Deep Convolutional GAN for MNIST Handwritten digits

Saif Gazali
8 min readAug 1, 2021
Photo by SquareUp

Generative Adversarial networks train deep convolutional neural networks for generating images. GAN requires a discriminator model for classifying whether a given image is real or fake and a generator model that transform an input into an image of pixel values. In this article we would be using MNIST handwritten digits dataset which is a standard image dataset allowing us to focus on GAN architecture rather than data preprocessing part.

The Modified National Institute of Standards and Technology dataset contains 70,000 grayscale 28x28 pixel images of handwritten digits between 0 and 9. The dataset can be accessed using mnist.load_dataset() function which Keras provides. It automatically returns the training and testing dataset.

import tensorflowfrom tensorflow.keras.datasets.mnist import load_data(trainX,trainY), (testX,testY) = load_data()

Printing the shape out training and test dataset. We have around 60,000 images in training set and 10,000 images in the test set.

# summarize the shape of the datasetprint('Train', trainX.shape, trainY.shape) #Train (60000, 28, 28) (60000,)print('Test', testX.shape, testY.shape) #Test (10000, 28, 28) (10000,)

We can plot some of the images from the training set using the matplotlib library imshow() function. We plot 25 images into a 5x5 matrix shape.

import matplotlib.pyplot as pyplotfor i in range(25):
# define subplot
pyplot.subplot(5, 5, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(trainX[i], cmap='gray_r')
pyplot.show()

The images from the dataset are used for training a Generative Adversarial Network and the generator model will be trained to develop new plausible images of handwritten digits between 0 and 9.

Discriminator Model

The discriminator model must take an image from the dataset and detect whether it is a real or fake image. Our discriminator model has 2 convolutional layers with 128 filters each and a stride of 2. Dropout layer is added for regularization and to reduce overfitting. We use LeakyReLU instead of ReLU which is the recommended on. Adam optimizer is used with a learning rate of 0.0002 and a momentum of 0.5.

# defining discriminator modedef define_discriminator(input = (28,28,1)):
model = Sequential()
model.add(Conv2D(128,3,strides=(2,2),padding='same',input_shape=input))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Conv2D(128,3,strides=(2,2),padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.4))
model.add(Flatten())
model.add(Dense(1,activation='sigmoid'))
opt = Adam(lr=0.002,beta_1=0.5)
model.compile(loss='binary_crossentropy',optimizer=opt,metrics=['accuracy']) return model

We would define the discriminator model and plot the summary.

model = define_discriminator()
model.summary()
Model Summary

Discriminator Model Training

We need to train our discriminator model on both real and fake images. Firstly, we create a method that would load real samples. We use mnist.load_data() function to load the MNIST dataset. The images loaded are 2D arrays of pixels however neural networks expect 3D arrays as input. Hence we would add an additional dimension for grayscale channel using the expand_dims() function of Numpy. Finally we normalize our image by making all the pixel values lie between [0,1]. All the steps are done in the function load_real_samples.

def load_real_samples():
(trainX,_),(_,_) = load_data()
X = expand_dims(trainX,axis=-1) X = X.astype('float32') X = X/255.0 return X

The model is updated in batches where in we select random images from the training set to train it. We would also return a class label of 1 specifying its a real image.

# define real samplesdef generate_real_samples(dataset,n_samples):
ix = randint(0,dataset.shape[0],n_samples)
# retrieve selected images
X = dataset[ix]
# generate class labels
y = ones((n_samples,1))
return X,y

For generating fake images we can use random pixel values in the range of [0,1].

# define fake samples
def generate_fake_samples(n_samples):
X = rand(28*28*n_samples)
X = X.reshape((n_samples,28,28,1)) y = zeros((n_samples,1)) return X,y

Finally, we would train our discriminator model by using a batch of 256 images where 128 are real and 128 are fake images for each iteration. We train the model for 100 iterations using the train_on_batch method.

