What are VAEs?
Variational Autoencoders are generative models that learn to encode data into a compressed latent space and then decode it back. Unlike regular autoencoders, VAEs learn a probability distribution, making them powerful for generating new, similar data.
Key Concepts:
- Encoder: Compresses input into latent distribution
- Latent Space: Compressed probabilistic representation
- Decoder: Generates output from latent sample
- Reparameterization Trick: Enables backpropagation through sampling
ποΈ VAE Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
class VAE(nn.Module):
"""Variational Autoencoder for image generation"""
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()
# Encoder layers
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim) # Mean of latent distribution
self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # Log variance
# Decoder layers
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
"""Encode input to latent distribution parameters"""
h = F.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
"""Reparameterization trick: z = mu + std * epsilon"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def decode(self, z):
"""Decode latent vector to output"""
h = F.relu(self.fc3(z))
x_recon = torch.sigmoid(self.fc4(h))
return x_recon
def forward(self, x):
"""Full forward pass"""
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
x_recon = self.decode(z)
return x_recon, mu, logvar
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=784, hidden_dim=400, latent_dim=20).to(device)
print("VAE Architecture:")
print(model)
print(f"\nUsing device: {device}")
π― Loss Function
def vae_loss(x_recon, x, mu, logvar):
"""
VAE loss = Reconstruction Loss + KL Divergence
Reconstruction Loss: How well we reconstruct the input
KL Divergence: Regularization to keep latent distribution close to N(0,1)
"""
# Reconstruction loss (Binary Cross Entropy)
recon_loss = F.binary_cross_entropy(x_recon, x.view(-1, 784), reduction='sum')
# KL divergence loss
# KL(q(z|x) || p(z)) where p(z) = N(0,1)
# = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Total loss
total_loss = recon_loss + kl_loss
return total_loss, recon_loss, kl_loss
# Example loss calculation
x_sample = torch.randn(32, 1, 28, 28).to(device)
x_recon, mu, logvar = model(x_sample)
loss, recon, kl = vae_loss(x_recon, x_sample, mu, logvar)
print(f"Total Loss: {loss.item():.2f}")
print(f"Reconstruction Loss: {recon.item():.2f}")
print(f"KL Divergence: {kl.item():.2f}")
π Training VAE
from torch.utils.data import DataLoader
# Load MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, transform=transform, download=True
)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 10
print("Training VAE...\n")
train_losses = []
for epoch in range(num_epochs):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
# Forward pass
x_recon, mu, logvar = model(data)
# Calculate loss
loss, recon_loss, kl_loss = vae_loss(x_recon, data, mu, logvar)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}] Batch [{batch_idx}/{len(train_loader)}] '
f'Loss: {loss.item()/len(data):.4f}')
avg_loss = train_loss / len(train_loader.dataset)
train_losses.append(avg_loss)
print(f'Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}\n')
# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.title('VAE Training Loss')
plt.grid(True)
plt.show()
print("Training completed!")
π¨ Generating New Images
def generate_images(model, num_images=16):
"""Generate new images by sampling from latent space"""
model.eval()
with torch.no_grad():
# Sample from standard normal distribution
z = torch.randn(num_images, 20).to(device)
# Decode to generate images
generated = model.decode(z)
generated = generated.view(-1, 1, 28, 28).cpu()
# Plot generated images
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
ax.imshow(generated[i].squeeze(), cmap='gray')
ax.axis('off')
plt.suptitle('VAE Generated Images')
plt.tight_layout()
plt.show()
# Generate new images
generate_images(model, num_images=16)
print("Generated 16 new images from random latent vectors!")
π Reconstructing Images
def reconstruct_images(model, data_loader, num_images=8):
"""Show original vs reconstructed images"""
model.eval()
# Get a batch of images
data_iter = iter(data_loader)
images, _ = next(data_iter)
images = images[:num_images].to(device)
with torch.no_grad():
recon, _, _ = model(images)
recon = recon.view(-1, 1, 28, 28)
# Plot original and reconstructed
fig, axes = plt.subplots(2, num_images, figsize=(15, 4))
for i in range(num_images):
# Original
axes[0, i].imshow(images[i].cpu().squeeze(), cmap='gray')
axes[0, i].axis('off')
if i == 0:
axes[0, i].set_title('Original', fontsize=12)
# Reconstructed
axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
axes[1, i].axis('off')
if i == 0:
axes[1, i].set_title('Reconstructed', fontsize=12)
plt.tight_layout()
plt.show()
# Show reconstructions
reconstruct_images(model, train_loader)
print("Comparison of original vs reconstructed images")
πΊοΈ Latent Space Exploration
def explore_latent_space(model, data_loader, num_samples=1000):
"""Visualize the learned latent space"""
model.eval()
latent_vectors = []
labels = []
with torch.no_grad():
for images, lbls in data_loader:
if len(latent_vectors) * len(images) >= num_samples:
break
images = images.to(device)
mu, _ = model.encode(images.view(-1, 784))
latent_vectors.append(mu.cpu().numpy())
labels.append(lbls.numpy())
latent_vectors = np.concatenate(latent_vectors)[:num_samples]
labels = np.concatenate(labels)[:num_samples]
# Use PCA to visualize 20D latent space in 2D
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
latent_2d = pca.fit_transform(latent_vectors)
# Plot
plt.figure(figsize=(12, 10))
scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1],
c=labels, cmap='tab10', alpha=0.6)
plt.colorbar(scatter, label='Digit')
plt.xlabel('First Principal Component')
plt.ylabel('Second Principal Component')
plt.title('VAE Latent Space Visualization (2D PCA)')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"Explained variance: {pca.explained_variance_ratio_.sum():.2%}")
# Visualize latent space
explore_latent_space(model, train_loader)
print("Latent space shows clear clustering by digit!")
π Latent Space Interpolation
def interpolate_latent(model, data_loader, steps=10):
"""Interpolate between two images in latent space"""
model.eval()
# Get two random images
data_iter = iter(data_loader)
images, _ = next(data_iter)
img1, img2 = images[0:1].to(device), images[1:2].to(device)
with torch.no_grad():
# Encode to latent space
mu1, _ = model.encode(img1.view(-1, 784))
mu2, _ = model.encode(img2.view(-1, 784))
# Interpolate
alphas = np.linspace(0, 1, steps)
interpolated_images = []
for alpha in alphas:
z_interp = (1 - alpha) * mu1 + alpha * mu2
img_interp = model.decode(z_interp)
interpolated_images.append(img_interp.view(28, 28).cpu())
# Plot interpolation
fig, axes = plt.subplots(1, steps, figsize=(20, 3))
for i, (ax, img) in enumerate(zip(axes, interpolated_images)):
ax.imshow(img, cmap='gray')
ax.axis('off')
ax.set_title(f'Ξ±={alphas[i]:.1f}')
plt.suptitle('Latent Space Interpolation')
plt.tight_layout()
plt.show()
# Show interpolation
interpolate_latent(model, train_loader, steps=10)
print("Smooth transition between two images through latent space!")
π VAEs vs GANs
| Aspect | VAEs | GANs |
|---|---|---|
| Training | Stable, easier to train | Unstable, mode collapse |
| Image Quality | Often blurry | Sharp, realistic |
| Latent Space | Continuous, interpretable | Less structured |
| Diversity | Good coverage | Can miss modes |
| Use Cases | Compression, interpolation | High-quality generation |
π¨ Convolutional VAE
class ConvVAE(nn.Module):
"""Convolutional VAE for better image generation"""
def __init__(self, latent_dim=128):
super(ConvVAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 4, stride=2, padding=1), # 14x14
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 7x7
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
# Decoder
self.fc_decode = nn.Linear(latent_dim, 64 * 7 * 7)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 14x14
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1), # 28x28
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = self.fc_decode(z)
h = h.view(-1, 64, 7, 7)
return self.decoder(h)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# Create convolutional VAE
conv_vae = ConvVAE(latent_dim=128).to(device)
print("\nConvolutional VAE:")
print(conv_vae)
print("\nBetter for high-resolution images!")
π― Key Takeaways
- VAEs learn distributions not just mappings
- Reparameterization trick enables gradient flow
- Loss = Reconstruction + KL divergence
- Continuous latent space enables interpolation
- Stable training compared to GANs
- Blurry outputs but good for compression
- Useful for data compression, anomaly detection, interpolation