Getting started with Generative Adversarial Networks

Beginner Guide to GANs

Getting started with Generative Adversarial Networks

What are GANs?

Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adversarial Networks or GANs, however, use neural networks for a very different purpose: Generative modeling.

Generative Modeling is generating examples that plausibly could have been drawn from the original dataset.

Working of GANs can be explained by the following flowchart:

flow.png

Dataset

We will be using a dump of all images from the famous Bored Apes Yacht Club (BAYC) NFT.

bubble-ape.gif


Making dataset ready for training

Let’s load this dataset using the ImageFolder class from torchvision. We will also resize and crop the images to 64x64 px, and normalize the pixel values with a mean & standard deviation of 0.5 for each channel. This will ensure that pixel values are in the range (-1, 1), which is more convenient for training the discriminator.

import os
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline

DATA_DIR = '../input/bored-apes-yacht-club/'
print(os.listdir(DATA_DIR))
image_size = 64
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
train_ds = ImageFolder(DATA_DIR, transform=T.Compose([
    T.Resize(image_size),
    T.CenterCrop(image_size),
    T.ToTensor(),
    T.Normalize(*stats)]))

train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=2, pin_memory=True)

def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break;

Lets see a sample from dataset -

sample.png

Discriminator Network

The discriminator takes an image as an input, and tries to classify it as “real” or “generated”. In this sense, it’s like any other neural network. We’ll use a convolutional neural network (CNN) which outputs a single number output for every image. We’ll use a stride of 2 to progressively reduce the size of the output feature map.

discriminator.gif

import torch.nn as nn

discriminator = nn.Sequential(
    # in: 3 x 64 x 64

    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 64 x 32 x 32

    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 128 x 16 x 16

    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 256 x 8 x 8

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 4 x 4

    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
    # out: 1 x 1 x 1

    nn.Flatten(),
    nn.Sigmoid())


Generator Network

The input to the generator is typically a vector or a matrix of random numbers (referred to as a latent tensor) which is used as a seed for generating an image. The generator will convert a latent tensor of shape (128, 1, 1) into an image tensor of shape 3 x 28 x 28.

import torch.nn as nn
generator = nn.Sequential(
    # in: latent_size x 1 x 1

    nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # out: 512 x 4 x 4

    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    # out: 256 x 8 x 8

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # out: 128 x 16 x 16

    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),
    # out: 64 x 32 x 32

    nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh()
    # out: 3 x 64 x 64
)


Discriminator Network Training

The steps involved in training the discriminator are:

  • We expect the discriminator to output 1 if the image was picked from the real dataset, and 0 if it was generated using the generator network.

  • We first pass a batch of real images, and compute the loss, setting the target labels to 1.

  • Then we pass a batch of fake images (generated using the generator) pass them into the discriminator, and compute the loss, setting the target labels to 0.

  • Finally, we add the two losses and use the overall loss to perform gradient descent to adjust the weights of the discriminator.

    def train_discriminator(real_images, opt_d):
      # Clear discriminator gradients
      opt_d.zero_grad()
    
      # Pass real images through discriminator
      real_preds = discriminator(real_images)
      real_targets = torch.ones(real_images.size(0), 1, device=device)
      real_loss = F.binary_cross_entropy(real_preds, real_targets)
      real_score = torch.mean(real_preds).item()
    
      # Generate fake images
      latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
      fake_images = generator(latent)
    
      # Pass fake images through discriminator
      fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
      fake_preds = discriminator(fake_images)
      fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
      fake_score = torch.mean(fake_preds).item()
    
      # Update discriminator weights
      loss = real_loss + fake_loss
      loss.backward()
      opt_d.step()
      return loss.item(), real_score, fake_score
    


Generator Network Training

Since the outputs of the generator are images, it’s not obvious how we can train the generator. This is where we employ a rather elegant trick, which is to use the discriminator as a part of the loss function. Here’s how it works:

  • We generate a batch of images using the generator and pass them into the discriminator.

  • We calculate the loss by setting the target labels to 1 i.e. real. We do this because the generator’s objective is to “fool” the discriminator.

  • We use the loss to perform gradient descent i.e. change the weights of the generator, so it gets better at generating real-like images to “fool” the discriminator.

    def train_generator(opt_g):
      # Clear generator gradients
      opt_g.zero_grad()
    
      # Generate fake images
      latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
      fake_images = generator(latent)
    
      # Try to fool the discriminator
      preds = discriminator(fake_images)
      targets = torch.ones(batch_size, 1, device=device)
      loss = F.binary_cross_entropy(preds, targets)
    
      # Update generator weights
      loss.backward()
      opt_g.step()
    
      return loss.item()
    


    Full Training

    Let’s define a fit function to train the discriminator and generator for each batch of training data. We’ll use the Adam optimizer with some custom parameters (betas) that are known to work well for GANs. We will also save some sample generated images at regular intervals for inspection.

    def fit(epochs, lr, start_idx=1):
      torch.cuda.empty_cache()
    
      # Losses & scores
      losses_g = []
      losses_d = []
      real_scores = []
      fake_scores = []
    
      # Create optimizers
      opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
      opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    
      for epoch in range(epochs):
          for real_images, _ in tqdm(train_dl):
              # Train discriminator
              loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
              # Train generator
              loss_g = train_generator(opt_g)
    
          # Record losses & scores
          losses_g.append(loss_g)
          losses_d.append(loss_d)
          real_scores.append(real_score)
          fake_scores.append(fake_score)
    
          # Log losses & scores (last batch)
          print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
              epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
    
          # Save generated images
          save_samples(epoch+start_idx, fixed_latent, show=False)
    
      return losses_g, losses_d, real_scores, fake_scores
    


    Results

    We can see the network progress and a few images are convincing but we are still far from our initial goal

eg.png

We struggle to avoid model collapse, which is a state in which every generator input leads to the same result.

Some of our limitations are due to our network, which is quite simple and small to current standards.

Link to Code: bit.ly/3T3EHgz

More generated examples: bit.ly/3dIZuGg

If you liked this article I would be super excited if you hit the like button or share it with your curious friends. Anyway, thanks for reading have a great day.