# train the discriminator model
def train_dicriminator(model,dataset,n_iter=100,n_batch=256):
half_batch = int(n_batch/2)
for i in range(n_iter):
# generate real samples
X_real,y_real = generate_real_samples(dataset,half_batch)
# train the discriminator on the real samples
lossreal,accreal = model.train_on_batch(X_real,y_real)
# generate fake samples
X_fake,y_fake = generate_fake_samples(half_batch)
# train the discriminator on the real samples
lossfake,accfake = model.train_on_batch(X_fake,y_fake)
print(i+1,accreal,accfake)

Training our discriminator model and printing out its accuracy for detecting real and fake images.

model = define_discriminator()dataset = load_real_samples()train_dicriminator(model,dataset)
Discriminator model training

Generator Model

The generator model takes a point from the latent space as input and outputs plausible image of handwritten digits. The latent space is a vector space of standard Gaussian distributed values. During training the generator model assign meaning to the latent points. The Dense layer is the first hidden layer that has enough nodes to represent a low resolution image (128 * 7 * 7 in our case where 7x7 is the size of image and as we would have multiple low resolution versions of output image we set that number to 128). We then need to reshape the output such as 128 different 7x7 feature maps. The next step involves upsampling the low resolution image to a high resolution version using the Conv2DTranspose layer by configuring the stride to (2x2). This is done twice to arrive at the 28x28 output image. The output layer is a Conv2D layer with a kernel size of 7x7 and using a sigmoid activation function to ensure the output values are in the desired range of [0,1]. The generator model is not compiles as it is not trained directly.

