Auxiliary GAN for MNIST Handwritten Digits
Generative Adversarial networks train deep convolutional neural networks for generating images. GAN requires a discriminator model for classifying whether a given image is real or fake and a generator model that transform an input into an image of pixel values. In this article we would be using MNIST handwritten digits dataset which is a standard image dataset allowing us to focus on GAN architecture.
The Auxiliary GAN is an extension of the Conditional GAN in which the discriminator model predicts class label of a given image rather than receive it as input. The training process is further stabilized and leads to generation of high quality images. Hence the discriminator model would predict whether the given image is real or fake and also the class label of the image. The generator model is provided both with a latent space point and the class label as input and it produces an image of that class label.
Discriminator model
The discriminator model comprises of an auxiliary classifier to classify the image and both of them can be considered as separate models but sharing the same model weights. In our model generation we would have a single fully connected neural network with 2 output. One for predicting whether the image is real or fake and the output would be a probability using the sigmoid activation function and the other is the probability of an image belonging to a class via the softmax activation function which is used for multiclass classification.
The discriminator model seeks to maximize the probability of predicting whether the image is real or fake and correctly predicting the image class label. The generator minimizes the ability of the discriminator to segregate real and fake images while also maximizing the discriminator’s ability to predict the class label of image. The appendix of the AC-GAN paper provides recommendations for the model configuration. The table summarizes the suggestions from the paper.
The discriminator model is defined according to DCGAN architecture using Batchnormalization, LeakyReLU and Dropout. However the way in which AC-GAN discriminator differs from the traditional one is that it has 2 output layers. The model is trained using binary cross-entropy and categorical cross-entropy loss. In order to compare integer class labels directly rather than comparing a one hot encoding of the class labels we use sparse categorical cross-entropy instead of binary cross-entropy. The model is fit using the Adam optimizer with a miniscule learning rate recommended for DCGANs.
# define discriminator modeldef define_discriminator(input_shape=(28,28,1),n_classes=10): init = RandomNormal(stddev=0.02) input_img = Input(shape=input_shape) #downsample to 14x14
model = Conv2D(32,(3,3),strides=(2,2),kernel_initializer=init,padding='same')(input_img) model = LeakyReLU(0.2)(model)
model = Dropout(0.5)(model) #normal
model = Conv2D(64,(3,3),kernel_initializer=init,padding='same')(model)
model = BatchNormalization()(model)
model = LeakyReLU(0.2)(model)
model = Dropout(0.5)(model) #downsample to 7x7
model = Conv2D(128,(3,3),strides=(2,2),kernel_initializer=init,padding='same')(model) model = BatchNormalization()(model)
model = LeakyReLU(0.2)(model)
model = Dropout(0.5)(model) #normal
model = Conv2D(256,(3,3),kernel_initializer=init,padding='same')(model)
model = BatchNormalization()(model)
model = LeakyReLU(0.2)(model)
model = Dropout(0.5)(model) #flatten feature maps
model = Flatten()(model) #classifier model predicting real or fake
out_layer1 = Dense(1,activation='sigmoid')(model) #classifier model to predict the class of an image
out_layer2 = Dense(n_classes,activation='softmax')(model) disc_model = Model(input_img,[out_layer1,out_layer2]) # compile model
opt = Adam(lr=0.0002, beta_1=0.5) disc_model.compile(loss=['binary_crossentropy', 'sparse_categorical_crossentropy'],optimizer=opt) return disc_model
Plotting our discriminator model.
Generator model
The generator model takes latent space point and a class label to generate an image. The class label can be provided as an additional channel in generator model by using an embedding layer, the output of which can be interpreted by a fully connected layer with a liner activation. Multiple feature maps of 7 x 7 are created to provide a basis for a low resolution version of the output image. The feature maps are passed through convolutional transpose layers to upsample. The final output is a grayscale image of 28x28 and pixels lying in the range[-1,1].
# define standalone generator which takes input a latent space and class labeldef define_generator(latent_dim,n_classes=10): #label input
label_input = Input(shape=(1,))
nodes = 7*7
init = RandomNormal(stddev=0.002) li = Embedding(n_classes,50)(label_input)
li = Dense(nodes,kernel_initializer=init)(li)
li = Reshape((7,7,1))(li) #image input
image_input = Input(shape=(latent_dim,))
m_nodes = 384*7*7 gen = Dense(m_nodes,kernel_initializer=init)(image_input)
gen = Activation('relu')(gen)
gen = Reshape((7,7,384))(gen) #merge the inputs
merge = Concatenate()([gen,li]) #upsample
gen = Conv2DTranspose(192,(5,5),strides=(2,2),padding='same',kernel_initializer=init)(merge)
gen = BatchNormalization()(gen)
gen = Activation('relu')(gen) gen = Conv2DTranspose(1,(5,5),strides=(2,2),padding='same',kernel_initializer=init)(gen) out_layer = Activation('tanh')(gen) #Model
model = Model([image_input,label_input],out_layer)
return model
Plotting the generator model
AC-GAN model
When the discriminator model is relatively good at detecting fake samples then the generator model is updated more whereas when the discriminator model is poor at detecting fake samples, the weights of the generator model is updated less. A new GAN model can be defined which takes a random input, generate samples and provide it as input to the discriminator model and output of the discriminator model is used to update the weight of the generator model. The discriminator can be trained on real and fake samples in a standalone manner whereas the generator gets updated according to the performance of the discriminator model on detecting fake samples. Therefore the discriminator which is part of the new GAN model is marked as not trainable
# define gan modeldef define_gan(g_model,d_model): #make discrimnator model non trainable
d_model.trainable = False #output of generator model acts as input to the discriminator model
gan_output = d_model(g_model.output) # define our gan model which takes noise and label as input and outputs a classification
model = Model(g_model.input,gan_output)
opt = Adam(learning_rate=0.0002,beta_1=0.5)
model.compile(loss=['binary_crossentropy','sparse_categorical_crossentropy'],optimizer=opt) return model.
Plotting the GAN model.
Loading Dataset
Loading the MNIST dataset from the Keras dataset. The images returned are 2D array but for providing them as input to the Embedding layer which expects 3D array, we add a new channel. We also normalize the images such that all the pixel values lies in [-1,1].
# getting the data for our AC gandef load_real_samples():
(trainX,trainY),(testX,testY) = load_data()
X = expand_dims(trainX,axis=-1) X = X.astype('float32') #scale to [-1,1]
X = (X-127.5)/127.5 return [X,trainY]
In order to select a batch of samples from our dataset we will use the following generate_real_samples method.
# generate real samplesdef generate_real_samples(dataset,n_samples): images, labels = dataset # create random instances
ix = randint(0,images.shape[0],n_samples) #get random images and labels
X,labels = images[ix],labels[ix] #create labels for our samples i.e they are real
y = np.ones((n_samples,1)) return [X,labels],y
We create a method to provide latent space points as input to the generator.
# generate latent spacedef generate_latent_points(latent_dim,n_samples,n_classes=10): # generate latent space
x_input = randn(latent_dim * n_samples) # reshape it
z_input = x_input.reshape(n_samples,latent_dim) # generate labels
labels = randint(0,n_classes,n_samples) return [z_input,labels]
The generate_fake_samples() modifies the ouput from the above method such that it can be used as input by the generator model.
# generate fake samplesdef generate_fake_samples(latent_dim,n_samples,gen_model): #generate latent space
X,labels = generate_latent_points(latent_dim,n_samples) #generate image
images = gen_model.predict([X,labels]) #create the labels for the images
y = np.zeros((n_samples,1)) return [images,labels],y
Model Training
We can now train our GAN model for around 100 epochs. Discriminator model is trained for half batch of real and fake images each. The generator is then updated via the composite GAN model. We save our generator model after every 10 epochs. A batch size of 128 is used as we have 60,000 images, for each training epoch we would have around 468 batches of real and fake samples each.
def train_gan(gen_model,disc_model,gan_model,dataset,latent_dim,n_epochs=100,n_batches=64): #no of batches per epoch
no_of_batches_per_epoch = int(dataset[0].shape[0]/n_batches) #no of steps
n_steps = int(no_of_batches_per_epoch * n_epochs) #half batch so we can train the discriminator on real images and fake images of length half the batch size each.
half_batch = int(n_batches/2) for i in range(n_steps):
print('step ',i) #generate real samples
[X_real,labels_real],y_real = generate_real_samples(dataset,half_batch) #generate fake samples
[X_fake,labels_fake],y_fake = generate_fake_samples(latent_dim,half_batch,gen_model) #train the discriminator on the real samples
disc_model.train_on_batch(X_real,[y_real,labels_real]) #train the discriminator on the fake samples
disc_model.train_on_batch(X_fake,[y_fake,labels_fake]) # generate inputs for the gan model
[z_inputs,z_labels] = generate_latent_points(latent_dim,n_batches) #inverted labels for real or fake
y_gan = np.ones((n_batches,1)) gan_model.train_on_batch([z_inputs,z_labels],[y_gan,z_labels]) if (i + 1)%(no_of_batches_per_epoch * 10)==0:
summarize_perf(i,gen_model,latent_dim)
Dimension of the latent space is set to 100 and we train our composite GAN model for 100,000 steps.
#latent dimension
latent_dim = 100# define discriminator
disc = define_discriminator()# define generator
gen = define_generator(latent_dim)#define gan model
gan = define_gan(gen,disc)#load dataset
dataset = load_real_samples()#train our gan
train_gan(gen,disc,gan,dataset,latent_dim)
We can create a method to save a plot of generate images in a grid of specified n x n images.
def save_plot(X,n):
for i in range(n*n):
plt.axis('off')
plt.subplot(n,n,i+1)
plt.imshow(X[i,:,:,0],cmap='gray_r')
plt.show()
Plotting 100 randomly generated images in a 10x10 grid.
Resources
Conditional Image Synthesis With Auxiliary Classifier GANs, Reviewer Comments