Conditional Adversarial Nets Summary & Implementation

Caleb Eunho Lee
3 min readMar 29, 2021

--

Original paper: Conditional Generative Adversarial Nets

Summary

With vanilla GAN, you can’t control the modes of the data being generated. In other words, you can’t choose what image to generate. But the conditional generative adversarial network, also known as cGAN, made this possible by feeding the data we wish to condition on to both the generator and discriminator. In the paper, the authors show that this model can generate MNIST digits conditioned on class labels.

The cGAN Architecture

Random noise and conditioned data are the input of the cGAN generator network. The conditioned data y is concatenated with the random noise and fed to the generator, concatenated with the input image, and fed to the discriminator.

Implementation

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import layers
import time
import tensorflow_datasets as tfds
import progressbar
import random
from IPython import display

We use the fashion_mnist data for training cGAN, then generate 10 images for each class after training is finished.

dataset = tfds.load('fashion_mnist',
split='train',shuffle_files=True,download=True,
data_dir='/content/drive/MyDrive/data')
dataset = dataset.batch(256)

Define the Model

Generator:

We modify the DCGAN architecture

The generator has 2 inputs — random noise and the class label. The class label is concatenated with the random noise.

def generator_model():  z = tf.keras.Input(shape=(100,))
class_labels = tf.keras.Input(shape=(10,))
x = layers.Concatenate(axis=-1)([z,class_labels]) x = layers.Dense(7*7*256, use_bias=False,input_shape=(100,))(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Reshape((7,7,256))(x)
x = layers.Conv2DTranspose(128,(5,5),strides=
(1,1),padding='same',use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2DTranspose(64,(5,5),strides=
(2,2),padding='same',use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2DTranspose(1,(5,5),strides=
(2,2),padding='same',use_bias=False,activation='tanh')(x)
return tf.keras.Model(inputs = [z, class_labels], outputs = x)
cGAN Generator Architecture

Discriminator:

The discriminator also has 2 inputs, the input image and the corresponding class labels, concatenated and fed to the network.

def discriminator_model():

input = tf.keras.Input(shape=(28,28,1))
class_labels = tf.keras.Input(shape=(10,))
class_embedding = layers.Dense(14*14*10)(class_labels)
class_embedding = layers.Reshape((14,14,10))(class_embedding)
class_embedding = layers.LeakyReLU()(class_embedding)
x = layers.Conv2D(64,(5,5),strides=(2,2),padding='same',
input_shape=[28,28,1])(input)
x = layers.LeakyReLU()(x)

x = tf.keras.layers.Concatenate(axis=-1)([x, class_embedding])
x = layers.Conv2D(128,(5,5),strides=(2,2),padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)
x = layers.Dense(1,activation='sigmoid')(x)

return tf.keras.Model(inputs=[input,class_labels],outputs = x)
cGAN Discriminator Architecture

Define the Training Loss

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)def discriminator_loss(real_output,fake_output):
real_loss = cross_entropy(tf.ones_like(real_output),real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
total_loss = real_loss+fake_loss
return total_loss
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output),fake_output)
generator_optimizer = tf.keras.optimizers.Adam(0.0002)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0002)

We set the learning rate to 0.0002, and train for 100 epochs in the training process.

epochs = 100
noise_dim = 100
num_examples = 100
BATCH_SIZE = 256

seed_image = tf.random.normal([num_examples,noise_dim])
seed_label = [i%10 for i in range(100)]
seed_label = tf.one_hot(seed_label,10)
seed = [seed_image,seed_label]

Generate Images:

def generate_images(model,epoch,test_input):
predictions = model(test_input,training=False)
fig = plt.figure(figsize=(10,10))
for i in range(predictions.shape[0]):
plt.subplot(10,10,i+1)
plt.imshow(predictions[i,:,:,0]*127.5,cmap='gray')
plt.axis('off')

Training Process:

rescale = tf.keras.Sequential([
layers.experimental.preprocessing.Rescaling(1./255,offset=-1)
])
@tf.function
def train_step(inputs):
noise = tf.random.normal([BATCH_SIZE,noise_dim])
noise_label_ = tf.random.uniform([BATCH_SIZE],minval=0,maxval=9,
dtype=tf.dtypes.int64)
noise_label = tf.one_hot(noise_label_,10)
seed = [noise,noise_label]
with tf.GradientTape() as gen_tape, tf.GradientTape() as
disc_tape:
generated_images = generator(seed,training=True)
real_output = discriminator(inputs,training=True)
fake_output =
discriminator([generated_images,noise_label],training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output,fake_output)
gen_grad =
gen_tape.gradient(gen_loss,generator.trainable_variables)
disc_grad=
disc_tape.gradient(disc_loss,discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gen_grad,
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(disc_grad,
discriminator.trainable_variables))
def train(dataset,epochs):
for epoch in range(epochs):
start = time.time()

for data in progressbar.progressbar(dataset):
image_batch = rescale(data["image"])
label_batch = data["label"]
label_batch = tf.one_hot(label_batch,10)
train_step([image_batch,label_batch]) display.clear_output(wait=True)
generate_and_save_images(generator,epoch+1,seed)
#save model
if (epoch+1)%15 == 0:
generator.save("generator_model.h5")
discriminator.save("discriminator_model.h5")
display.clear_output(wait=True)
generate_images(generator,epochs,seed)

Result:

The following image is the generated images, 10 for each class.

Generated Images, sorted by class

--

--

No responses yet