What are GANs?
GANs consist of two networks playing a game: a Generator creates fake images, while a Discriminator tries to detect fakes. Through competition, the Generator learns to create incredibly realistic outputs.
The GAN Game:
- Generator: Creates fake images from random noise
- Discriminator: Distinguishes real vs fake images
- Training: Generator improves to fool Discriminator
- Result: Photorealistic generated images
🏗️ GAN Architecture
Training Process:
- Generator takes random noise (z) → Creates fake image
- Discriminator sees real images (label=1) and fake images (label=0)
- Discriminator trains to classify correctly
- Generator trains to maximize Discriminator's error
- Repeat until Generator creates realistic images
PyTorch Implementation
import torch
import torch.nn as nn
# Generator: noise → image
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_channels=3, img_size=64):
super().__init__()
self.model = nn.Sequential(
# Input: latent_dim (100)
nn.Linear(latent_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
# Output: flattened image
nn.Linear(1024, img_channels * img_size * img_size),
nn.Tanh() # Output in [-1, 1]
)
self.img_channels = img_channels
self.img_size = img_size
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), self.img_channels, self.img_size, self.img_size)
return img
# Discriminator: image → real/fake probability
class Discriminator(nn.Module):
def __init__(self, img_channels=3, img_size=64):
super().__init__()
self.model = nn.Sequential(
# Flatten image
nn.Flatten(),
nn.Linear(img_channels * img_size * img_size, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
# Output: probability (0=fake, 1=real)
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
validity = self.model(img)
return validity
🎯 Training a GAN
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 64
epochs = 100
# Initialize models
generator = Generator(latent_dim=latent_dim)
discriminator = Discriminator()
# Optimizers (Adam works best for GANs)
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
# Loss function
adversarial_loss = nn.BCELoss()
# Load real data
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
])
dataloader = DataLoader(
datasets.CIFAR10(root="./data", train=True, download=True, transform=transform),
batch_size=batch_size,
shuffle=True
)
# Training loop
for epoch in range(epochs):
for i, (real_imgs, _) in enumerate(dataloader):
batch_size = real_imgs.size(0)
# Labels
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Real images
real_loss = adversarial_loss(discriminator(real_imgs), real_labels)
# Fake images
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
# Total discriminator loss
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Generator wants discriminator to think fakes are real
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)
g_loss.backward()
optimizer_G.step()
# Log progress
if i % 100 == 0:
print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
# Save generated samples
if epoch % 10 == 0:
with torch.no_grad():
z = torch.randn(16, latent_dim)
generated = generator(z)
# Save images...
🚀 Advanced GAN Variants
DCGAN
Deep Convolutional GAN
- Uses convolutional layers
- More stable training
- Better image quality
- Industry standard baseline
StyleGAN
Style-based Generator
- Controls image style at different levels
- Photorealistic faces
- Used by thispersondoesnotexist.com
- State-of-the-art quality
CycleGAN
Unpaired Image Translation
- Photo → Painting style
- Horse → Zebra conversion
- No paired training data needed
- Creative applications
Conditional GAN
Controlled Generation
- Generate specific classes
- Text-to-image (early approach)
- Guided generation
- More control over output
Pix2Pix
Paired Image Translation
- Sketch → Photo
- Day → Night
- Requires paired data
- Precise transformations
BigGAN
Large-scale GAN
- Trained on ImageNet
- High-resolution (512×512)
- Diverse outputs
- Strong baseline
⚡ DCGAN Implementation
import torch.nn as nn
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim=100, channels=3):
super().__init__()
self.model = nn.Sequential(
# Input: latent_dim x 1 x 1
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# State: 512 x 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# State: 256 x 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# State: 128 x 16 x 16
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# State: 64 x 32 x 32
nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
nn.Tanh()
# Output: channels x 64 x 64
)
def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1)
return self.model(z)
class DCGANDiscriminator(nn.Module):
def __init__(self, channels=3):
super().__init__()
self.model = nn.Sequential(
# Input: channels x 64 x 64
nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# State: 64 x 32 x 32
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# State: 128 x 16 x 16
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# State: 256 x 8 x 8
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# State: 512 x 4 x 4
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
# Output: 1 x 1 x 1
)
def forward(self, img):
return self.model(img).view(-1, 1)
⚠️ Training Challenges
Common Issues:
- Mode Collapse: Generator produces same output
- Unstable Training: Losses oscillate wildly
- Vanishing Gradients: Generator stops improving
- Discriminator Wins: Too good, Generator can't learn
Solutions:
- Use DCGAN architecture (more stable)
- Label smoothing: Use 0.9 instead of 1.0 for real labels
- Feature matching: Match statistics, not exact outputs
- Minibatch discrimination: Encourage diversity
- Spectral normalization: Stabilize discriminator
- Lower learning rate: 0.0002 is common
🎯 Key Takeaways
- GANs use adversarial training to generate realistic images
- Generator vs Discriminator - competition improves both
- DCGAN is the standard architecture to start with
- Training is tricky - requires careful tuning
- StyleGAN achieves photorealistic results
- Diffusion models now preferred for most use cases