Generative Adversarial Networks [GANs] Explained

·

6 min read

Introduction

In this article, we will be discussing about GANs and how exactly they can help in increasing the amount of data you require for your models thus giving you a bigger sample dataset to test and run your model on which can lead to an increase in the accuracy of the model being built. We will be sticking to the common application of the generation of fake images in this article.

GANs are a type of generative model in the field of deep learning used to create synthetic data that closely resemble your pre-existing training data.

GANs mainly work through the process of making 2 components of the model fight with each other to win. A battle happens between a 'Generator' and a 'Discriminator' where in the generator is trying to produce synthetic / fake data using the pre-existing data and the discriminator is trying not to be fooled by the forged data generated by it.

Architecture of GANs

A diagram of a generative adversarial network. At the center of the
diagram is a box labeled 'discriminator'. Two branches feed into this
box from the left.  The top branch starts at the upper left of the
diagram with a cylinder labeled 'real world images'. An arrow leads
from this cylinder to a box labeled 'Sample'. An arrow from the box
labeled 'Sample' feeds into the 'Discriminator' box. The bottom branch
feeds into the 'Discriminator' box starting with a box labeled 'Random
Input'. An arrow leads from the 'Random Input' box to a box labeled
'Generator'. An arrow leads from the 'Generator' box to a second
'Sample' box. An arrow leads from the 'Sample' box to the
'Discriminator box. On the right side of the Discriminator box, an
arrow leads to a box containing a green circle and a red circle. The
word 'Real' appears in green text above the box and the word 'False'
appears in red below the box. Two arrows lead from this box to two
boxes on the right side of the diagram. One arrow leads to a box
labeled 'Discriminator loss'. The other arrow leads to a box labeled
'Generator loss'.

We first supply a random noise seed which is also known as a latent vector is like a 1D vector. It is passed into a generator network where it is then scaled up to a 2D image format. We simultaneously pass in real images as well and both of these together get passed into the 'Discriminator Network'.

The discriminator network then gets trained on sorting between the real and the fake images (generated using the generator) and finally a discriminator loss and generative loss value is calculated. This is done continuously depending on the amount of images required.

To summarize, the steps taken to train a GAN network are :

  1. We define an GAN architecture application based on the application and the type of data available

  2. We train a discriminator to be able to sort through synthetic vs original data.

  3. We train a generator to create data that can fool the generator.

  4. We run this process of pitting the generator and discriminator against each other for multiple epochs

  5. We finally use the trained model to create new data that closely resembles the original data but is synthetic in nature.

Individual components of GAN networks

The 'Generator' and the 'Discriminator' components of the GAN networks are both simply classification models that fight with each other to generate synthetic data at the end that is similar to the original data.

1. Generator

The generator is mainly responsible for the generation of fake data based on the feedback provided to it by the discriminator. In general, it takes in random noise in the form of input and translates this noise to data that is valid in nature.

After the discriminator, finishes classifying the data sent in as real and fake, through the process of back-propagation, the weights used by the generator are updated with the help of gradients by using the score calculated after the discriminator classifies the data which is known as 'Generative Loss'.

To summarize, the steps taken by the Generator are :

  1. Random noise or a latent vector is supplied to it

  2. An output is produced with the help of this random noise

  3. The results of classification as 'real' or 'fake' by the discriminator for the provided data are received

  4. The generative loss is calculated from the discriminator's classification

  5. Backpropagation takes place through the network to obtain gradients

  6. The gradients are used to update the weights of the generator for better / more accurate results.

2. Discriminator

The discriminator is the main core-component of the GAN network. It acts as the main classifier between the real and the synthetic data [generated by the generator] and classifies them as 'real' or 'fake' by assigning the data scores. These scores are known as discriminator loss and generative loss.

The way the discriminator works is it connects to 2 loss functions. During the process of training, the discriminator ignores the generative loss that occurs and only use the discriminator loss.

To summarize, the steps taken by the Discriminator are :

  1. It acts as a classifier between the real and fake data

  2. The discriminator loss is used to penalize the discriminator for the mis-classification that occurs. It mainly penalizes it for classifying a real form of data as fake or vice-versa.

  3. Finally, the weights of the discriminator is updated via back-propagation using the discriminator loss in the network.

Loss functions used by GANs

Now, we will be taking a step-back to understand what exactly the scores aka the 'Generator loss' and the 'Discriminator loss' denotes in GANs.

Overall, a game of minimax is played in this algorithm where one component of the network is trying to compete and out-smart the other component.

  • Generator Loss : The generator loss is the measure of how well it managed to fool the discriminator thus making it mis-classify the fake data as real instead.

  • Discriminator Loss : The discriminator loss is a measure of the amount of misclassifications performed by the discriminator and the discriminator punishes itself for misclassifying a fake image as a real one or vice versa.

These loss functions however have their own set of limitations which are :

  1. Mode Collapse : The model can sometimes be over-fitting in nature and can only generate synthetic data for one specific subset of the classes of data present in the original data.

  2. Vanishing Gradients : The gradient of the generator may become saturated in nature because the discriminator could over-power and perform better than it.

  3. Convergence : In an ideal scenario, both the components tend to be able to achieve consistent results given similar input data. However, it is not so always and there could always be a shift and one model could take over the other and beat it.

Applications of GANs

GANs are one of the main deep learning models that lead to a spur in a number of generative deep learning models that can be used for synthetic data generation and more. Some of it's common applications are :

  1. Image Inpainting : Filling the empty gaps that were cut-off or lost in an image can be done using GANs

  2. Denoising : Removing any sort of noise to improve clarity especially for hard to read images like X-rays

  3. Image quality increase : GANs can be used to increase stuff like contrast and sharpness thereby increasing the overall quality of an image by a drastic amount.

Conclusion

In conclusion, Generative Adversarial Networks have proved to be crucial for any sort of complex data augmentation tasks that are required to be performed to bring in variance or to increase the size of the sample set of data required to train an AI model and increase its accuracy.

Did you find this article valuable?

Support Akash GSS's Blog by becoming a sponsor. Any amount is appreciated!