🎨 Variational Autoencoders (VAEs)

Probabilistic generative models for image synthesis

What are VAEs?

Variational Autoencoders are generative models that learn to encode data into a compressed latent space and then decode it back. Unlike regular autoencoders, VAEs learn a probability distribution, making them powerful for generating new, similar data.

Key Concepts:

  • Encoder: Compresses input into latent distribution
  • Latent Space: Compressed probabilistic representation
  • Decoder: Generates output from latent sample
  • Reparameterization Trick: Enables backpropagation through sampling

πŸ—οΈ VAE Architecture

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

class VAE(nn.Module):
    """Variational Autoencoder for image generation"""
    
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        
        # Encoder layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Mean of latent distribution
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Log variance
        
        # Decoder layers
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        """Encode input to latent distribution parameters"""
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick: z = mu + std * epsilon"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        """Decode latent vector to output"""
        h = F.relu(self.fc3(z))
        x_recon = torch.sigmoid(self.fc4(h))
        return x_recon
    
    def forward(self, x):
        """Full forward pass"""
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=784, hidden_dim=400, latent_dim=20).to(device)

print("VAE Architecture:")
print(model)
print(f"\nUsing device: {device}")

🎯 Loss Function

def vae_loss(x_recon, x, mu, logvar):
    """
    VAE loss = Reconstruction Loss + KL Divergence
    
    Reconstruction Loss: How well we reconstruct the input
    KL Divergence: Regularization to keep latent distribution close to N(0,1)
    """
    # Reconstruction loss (Binary Cross Entropy)
    recon_loss = F.binary_cross_entropy(x_recon, x.view(-1, 784), reduction='sum')
    
    # KL divergence loss
    # KL(q(z|x) || p(z)) where p(z) = N(0,1)
    # = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Total loss
    total_loss = recon_loss + kl_loss
    
    return total_loss, recon_loss, kl_loss

# Example loss calculation
x_sample = torch.randn(32, 1, 28, 28).to(device)
x_recon, mu, logvar = model(x_sample)
loss, recon, kl = vae_loss(x_recon, x_sample, mu, logvar)

print(f"Total Loss: {loss.item():.2f}")
print(f"Reconstruction Loss: {recon.item():.2f}")
print(f"KL Divergence: {kl.item():.2f}")

πŸš€ Training VAE

from torch.utils.data import DataLoader

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, transform=transform, download=True
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 10

print("Training VAE...\n")

train_losses = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        
        # Forward pass
        x_recon, mu, logvar = model(data)
        
        # Calculate loss
        loss, recon_loss, kl_loss = vae_loss(x_recon, data, mu, logvar)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx}/{len(train_loader)}] '
                  f'Loss: {loss.item()/len(data):.4f}')
    
    avg_loss = train_loss / len(train_loader.dataset)
    train_losses.append(avg_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}\n')

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('VAE Training Loss')
plt.grid(True)
plt.show()

print("Training completed!")

🎨 Generating New Images

def generate_images(model, num_images=16):
    """Generate new images by sampling from latent space"""
    model.eval()
    
    with torch.no_grad():
        # Sample from standard normal distribution
        z = torch.randn(num_images, 20).to(device)
        
        # Decode to generate images
        generated = model.decode(z)
        generated = generated.view(-1, 1, 28, 28).cpu()
    
    # Plot generated images
    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated[i].squeeze(), cmap='gray')
        ax.axis('off')
    plt.suptitle('VAE Generated Images')
    plt.tight_layout()
    plt.show()

# Generate new images
generate_images(model, num_images=16)

print("Generated 16 new images from random latent vectors!")

πŸ”„ Reconstructing Images

def reconstruct_images(model, data_loader, num_images=8):
    """Show original vs reconstructed images"""
    model.eval()
    
    # Get a batch of images
    data_iter = iter(data_loader)
    images, _ = next(data_iter)
    images = images[:num_images].to(device)
    
    with torch.no_grad():
        recon, _, _ = model(images)
        recon = recon.view(-1, 1, 28, 28)
    
    # Plot original and reconstructed
    fig, axes = plt.subplots(2, num_images, figsize=(15, 4))
    
    for i in range(num_images):
        # Original
        axes[0, i].imshow(images[i].cpu().squeeze(), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=12)
        
        # Reconstructed
        axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=12)
    
    plt.tight_layout()
    plt.show()

# Show reconstructions
reconstruct_images(model, train_loader)

print("Comparison of original vs reconstructed images")

πŸ—ΊοΈ Latent Space Exploration

def explore_latent_space(model, data_loader, num_samples=1000):
    """Visualize the learned latent space"""
    model.eval()
    
    latent_vectors = []
    labels = []
    
    with torch.no_grad():
        for images, lbls in data_loader:
            if len(latent_vectors) * len(images) >= num_samples:
                break
            
            images = images.to(device)
            mu, _ = model.encode(images.view(-1, 784))
            latent_vectors.append(mu.cpu().numpy())
            labels.append(lbls.numpy())
    
    latent_vectors = np.concatenate(latent_vectors)[:num_samples]
    labels = np.concatenate(labels)[:num_samples]
    
    # Use PCA to visualize 20D latent space in 2D
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    latent_2d = pca.fit_transform(latent_vectors)
    
    # Plot
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], 
                         c=labels, cmap='tab10', alpha=0.6)
    plt.colorbar(scatter, label='Digit')
    plt.xlabel('First Principal Component')
    plt.ylabel('Second Principal Component')
    plt.title('VAE Latent Space Visualization (2D PCA)')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"Explained variance: {pca.explained_variance_ratio_.sum():.2%}")

# Visualize latent space
explore_latent_space(model, train_loader)

print("Latent space shows clear clustering by digit!")

🎭 Latent Space Interpolation

def interpolate_latent(model, data_loader, steps=10):
    """Interpolate between two images in latent space"""
    model.eval()
    
    # Get two random images
    data_iter = iter(data_loader)
    images, _ = next(data_iter)
    img1, img2 = images[0:1].to(device), images[1:2].to(device)
    
    with torch.no_grad():
        # Encode to latent space
        mu1, _ = model.encode(img1.view(-1, 784))
        mu2, _ = model.encode(img2.view(-1, 784))
        
        # Interpolate
        alphas = np.linspace(0, 1, steps)
        interpolated_images = []
        
        for alpha in alphas:
            z_interp = (1 - alpha) * mu1 + alpha * mu2
            img_interp = model.decode(z_interp)
            interpolated_images.append(img_interp.view(28, 28).cpu())
    
    # Plot interpolation
    fig, axes = plt.subplots(1, steps, figsize=(20, 3))
    for i, (ax, img) in enumerate(zip(axes, interpolated_images)):
        ax.imshow(img, cmap='gray')
        ax.axis('off')
        ax.set_title(f'Ξ±={alphas[i]:.1f}')
    plt.suptitle('Latent Space Interpolation')
    plt.tight_layout()
    plt.show()

# Show interpolation
interpolate_latent(model, train_loader, steps=10)

print("Smooth transition between two images through latent space!")

πŸ†š VAEs vs GANs

Aspect VAEs GANs
Training Stable, easier to train Unstable, mode collapse
Image Quality Often blurry Sharp, realistic
Latent Space Continuous, interpretable Less structured
Diversity Good coverage Can miss modes
Use Cases Compression, interpolation High-quality generation

🎨 Convolutional VAE

class ConvVAE(nn.Module):
    """Convolutional VAE for better image generation"""
    
    def __init__(self, latent_dim=128):
        super(ConvVAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),  # 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 7x7
            nn.ReLU(),
            nn.Flatten()
        )
        
        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
        
        # Decoder
        self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),  # 28x28
            nn.Sigmoid()
        )
    
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(-1, 64, 7, 7)
        return self.decoder(h)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Create convolutional VAE
conv_vae = ConvVAE(latent_dim=128).to(device)
print("\nConvolutional VAE:")
print(conv_vae)
print("\nBetter for high-resolution images!")

🎯 Key Takeaways