Implementing ACGAN(Auxiliary Classifier GAN)
Original Paper: Conditional Image Synthesis with Auxiliary Classifier GANs
ACGAN, as the name implies, uses a classifier network to condition class labels to images.
The discriminator network learns to classify the images and distinguish between real and fake images. So, the discriminator learns to maximize Lc + Ls where
The generator’s objective is to fake the discriminator and produce clear class-specific images. So it is trained to maximize Lc-Ls.
The following figure is the architecture of ACGAN.
The discriminator network has 2 outputs — real/fake and class output.
Implementation
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import random
import time
import progressbar
from IPython import display
I used the MNIST dataset for training the ACGAN network.
dataset = tfds.load('mnist',split='train',data_dir='/content/drive/MyDrive
/data/mnist')
dataset = dataset.batch(256)
Generator:
The generator architecture is from the cGAN network.
def generator_model(): z = tf.keras.Input(shape=(100,))
class_labels = tf.keras.Input(shape=(10,)) x = tf.keras.layers.Concatenate(axis=-1)([z,class_labels]) x = tf.keras.layers.Dense(7*7*256,use_bias=False,input_shape=(100,))(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x) x = tf.keras.layers.Reshape((7,7,256))(x)
x = tf.keras.layers.Conv2DTranspose(128,(5,5),strides=(1,1),padding='same',use_bias=False)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x) x = tf.keras.layers.Conv2DTranspose(64,(5,5),strides=(2,2),padding='same',use_bias=False)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x) x = tf.keras.layers.Conv2DTranspose(1,(5,5),strides=(2,2),padding='same',use_bias=False,activation='tanh')(x) return tf.keras.Model(inputs = [z, class_labels], outputs = x)generator = generator_model()
tf.keras.utils.plot_model(generator,show_shapes=True)
Discriminator:
def discriminator_model(channels = 28,image_size = (28,28)):
input = tf.keras.Input(shape=(28,28,1)) x = tf.keras.layers.Conv2D(channels, 4, strides=(2, 2), padding='same',input_shape=image_size+(1,))(input)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU()(x) x = tf.keras.layers.Conv2D(channels*2, 4, strides=(2, 2), padding='same')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU()(x) x = tf.keras.layers.Conv2D(channels*4, 4, strides=(2, 2), padding='same')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.LeakyReLU()(x) x = tf.keras.layers.Flatten()(x)
disc_output = tf.keras.layers.Dense(1,activation='sigmoid')(x)
class_output = tf.keras.layers.Dense(10,activation='softmax')(x) return tf.keras.models.Model(inputs = input,outputs = [disc_output,class_output])discriminator = discriminator_model()
tf.keras.utils.plot_model(discriminator,show_shapes=True)
Define the loss
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
The discriminator loss is defined as Lgan + Lc, where
Lgan = cross-entropy(D(x))+cross-entropy(1-D(G(z))),
Lc = cross-entropy(y, D(x))
def discriminator_loss(real_output,fake_output,real_label,pred):
real_loss = bce(tf.ones_like(real_output),real_output)
fake_loss = bce(tf.zeros_like(fake_output),fake_output)
class_loss = cce(real_label,pred)
total_loss = real_loss+fake_loss+class_loss
return total_loss
The generator loss:
Lc = cross-entropy(y_fake,D(G(z))
Lg = cross-entropy(D(G(z))) + Lc
def generator_loss(fake_output,fake_label,fake_pred):
disc_loss = bce(tf.ones_like(fake_output),fake_output)
class_loss = cce(fake_label,fake_pred)
return class_loss + disc_loss
Training process:
generator_optimizer = tf.keras.optimizers.Adam(0.0002)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0002)epochs = 150
noise_dim = 100
num_examples = 16
BATCH_SIZE = 128seed_image = tf.random.normal([num_examples,noise_dim])
seed_label = tf.one_hot([i%10 for i in range(16)],10)
seed = [seed_image,seed_label]def generate_and_save_images(model,epoch,test_input):
predictions = model(test_input,training=False) fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow(predictions[i,:,:,0]*127.5,cmap='gray')
plt.axis('off') plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()@tf.function
def train_step(image,label):
noise = tf.random.normal([BATCH_SIZE,noise_dim])
noise_label_ = tf.random.uniform([BATCH_SIZE],minval=0,maxval=9,
dtype=tf.dtypes.int64)
noise_label = tf.one_hot(noise_label_,10)
seed = [noise,noise_label] with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(seed,training=True) real_output, real_pred = discriminator(image,training=True)
fake_output, fake_pred = discriminator(generated_images,training=True) gen_loss = generator_loss(fake_output,noise_label,fake_pred)
disc_loss = discriminator_loss(real_output,fake_output,label,real_pred) gen_grad = gen_tape.gradient(gen_loss,generator.trainable_variables)
disc_grad = disc_tape.gradient(disc_loss,discriminator.trainable_variables) generator_optimizer.apply_gradients(zip(gen_grad,
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(disc_grad,
discriminator.trainable_variables))def train(dataset,epochs):
for epoch in range(epochs):
start = time.time()
for data in progressbar.progressbar(dataset):
image_batch = data["image"]/255
label_batch = data["label"]
label_batch = tf.one_hot(label_batch,10) train_step(image_batch,label_batch) display.clear_output(wait=True)
generate_and_save_images(generator,epoch+1,seed) if (epoch+1) % 2 == 0:
generator.save("generator_model.h5")
discriminator.save("discriminator_model.h5") print('Time for epoch {} is {} sec'.format(epoch+1,time.time()-start)) display.clear_output(wait=True)
generate_and_save_images(generator,epochs,seed)%%time
train(dataset,epochs
Result after 150 epochs
For some reason, the images of number 9 are not properly generated.