# define the generator modeldef define_generator(input):
model = Sequential()
# foundations for 7x7 image
model.add(Dense(128*7*7,input_dim=input))
model.add(LeakyReLU(0.2))
# reshape the image
model.add(Reshape((7,7,128)))
# upsample it to 14x14
model.add(Conv2DTranspose(128,(4,4),strides=(2,2),padding='same'))
model.add(LeakyReLU(0.2))
#upsample it to 28x28
model.add(Conv2DTranspose(128,(4,4),strides=(2,2),padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(1,(7,7),activation='sigmoid',padding='same')) return model

Defining our generator model and plotting its summary.

# latent dimension
latent_dim = 100
# generate a model
model = define_generator(latent_dim)
model.summary()
Generator model summary

We create a latent space which would provide the points as an input to our generator model to generate images. This can be done using randn() function which generates arrays of random numbers from a standard Gaussian distribution. The arrays can be reshaped such as n rows with 100 elements (latent_dimension is set to 100) per row.

# generate latent space
def generate_latent_space(latent_dim,n_samples):
# generate points in latent space
X = randn(latent_dim*n_samples)
# reshape into a batch of inputs for the network
X = X.reshape(n_samples,latent_dim)
return X

The points from the latent space can be used as inputs to the generator model to generate new samples.

# use the generator to generate n fake samplesdef generate_fake_samples(model,latent_dim,n_samples):
X = generate_latent_space(latent_dim,n_samples)
X_output = model.predict(X)
y = zeros((n_samples,1)) return X_output,y

We can plot the random generated image by our generator model. The image would not make sense as our generator model is not yet trained.

import matplotlib.pyplot as plt# latent dimension
latent_dim = 100
# generate a model
model = define_generator(latent_dim)
n_samples = 10X,y = generate_fake_samples(model,latent_dim,n_samples)for i in range(n_samples):
plt.imshow(X[i,:,:,0])
Random generated image

Generator Model Training

When the discriminator model is relatively good at detecting fake samples then the generator model is updated more whereas when the discriminator model is poor at detecting fake samples, the weights of the generator model is updated less. A new GAN model can be defined which takes a random input, generate samples and provide it as input to the discriminator model and output of the discriminator model is used to update the weight of the generator model.

The discriminator can be trained on real and fake samples in a standalone manner whereas the generator gets updated according to the performance of the discriminator model on detecting fake samples. Therefore the discriminator which is part of the new GAN model is marked as not trainable. Also the generated samples from the generator model are marked as real so the discriminator would detect it as real or fake and give a probability using which during backpropagation process the weights of the generator model is updated to minimize the loss.

# defining our gan modeldef define_gan(generator,discriminator):
# make the weights of discriminator as non trainable
discriminator.trainable = False
model = Sequential()
model.add(generator)
model.add(discriminator)
opt = Adam(lr=0.002,beta_1=0.5) model.compile(loss='binary_crossentropy',optimizer=opt,metrics=['accuracy']) return model

Defining our GAN model and plotting its summary.

# size of the latent space
latent_dim = 100
# create the discriminator
d_model = define_discriminator()
# create the generator
g_model = define_generator(latent_dim)
# create the gan
gan_model = define_gan(g_model, d_model)
# summarize gan mode
gan_model.summary()
GAN model summary

Model training would include first updating the discriminator model with real and fake samples, then the generator model in the composite model is updated.

# Training our standalone discriminator and generator in gandef train_gan(gen_model,disc_model,gan_model,dataset,latent_dim,n_epochs=50,n_batch=256):
half_batch =int(n_batch/2)
batch_per_epoch = int(dataset.shape[0]/n_batch) # manually enumerate epochs for i in range(n_epochs):
# enumerate batches over the training set
for j in range(batch_per_epoch):
# generate random selected real samples
X_real,y_real = generate_real_samples(dataset,half_batch)

# generate fake samples
X_fake, y_fake = generate_fake_samples(gen_model,latent_dim,half_batch)
#train discriminator
lossreal,accreal = disc_model.train_on_batch(X_real,y_real)
lossfake,accfake = disc_model.train_on_batch(X_fake,y_fake)
#train generator
X_gan = generate_latent_space(latent_dim,n_batch)
y_gan = ones((n_batch,1))

#update the generator via the discriminator's error
loss,acc = gan_model.train_on_batch(X_gan,y_gan)
if (i+1)%10 == 0:
summarize_perf(i,gen_model,disc_model,dataset,latent_dim,100)

GAN Performance Evaluation

Generally images generated by GANs are evaluated for quality by a human operator. Our approach would include 3 steps:

  1. Evaluate the accuracy of discriminator model on real and fake images.
  2. Generate images using generator model periodically.
  3. Save the generator model.

We define a method summarize_performance to summarize the performance of the discriminator model by evaluating the accuracy of the discriminator model on real as well as fake images.

# evaluate the discriminator , plot and save the generated imagedef summarize_perf(epoch,gen_model,disc_model,dataset,latent_dim,n_samples=100):  X_real,y_real = generate_real_samples(dataset,n_samples)
lossreal,accreal = disc_model.evaluate(X_real,y_real)
X_fake,y_fake = generate_fake_samples(gen_model,latent_dim,n_samples)
lossfake,accfake = disc_model.evaluate(X_real,y_real)
print(accreal,accfake) #save plot
save_plot(X_fake,epoch)
#save the generator model
filename = 'generated_model_%03d.h5' % (epoch+1)
gen_model.save(filename)

We create a function to plot the images that are generated by our generator model. We plot 100 images as a 10x10 grid.

# create and save a plot of generated images# plot 100 images as 10x10 griddef save_plot(examples,epoch,n=10):
for i in range(n*n):
plt.subplot(n,n,i+1)
plt.axis('off') plt.imshow(examples[i,:,:,0],cmap='gray_r') # save plot to file
filename = 'generated_plot_e%03d.png' % (epoch+1)
plt.savefig(filename) plt.close()

Training our GAN model for around 50 epochs.

# size of the latent dim
latent_dim = 50
# define discriminator
discriminator = define_discriminator()
# define generator
generator = define_generator(latent_dim)
# define gan
gan = define_gan(generator,discriminator)
#load image data
dataset = load_real_samples()
#train gan
train_gan(generator,discriminator,gan,dataset,latent_dim)
GAN model training

Plotting some of the images generated by our generator model.

Images generated after 20 epochs
Images generated after 40 epochs

Our generator model seems to be generating plausible images of handwritten digits. We could use our saved generator model to generate more images.

model = load_model('gen_10.h5')latent_points = generate_latent_points(50,25)X = model.predict(latent_points)save_plot(X,5)

A single image can be generated as well using the model.

vector = asarray([[0.0 for _ in range(50)]])X = model.predict(vector)plt.imshow(X[0,:,:,0],cmap='gray_r')
plt.show()

Resources

Machine Learning Mastery

Keras API

--

--