Conditional GAN for MNIST Handwritten Digits
Generative Adversarial Networks includes a generator model which is capable of generating new plausible fake samples that can be considered to be coming from an existing distribution of samples and a discriminator model that would classify the given sample as real or fake. The model weights of discriminator and the generator are updated according to the performance of the discriminator model.
Conditional GAN involves conditional generation of images by a generator model. It has the ability to generate targeted images of a given type unlike other GAN model which generates a random image from the domain. In this article Conditional generative adversarial network is developed for the targeted generation of MNIST handwritten digits.
The Modified National Institute of Standards and Technology dataset contains 70,000 grayscale 28x28 pixel images of handwritten digits between 0 and 9. The dataset can be accessed using mnist.load_dataset() function which Keras provides. It automatically returns the training and testing dataset.
import tensorflowfrom tensorflow.keras.datasets.mnist import load_data(trainX,trainY), (testX,testY) = load_data()
We can plot some of the images from the training set using the matplotlib library imshow() function. We plot 25 images into a 5x5 matrix shape.
import matplotlib.pyplot as pyplot
for i in range(25):
# define subplot
pyplot.subplot(5, 5, 1 + i) # turn off axis
pyplot.axis('off') # plot raw pixel data
pyplot.imshow(trainX[i], cmap='gray_r')pyplot.show()
Discriminator model
The discriminator model must take an image from the dataset and detect whether it is a real or fake image. The model would have the shape of the image and an integer for the class label of the image as its input. The class label is passed through an Embedding layer with the size of 50. Each of the class label would map to a different 50 vector representation. The output is passed to a fully connected layer. The Dense layer has enough activations which can be reshaped into single 28 x 28 activation map and concatenated with the input image. Hence the next layer would have 2 channel input image.
# define standalone discriminatordef define_discriminator(input_shape=(28,28,1),n_classes=10): # label input
in_labels = Input(shape=(1,)) # Embedding for categorical input
em = Embedding(n_classes,50)(in_labels) # scale up the image dimension with linear activations
d1 = Dense(input_shape[0] * input_shape[1])(em) # reshape to additional channel
d1 = Reshape((input_shape[0],input_shape[1],1))(d1) # image input
image_input = Input(shape=input_shape) # concate label as channel
merge = Concatenate()([image_input,d1]) # downsample
fe = Conv2D(128,(3,3),strides=(2,2),padding='same')(merge)
fe = LeakyReLU(0.2)(fe) # downsample
fe = Conv2D(128,(3,3),strides=(2,2),padding='same')(merge)
fe = LeakyReLU(0.2)(fe) #flatten feature maps
fe = Flatten()(fe)
fe = Dropout(0.4)(fe) #ouput
out_layer = Dense(1,activation='sigmoid')(fe) #define model
model = Model([image_input,in_labels],out_layer) #compile model
opt = Adam(lr=0.0002,beta_1=0.5)
model.compile(loss='binary_crossentropy',optimizer=opt,metrics= ['accuracy']) return model
Plotting the discriminator model.
Summary of the discriminator model.
Generator model
The generator model takes a point from the latent space as input and outputs plausible image of handwritten digits. The latent space is a vector space of standard Gaussian distributed values. The generator model takes the class label which makes the point in the latent space conditional on the provided class label. The class label passes through an Embedding layer which maps it to a 50 element vector and is then passed through a fully connected layer. The new 7x7 feature map is added as one more channel (for class labels) to the 128 making it 129 feature maps that are upsampled in the next layers using Conv2dTranspose and LeakyReLU activation function.
# define standalone generator modeldef define_generator(latent_dim,n_classes=10):
label_input = Input(shape=(1,)) #Embedding layer
em = Embedding(n_classes,50)(label_input) nodes = 7*7
em = Dense(nodes)(em)
em = Reshape((7,7,1))(em) #image generator input
image_input = Input(shape=(latent_dim,)) nodes = 128*7*7
d1 = Dense(nodes)(image_input)
d1 = LeakyReLU(0.2)(d1)
d1 = Reshape((7,7,128))(d1) # merge
merge = Concatenate()([d1,em]) #upsample to 14x14
gen = Conv2DTranspose(128,(4,4),strides=(2,2),padding='same')(merge)
gen = LeakyReLU(0.2)(gen) #upsample to 28x28
gen = Conv2DTranspose(128,(4,4),strides=(2,2),padding='same')(gen)
gen = LeakyReLU(0.2)(gen) #output layer
out_layer = Conv2D(1,(7,7),activation='tanh',padding='same')(gen) #define model
model = Model([image_input,label_input],out_layer) return model
Summary of the generator model.
Plotting the generator model.
GAN model
When the discriminator model is relatively good at detecting fake samples then the generator model is updated more whereas when the discriminator model is poor at detecting fake samples, the weights of the generator model is updated less. A new GAN model can be defined which takes a random input, generate samples and provide it as input to the discriminator model and output of the discriminator model is used to update the weight of the generator model. The GAN model would take latent space point and a class label and predict whether the input is real or fake. The discriminator model in the GAN model as non trainable as it is updated it in a standalone manner on real as well as fake images.
# defining the combined generator and discriminator model for updating the generatordef define_gan(g_model,d_model): # make the discriminator layer as non trainable
d_model.trainable = False # get the noise and label input from the generator
gen_noise, gen_label = g_model.input # get the output from the generator
gen_output = g_model.output #connect image output and label input from generator as inputs to #discriminator gan_output = d_model([gen_output,gen_label]) #define gan model as taking noise and label and outputting a #classification model = Model([gen_noise,gen_label],gan_output) opt = Adam(lr=0.0002,beta_1=0.5) model.compile(loss='binary_crossentropy',optimizer=opt) return model
Plotting our GAN model.
Loading Dataset
Loading the MNIST dataset from the Keras dataset. The images returned are 2D array but for providing them as input to the Embedding layer which expects 3D array, we add a new channel. We also normalize the images such that all the pixel values lies in [-1,1].
# load mnsit imagesdef load_real_samples(): (trainX,trainY),(testX,testY) = load_data() # expand to 3d
X = expand_dims(trainX,axis=-1) #make it float
X = X.astype('float32') # scale from [0,255] tp [-1,1]
X = (X - 127.5) / 127.5 return [X,trainY]
In order to select a batch of samples from our dataset we will use the following generate_real_samples method.
# select real samplesdef generate_real_samples(dataset,n_samples): # split into images and labels
images,labels = dataset #choose random instances
ix = randint(0,images.shape[0],n_samples) # select images and labels
X, Labels = images[ix], labels[ix] # generate class labels
y = ones((n_samples,1)) return [X,Labels], y
We create a method to provide latent space points as input to the generator.
# generate points in latent space as input for the generatordef generate_latent_points(laten_dim,n_samples,n_classes=10):
# generate points in the latent space
x_input = randn(latent_dim*n_samples) # reshape into a batch of inputs to the network
z_input = x_input.reshape(n_samples,latent_dim) # generate labels
labels = randint(0,n_classes,n_samples) return [z_input,labels]
The generate_fake_samples() modifies the ouput from the above method such that it can be used as input by the generator model.
# use generator to generate n fake samples, with class labelsdef generate_fake_samples(latent_dim,n_samples,generator): #generate points in latent space
z_input,labels = generate_latent_points(latent_dim,n_samples) #predict outputs
images = generator.predict([z_input,labels]) # create class labels
y = zeros((n_samples,1)) return [np.array(images),labels],y
Model Training
We can now train our GAN model for around 100 epochs. Discriminator model is trained for half batch of real and fake images each. The generator is then updated via the composite GAN model. We save our generator model after every 10 epochs. A batch size of 128 is used as we have 60,000 images, for each training epoch we would have around 468 batches of real and fake samples each.
def train_gan(gen_model,disc_model,gan_model,latent_dim,dataset,n_epochs=100,n_batches=128): batches_per_epoch = int(dataset[0].shape[0]/n_batches) half_batch = int(n_batches/2) for i in range(n_epochs): print('epoch ',i)
for j in range(batches_per_epoch):
#generate real samples
[X_real,labels_real], y_real = generate_real_samples(dataset,half_batch) #generate fake samples
[X_fake,labels_fake], y_fake = generate_fake_samples(latent_dim,half_batch,gen_model) # train the discriminator on real and fake samples
loss_real,acc_real = disc_model.train_on_batch([X_real,labels_real],y_real)
loss_fake,acc_fake = disc_model.train_on_batch([X_fake,labels_fake],y_fake) # prepare points in the latent space as input to the generator
[z_input,labels_input] = generate_latent_points(latent_dim,n_batches) # create inverted labels for fake samples
y_gan = ones((n_batches,1)) #update the generator model via the discriminator
g_loss = gan_model.train_on_batch([z_input,labels_input],y_gan) if ((i+1)%10) == 0:
gen_model.save('c_gan%3d.h5'%i)
Dimension of the latent space is set to 100 and we train our composite GAN model for 100 epochs.
# size of the latent dimensions
latent_dim = 100#discriminator model
discriminator = define_discriminator()#generator model
generator = define_generator(latent_dim)#create the gan model
gan = define_gan(generator,discriminator)#dataset
dataset = load_real_samples()#train model
train_gan(generator,discriminator,gan,latent_dim,dataset)
We can create a method to save a plot of generate images in a grid of specified n x n images.
def save_plot(X,n):
for i in range(n*n):
plt.axis('off') plt.subplot(n,n,i+1) plt.imshow(X[i,:,:,0],cmap='gray_r') plt.show()
In order for our generator model to make prediction we would need latent space points and the specific labels for which the images should be generated.
# latent_dimlatent_dim = 100[inputs,labels] = generate_latent_points(latent_dim,100)labels = np.asarray([x for _ in range(10) for x in range(10)])X = model.predict([inputs,labels])save_plot(X,10)
Plotting the images generated by our GAN model.
Resources
Conditional Generative Adversarial Nets, 2014.
Image-To-Image Translation With Conditional Adversarial Networks, 2017.