Semi-Supervised GAN for MNIST Handwritten Digits

Saif Gazali
8 min readAug 5, 2021
Photo by TowardsDataScience

Semi-Supervised GAN involves training of a supervised discriminator, unsupervised discriminator and a generator model simultaneously. It results in a supervised classification predicting the class label of an image and a generator model that generate images from the domain. Generally, Semi-supervised learning is the problem of training a classifier for a dataset that has small number of labeled examples as compared to unlabeled examples. Hence the model must learn from the small amount of labeled data and also take into consideration unlabeled examples in order to generalize to classify new images.

The discriminator in a traditional GAN is used to predict whether the given image is real or fake. It can be used via transfer learning when developing a classifier allowing the supervised prediction task to take advantage from the unsupervised training of the GAN. The discriminator in Semi-Supervised GAN is updated to make predictions for N+1 classes where N is the number of classes in the prediction problem and the additional class label is added for a new fake class. Hence discriminator model is trained for both the unsupervised GAN task and the supervised task simultaneously. Unsupervised training allows the model to learn useful features from unlabeled dataset which is large in size and to use the extracted features in supervise training to apply class labels.

Discriminator model

There are many ways to develop a discriminator model for Semi-Supervised GAN such as having a Single Discriminator with 2 output layers — one for the supervised and the other for the unsupervised task. Both of which uses the same feature extraction layers. Another approach could be to have a separate discriminator models with shared weights. One model to predict whether the image is real or fake and a second model that predicts the class of a given model.

In this article, we would be creating a discriminator model using a stacked approach to reuse the output layers of one model to provide it as input into another model. This technique is based on the definition of semi-supervised learning in the 2016 paper by Tim Salimans, et al. where in the supervised model is created with K output classes and a softmax activation function where as the unsupervised model takes the output of the prior model and calculated a normalized sum of exponential outputs using the following equation.

The above equation can be implemented using Numpy.

# custom activation fn
def custom_activation(output):
logexpsum = backend.sum(backend.exp(output),axis=-1,keepdims=True)
result = (logexpsum) / (logexpsum + 1.0)
return result

Activations are the output of the unsupervised model before the activation function is applied to it. They are small positive or negative values. The model would output a large class prediction for real examples and a small class prediction for fake samples. The supervised model uses a softmax activation and categorical cross-entropy loss function whereas the unsupervised model is stacked on top of the output layer of the supervised model prior the activation. The activations of the node pass through the custom_activation function.

# Define stacked discriminator for semi supervised learningdef define_discriminator(in_shape=(28,28,1),n_classes=10):  input = Input(shape = in_shape)  # Downsample into 14x14
fc = Conv2D(128,(3,3),strides=(2,2),padding='same')(input)
fc = LeakyReLU(0.2)(fc)
# Downsample into 7x7
fc = Conv2D(128,(3,3),strides=(2,2),padding='same') (fc)
fc = LeakyReLU(0.2)(fc)
# flatten feature maps
fc = Flatten()(fc)
#ouput layer nodes
fc = Dense(n_classes)(fc)
# supervised output
c_out_layer = Activation('softmax')(fc)
# define and compile supervised discriminator model
c_model = Model(input,c_out_layer)
opt = Adam(lr=0.0002,beta_1=0.5)
c_model.compile(loss='sparse_categorical_crossentropy',optimizer=opt,metrics=['accuracy'])
c_model.summary() # unsupervised ouput - image real or fake
d_out_layer = Lambda(custom_activation)(fc)
# define and compile unsupervised discriminator model
d_model = Model(input,d_out_layer)
d_model.compile(loss='binary_crossentropy',optimizer=opt)
d_model.summary()
return d_model,c_model

Plotting the unsupervised and supervised discriminator model.

Unsupervised discriminator model
Supervised discriminator model

Generator model

The generator model takes a point from the latent space as input and outputs plausible image of handwritten digits. The latent space is a vector space of standard Gaussian distributed values. During training the generator model assign meaning to the latent points. The Dense layer is the first hidden layer that has enough nodes to represent a low resolution image (128 * 7 * 7 in our case where 7x7 is the size of image and as we would have multiple low resolution versions of output image we set that number to 128). We then need to reshape the output such as 128 different 7x7 feature maps. The next step involves upsampling the low resolution image to a high resolution version using the Conv2DTranspose layer by configuring the stride to (2x2). This is done twice to arrive at the 28x28 output image. The output layer is a Conv2D layer with a kernel size of 7x7 and using a sigmoid activation function to ensure the output values are in the desired range of [0,1]. The generator model is not compiles as it is not trained directly.

# define the standalone generator modeldef define_generator(latent_dim):
input = Input(shape=(latent_dim,))
#creating enough activations for our (28,28,1) input image
nodes = 128 * 7 * 7
gen = Dense(nodes)(input)
gen = LeakyReLU(0.2)(gen)
gen = Reshape((7,7,128)) (gen)
#upsample 14x14
fc = Conv2DTranspose(128,(4,4),strides=(2,2),padding='same')(gen)
fc = LeakyReLU(0.2)(fc)
#upsample 28x28
fc = Conv2DTranspose(128,(4,4),strides=(2,2),padding='same')(fc)
fc = LeakyReLU(0.2)(fc)
#ouput
out_layer = Conv2D(1,(7,7),activation='tanh',padding='same')(fc)
#define model
model = Model(input,out_layer)
model.summary()
return model

Plotting the generator model.

Generator model Plot

When the discriminator model is relatively good at detecting fake samples then the generator model is updated more whereas when the discriminator model is poor at detecting fake samples, the weights of the generator model is updated less. A new GAN model can be defined which takes a random input, generate samples and provide it as input to the discriminator model and output of the discriminator model is used to update the weight of the generator model.

