🎨 GANs - Generative Adversarial Networks

Two neural networks competing to create realistic images

What are GANs?

GANs consist of two networks playing a game: a Generator creates fake images, while a Discriminator tries to detect fakes. Through competition, the Generator learns to create incredibly realistic outputs.

The GAN Game:

  • Generator: Creates fake images from random noise
  • Discriminator: Distinguishes real vs fake images
  • Training: Generator improves to fool Discriminator
  • Result: Photorealistic generated images

🏗️ GAN Architecture

Training Process:

  1. Generator takes random noise (z) → Creates fake image
  2. Discriminator sees real images (label=1) and fake images (label=0)
  3. Discriminator trains to classify correctly
  4. Generator trains to maximize Discriminator's error
  5. Repeat until Generator creates realistic images

PyTorch Implementation

import torch
import torch.nn as nn

# Generator: noise → image
class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_channels=3, img_size=64):
        super().__init__()
        self.model = nn.Sequential(
            # Input: latent_dim (100)
            nn.Linear(latent_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            
            # Output: flattened image
            nn.Linear(1024, img_channels * img_size * img_size),
            nn.Tanh()  # Output in [-1, 1]
        )
        self.img_channels = img_channels
        self.img_size = img_size
    
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), self.img_channels, self.img_size, self.img_size)
        return img

# Discriminator: image → real/fake probability
class Discriminator(nn.Module):
    def __init__(self, img_channels=3, img_size=64):
        super().__init__()
        self.model = nn.Sequential(
            # Flatten image
            nn.Flatten(),
            
            nn.Linear(img_channels * img_size * img_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            # Output: probability (0=fake, 1=real)
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        validity = self.model(img)
        return validity

🎯 Training a GAN

import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 64
epochs = 100

# Initialize models
generator = Generator(latent_dim=latent_dim)
discriminator = Discriminator()

# Optimizers (Adam works best for GANs)
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Loss function
adversarial_loss = nn.BCELoss()

# Load real data
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])
dataloader = DataLoader(
    datasets.CIFAR10(root="./data", train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

# Training loop
for epoch in range(epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        batch_size = real_imgs.size(0)
        
        # Labels
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        
        # Real images
        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
        
        # Fake images
        z = torch.randn(batch_size, latent_dim)
        fake_imgs = generator(z)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        
        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        
        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()
        
        # Generator wants discriminator to think fakes are real
        z = torch.randn(batch_size, latent_dim)
        fake_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)
        
        g_loss.backward()
        optimizer_G.step()
        
        # Log progress
        if i % 100 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
    
    # Save generated samples
    if epoch % 10 == 0:
        with torch.no_grad():
            z = torch.randn(16, latent_dim)
            generated = generator(z)
            # Save images...

🚀 Advanced GAN Variants

DCGAN

Deep Convolutional GAN

  • Uses convolutional layers
  • More stable training
  • Better image quality
  • Industry standard baseline

StyleGAN

Style-based Generator

  • Controls image style at different levels
  • Photorealistic faces
  • Used by thispersondoesnotexist.com
  • State-of-the-art quality

CycleGAN

Unpaired Image Translation

  • Photo → Painting style
  • Horse → Zebra conversion
  • No paired training data needed
  • Creative applications

Conditional GAN

Controlled Generation

  • Generate specific classes
  • Text-to-image (early approach)
  • Guided generation
  • More control over output

Pix2Pix

Paired Image Translation

  • Sketch → Photo
  • Day → Night
  • Requires paired data
  • Precise transformations

BigGAN

Large-scale GAN

  • Trained on ImageNet
  • High-resolution (512×512)
  • Diverse outputs
  • Strong baseline

⚡ DCGAN Implementation

import torch.nn as nn

class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, channels=3):
        super().__init__()
        self.model = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # State: 512 x 4 x 4
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State: 256 x 8 x 8
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State: 128 x 16 x 16
            
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # State: 64 x 32 x 32
            
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: channels x 64 x 64
        )
    
    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.model(z)

class DCGANDiscriminator(nn.Module):
    def __init__(self, channels=3):
        super().__init__()
        self.model = nn.Sequential(
            # Input: channels x 64 x 64
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 64 x 32 x 32
            
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 128 x 16 x 16
            
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 256 x 8 x 8
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 512 x 4 x 4
            
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output: 1 x 1 x 1
        )
    
    def forward(self, img):
        return self.model(img).view(-1, 1)

⚠️ Training Challenges

Common Issues:

  • Mode Collapse: Generator produces same output
  • Unstable Training: Losses oscillate wildly
  • Vanishing Gradients: Generator stops improving
  • Discriminator Wins: Too good, Generator can't learn

Solutions:

  • Use DCGAN architecture (more stable)
  • Label smoothing: Use 0.9 instead of 1.0 for real labels
  • Feature matching: Match statistics, not exact outputs
  • Minibatch discrimination: Encourage diversity
  • Spectral normalization: Stabilize discriminator
  • Lower learning rate: 0.0002 is common

🎯 Key Takeaways