
Now if you’ve been bit active in the Machine learning community , you would have already heard about GANS ( Generative Adversarial Networks)
What exactly are GANS ?
To put it simple GANS pits two neural networks against each other , I know I made nothing clear with the above statement so let me give you an analogy.
Let’s consider the relation between a forger and an investigator
The task of a forger is to create fraudulent imitations of original paintings by famous artists. If this created piece can pass as the original one, the forger gets a lot of money in exchange of the piece.
On the other hand, an art investigator’s task is to catch these forgers who create the fraudulent pieces. How does he do it? He knows what are the properties which sets the original artist apart and what kind of painting he should have created. He evaluates this knowledge with the piece in hand to check if it is real or not.
This contest of forger vs investigator goes on, which ultimately makes world class investigators (and unfortunately world class forger).
Now you might wonder how is it related to GANS ? well GANS work in a similar way where the Generator (forger) takes up the task of creating fake data (images) whereas the Discriminator (investigator) takes up the task of overviewing the data and removing out the fake ones as well as giving the necessary feedback to the Generator to create better images .
The Generator neural network takes an random noise as an input to generate some images which are mixed with some real images and sent to discriminator and the task of Discriminator Network is to take input either from the real data or from the generator and try to predict whether the input is real or generated.
Let’s try to learn about the technicalities while trying to implement a basic DCGANS (Deep Convolutional GANS)
Deep Convolutional GAN (DCGAN) is one of the models that demonstrated how to build a practical GAN that can learn by itself how to synthesize new images. DCGAN is very similar to GANs but specifically focuses on using deep convolutional networks in place of fully-connected networks.
DCGAN implementation
Now let’s try to implement DCGANs with the use of the mnist_dataset .
Importing Libraries
We will be using Tensorflow 2.0 for the implementation of the DCGAN ,we will be using the following layers : Dense , Conv2D and Conv2DTranspose . We will be using RMSprop as the optimizer for the model and for plotting ,I have used matplotlib library .
Generator Neural Network
As we can see from the above image the Generator first takes a random noise as an input and in DCGANS we use Conv2DTranspose to deconvulise it and generate fake data which then is pitted against the discriminator , we are using a functional approach while building this model, if you are confused about what exactly is a functional approach you can refer to my previous blog from here
In the Generator network above first we are reshaping the data and in the filters [128,64,32,1] we are adding strides then we are using Conv2DTranspose for making it into an image output and returning the Generator model.
Discriminator Neural Network
Now let’s try to build the Discriminator part of the model , which acts as an Investigator ,it will label the real-images as 1.0 and that of fake images generated by the Generator neural network as 0.0 . Sigmoid activation function is used to different between real and fake images in this model.
In the discriminator model we have used the Conv2D layer for capturing the features in the images and as we are using Sigmoid as the activation function the output will be either 0.0(Fake) or 1.0(Real) .
Train the GAN model
Now let’s alternately train Discriminator and Adversarial networks by batch, Discriminator is trained first with real and fake images then Adversarial network is trained next with fake images pretending to be real and finally we generate fake images for every 500 intervals with the feedback from Discriminator.
Remember to stabilize a GAN and to make sure the Generator learns something , the Discriminator should be way more powerful(trained) than the Generator only then can it send valuable feedback to the Generator from which it can learn and create better fake images
In the above train function we are generating the fake images using random noise then adding these fake images to the real data , then we are training the Discriminator with this data where label the real images as 1 and that of fake as 0 and we will train it and store it’s loss and accuracy.
After training the Discriminator we will train the Adversarial network where the fake images are labeled as 1.0 and while training the Adversarial network the Discriminator weights are frozen and only the Generator is trained and we will create a function for plotting the images which the Generator neural network has generated and also we save this current model for every 500 time steps.
Plot the Generator Images
Now let’s create a function which plot’s the Generated images and saves them at every 500 timesteps
In the above function we are creating a grid of images where the Generator tries to recreate the images with noise and feedback from the Discriminator and we are storing all these images in a folder called GAN_mnist
Build and Train the models
As we have created functions for both Generator and Discriminator let’s try to build and train all the networks
We are loading the Mnist_dataset , then at first training the Discriminator part of the model (as I have mentioned above for the Stability and progress of the DCGAN model) then we are taking the noise to generate images from the Generator
Adversarial network = Generator + Discriminator , while training the Adversarial network we freeze the weights of the Discriminator ,after that we compile the Adversarial network where we have used binary_crossentropy as the loss function.
Run the DCGAN
Now let’s run the model
This above code first checks if you have any pre-trained saved model if so it continues the training from there if not it starts fresh
Here is the result I got while training the mnist DCGANS
The End(at least for now)
You can find the code for the blog from here
PS: If you have any doubts you can mail me here (pavankunchalapk@gmail.com), you can contact me on my linkedin from here and you can check out my other codes(it has really cool stuff) on my Github from here
I am also looking for Freelancing opportunities in the field of Deep Learning and Computer vision if you are willing to collaborate, mail me here( pavankunchalapk@gmail.com)
Have a wonderful day!