# the generator model is updated via the unsupervised discriminatordef define_gan(g_model,d_model):  # make weights in the discriminator model not trainable
d_model.trainable = False
# connect image output from generator as input to the discriminator
gan_output = d_model(g_model.output)
# define gan model as taking noise and ouputting a classification
model = Model(g_model.input,gan_output)
# compile model
opt = Adam(lr=0.0002,beta_1=0.5)
model.compile(loss='binary_crossentropy',optimizer=opt) return model

We create a latent space which would provide the points as an input to our generator model to generate images. This can be done using randn() function which generates arrays of random numbers from a standard Gaussian distribution. The arrays can be reshaped such as n rows with 100 elements (latent_dimension is set to 100) per row.

# generate latent space
def generate_latent_space(latent_dim,n_samples):
# generate points in latent space
X = randn(latent_dim*n_samples) # reshape into a batch of inputs for the network
X = X.reshape(n_samples,latent_dim)
return X

The points from the latent space can be used as inputs to the generator model to generate new samples.

# use the generator to generate n fake samplesdef generate_fake_samples(model,latent_dim,n_samples):  X = generate_latent_space(latent_dim,n_samples)
X_output = model.predict(X)
y = zeros((n_samples,1))
return X_output,y

We need to train our discriminator model on both real and fake images. Firstly, we create a method that would load real samples. We use mnist.load_data() function to load the MNIST dataset. The images loaded are 2D arrays of pixels however neural networks expect 3D arrays as input. Hence we would add an additional dimension for grayscale channel using the expand_dims() function of Numpy. Finally we normalize our image by making all the pixel values lie between [0,1]. All the steps are done in the function load_real_samples.

# getting all the required inputsdef load_real_samples():  (trainX,trainY),(testX,testY) = load_data()
X = expand_dims(trainX,axis=-1)
X = X.astype('float32')
# making values between [-1,1]
X = (X - 127.5) / 127.5
return [X,trainY]

A subset of the training dataset is selected which has labels and is used to train the supervised version of the discriminator model using the function select_supervised_samples. The function makes sure that the examples are selected at random and the classes are balanced.

# select as supervised subset of the dataset , ensure classes are balanceddef select_supervised_samples(dataset,n_samples=100,n_classes=10):  X,y = dataset
X_list , y_list = list(), list()
n_samples_per_class = int(n_samples / n_classes) for i in range(n_classes):
# get all the images with specified class
X_class = X[y == i] #choose random instances
ix = randint(0,len(X_class),n_samples_per_class)
#add to the list
[X_list.append(X_class[j]) for j in ix]
[y_list.append(i) for j in ix]
return np.array(X_list), np.array(y_list)

GAN Model Training

The model training firstly involves getting the labeled subset of the training dataset and calculating the number of training steps. Supervised model is trained on the labeled examples and the unsupervised model is updated using the real and fake samples. Finally, the generator model is updated through the composite model.

#train the generator and discriminatordef train_gan(gen_model,d_model,c_model,gan_model,dataset,latent_dim,n_epochs=20,n_batches=100):  half_batch = int(n_batches/2)  # calculate no of batches per epoch
batches_per_epoch = int(dataset[0].shape[0]/n_batches)
#calculating no of training iterations
n_steps = int(batches_per_epoch * n_epochs)
X_sup, y_sup = select_supervised_samples(dataset) print(X_sup.shape,y_sup.shape) for i in range(n_steps):

# update the supervised discriminator
[X_supReal, y_supReal],_ = generate_real_samples([X_sup,y_sup],half_batch)
c_loss,c_acc = c_model.train_on_batch(X_supReal,y_supReal) # update the unsupervised discriminator
[X_real,_],y_real = generate_real_samples(dataset,half_batch)
dlossreal = d_model.train_on_batch(X_real,y_real) X_fake, y_fake = generate_fake_samples(latent_dim,half_batch,gen_model) dlossfake = d_model.train_on_batch(X_fake,y_fake) #update the generator
X_gan = generate_latent_points(latent_dim,n_batches)
y_gan = np.ones((n_batches,1))

gan_model.train_on_batch(X_gan,y_gan)

Training the model for 12000 steps. We get a classification accuracy of around 93.37%.

# latent space
latent_dim = 100
# discriminator model
d_model,c_model = define_discriminator()
#generator model
gen = define_generator(latent_dim)
#gan model
gan = define_gan(gen,d_model)
#load dataset
dataset = load_real_samples()
#train images
train_gan(gen,d_model,c_model,gan,dataset,latent_dim)

Evaluating our supervised model on Test dataset gives us around 94% classification accuracy.

# Evaluating on the training data
classifier_model.evaluate(trainX,trainY)
#Evaluating on the testing data
classifier_model.evaluate(testX,testY)

Using our trained generator model to generate 100 random images of MNIST handwritten digits in a 10x10 grid.

# generating 100 new images with our generator model
latent_dim = 100
images,y = generate_fake_samples(latent_dim,100,gen_model)
for i in range(100):
plt.subplot(10,10,i+1)
plt.axis('off') plt.imshow(images[i,:,:,0],cmap='gray_r')
Generated Plot

The image generated by our Semi-Supervised GAN model seems to be of better quality as compared to the generated images by other GAN models.

We have implemented other GAN models to generate MNIST handwritten digits as well and the articles can be found from the following links.

Deep Convolutional GAN

Conditional GAN

Auxiliary GAN

Resources

Semi-Supervised Learning with Generative Adversarial Networks, 2016.

Improved Techniques for Training GANs, 2016.

Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks, 2015.

Machine Learning Mastery

Keras API

--

--