GAN Mathematics: Comprehensive TheoryΒΆ
IntroductionΒΆ
Generative Adversarial Networks (GANs), introduced by Ian Goodfellow et al. in 2014, revolutionized generative modeling by framing it as a two-player minimax game between a generator and discriminator. This adversarial framework enables learning complex data distributions without explicit density estimation.
Core Idea: Train two neural networks in opposition:
Generator \(G\): Learns to create fake samples that resemble real data
Discriminator \(D\): Learns to distinguish real samples from fake ones
Through this competition, \(G\) improves at generating realistic samples, while \(D\) becomes better at detecting fakes. At equilibrium, \(G\) produces samples indistinguishable from real data.
Mathematical FormulationΒΆ
The Minimax GameΒΆ
GANs are defined by the following minimax objective:
Where:
\(\mathbf{x}\): Real data samples
\(\mathbf{z}\): Latent noise vector (typically \(\mathbf{z} \sim \mathcal{N}(0, I)\))
\(G(\mathbf{z})\): Generated sample from noise \(\mathbf{z}\)
\(D(\mathbf{x}) \in [0, 1]\): Probability that \(\mathbf{x}\) is real
\(p_{\text{data}}\): True data distribution
\(p_z\): Prior distribution over latent space (usually Gaussian or uniform)
\(p_g\): Distribution induced by generator \(G\)
InterpretationΒΆ
Discriminatorβs Goal (maximization): $\(\max_D \left[\underbrace{\mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[\log D(\mathbf{x})]}_{\text{correctly classify real}} + \underbrace{\mathbb{E}_{\mathbf{z} \sim p_z}[\log(1 - D(G(\mathbf{z})))]}_{\text{correctly classify fake}}\right]\)$
Assign \(D(\mathbf{x}) \approx 1\) for real samples
Assign \(D(G(\mathbf{z})) \approx 0\) for fake samples
Generatorβs Goal (minimization): $\(\min_G \mathbb{E}_{\mathbf{z} \sim p_z}[\log(1 - D(G(\mathbf{z})))]\)$
Make \(D(G(\mathbf{z})) \approx 1\) (fool the discriminator)
Equivalent to making generated samples indistinguishable from real
Optimal DiscriminatorΒΆ
Proposition 1: Optimal Discriminator for Fixed GeneratorΒΆ
For a fixed generator \(G\), the optimal discriminator is:
Proof:
The objective for \(D\) can be rewritten as:
To maximize w.r.t. \(D(\mathbf{x})\) for each \(\mathbf{x}\), take the derivative and set to zero:
Solving for \(D(\mathbf{x})\):
Interpretation:
If \(p_{\text{data}}(\mathbf{x}) \gg p_g(\mathbf{x})\): \(D^*_G(\mathbf{x}) \approx 1\) (likely real)
If \(p_g(\mathbf{x}) \gg p_{\text{data}}(\mathbf{x})\): \(D^*_G(\mathbf{x}) \approx 0\) (likely fake)
If \(p_{\text{data}}(\mathbf{x}) = p_g(\mathbf{x})\): \(D^*_G(\mathbf{x}) = 0.5\) (perfectly confused)
Virtual Training CriterionΒΆ
Proposition 2: Generator Training with Optimal DiscriminatorΒΆ
Substituting \(D^*_G\) into the objective yields:
Expanding:
Rewrite using the average distribution \(p_m = \frac{p_{\text{data}} + p_g}{2}\):
Where \(\text{JSD}\) is the Jensen-Shannon divergence:
Properties of JSD:
Non-negative: \(\text{JSD}(p \| q) \geq 0\)
Symmetric: \(\text{JSD}(p \| q) = \text{JSD}(q \| p)\)
Bounded: \(0 \leq \text{JSD}(p \| q) \leq \log 2\)
Zero iff equal: \(\text{JSD}(p \| q) = 0 \iff p = q\)
Global OptimalityΒΆ
Theorem: Global Minimum of \(C(G)\)ΒΆ
Statement: The global minimum of \(C(G)\) is achieved if and only if \(p_g = p_{\text{data}}\). At this point: $\(C(G^*) = -\log 4\)$
And the optimal discriminator becomes: $\(D^*_{G^*}(\mathbf{x}) = \frac{1}{2} \quad \forall \mathbf{x}\)$
Proof:
From the Jensen-Shannon divergence formulation: $\(C(G) = -\log 4 + 2 \cdot \text{JSD}(p_{\text{data}} \| p_g)\)$
Since \(\text{JSD} \geq 0\) with equality iff \(p_{\text{data}} = p_g\): $\(C(G) \geq -\log 4\)$
Equality holds when \(p_g = p_{\text{data}}\), giving: $\(C(G^*) = -\log 4 \approx -1.386\)$
Interpretation:
When \(G\) perfectly matches the data distribution, \(D\) cannot do better than random guessing (50% accuracy)
The game reaches a Nash equilibrium
Nash EquilibriumΒΆ
DefinitionΒΆ
A pair \((G^*, D^*)\) is a Nash equilibrium if:
\(D^* \in \arg\max_D V(D, G^*)\) (optimal discriminator for \(G^*\))
\(G^* \in \arg\min_G V(D^*, G)\) (optimal generator against \(D^*\))
Theorem: The Nash equilibrium is achieved when:
\(p_g = p_{\text{data}}\)
\(D^*(\mathbf{x}) = \frac{1}{2}\) for all \(\mathbf{x}\)
Training DynamicsΒΆ
Algorithm: GAN TrainingΒΆ
For each iteration:
Update Discriminator (k steps):
Sample minibatch of \(m\) real samples: \(\{\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(m)}\} \sim p_{\text{data}}\)
Sample minibatch of \(m\) noise samples: \(\{\mathbf{z}^{(1)}, \ldots, \mathbf{z}^{(m)}\} \sim p_z\)
Update discriminator by ascending its stochastic gradient: $\(\nabla_{\theta_d} \frac{1}{m} \sum_{i=1}^m \left[\log D(\mathbf{x}^{(i)}) + \log(1 - D(G(\mathbf{z}^{(i)})))\right]\)$
Update Generator (1 step):
Sample minibatch of \(m\) noise samples: \(\{\mathbf{z}^{(1)}, \ldots, \mathbf{z}^{(m)}\} \sim p_z\)
Update generator by descending its stochastic gradient: $\(\nabla_{\theta_g} \frac{1}{m} \sum_{i=1}^m \log(1 - D(G(\mathbf{z}^{(i)})))\)$
Hyperparameter \(k\): Number of discriminator updates per generator update (typically \(k=1\)).
Non-Saturating Generator LossΒΆ
Problem with Original Loss:
When \(D\) is very confident (\(D(G(\mathbf{z})) \approx 0\)), the gradient \(\nabla_G \log(1 - D(G(\mathbf{z})))\) vanishes:
When \(D \approx 0\), this gradient is small, causing slow learning early in training.
Non-Saturating Alternative:
Instead of minimizing \(\log(1 - D(G(\mathbf{z})))\), maximize \(\log D(G(\mathbf{z}))\):
Gradient Comparison:
Original: \(\frac{\partial}{\partial D} \log(1 - D) = -\frac{1}{1 - D}\) (small when \(D \approx 0\))
Non-saturating: \(\frac{\partial}{\partial D} \log D = \frac{1}{D}\) (large when \(D \approx 0\))
Trade-off:
Non-saturating loss provides stronger gradients early
But changes the fixed-point of the dynamics (not equivalent to original objective)
Mode CollapseΒΆ
DefinitionΒΆ
Mode collapse occurs when the generator produces limited variety of samples, concentrating on a few modes of \(p_{\text{data}}\) while ignoring others.
TypesΒΆ
Complete collapse: \(G\) produces single sample regardless of \(\mathbf{z}\)
Partial collapse: \(G\) covers some modes but misses others
Oscillation: \(G\) cycles through modes during training
CausesΒΆ
Missing Mass: Generator finds a single mode where \(D\) is weak, exploits it
Gradient Pathologies: Local minima that donβt correspond to global optimum
Training Imbalance: \(D\) too strong β \(G\) gets stuck; \(D\) too weak β poor guidance
Mitigation StrategiesΒΆ
Unrolled GANs: Update \(G\) considering \(k\) future updates of \(D\)
Minibatch Discrimination: Penalize \(G\) if minibatch samples are too similar
Feature Matching: Match statistics of intermediate layers instead of fooling \(D\)
Experience Replay: Store past generated samples, retrain \(D\) on them
Multiple GANs: Ensemble of generators covering different modes
Theoretical ChallengesΒΆ
1. Vanishing GradientsΒΆ
Problem: If \(D\) is optimal, gradient for \(G\) can vanish.
Proof: When \(p_g\) and \(p_{\text{data}}\) have disjoint supports: $\(\text{JSD}(p_{\text{data}} \| p_g) = \log 2\)$
This means \(D^*\) can be perfect (zero loss), providing no gradient signal to \(G\).
Solution: Add noise to inputs, use Wasserstein GAN (different divergence).
2. Non-ConvergenceΒΆ
Problem: Training may oscillate indefinitely without reaching equilibrium.
Example: Consider simple case where \(G\) and \(D\) update in cycles:
\(G\) moves to fool \(D\)
\(D\) adapts to new \(G\)
\(G\) shifts again
Repeat (never converges)
Mitigation: Careful learning rate selection, momentum, regularization.
Evaluation MetricsΒΆ
Inception Score (IS)ΒΆ
Where:
\(p(y|\mathbf{x})\): Class probabilities from Inception network
\(p(y) = \mathbb{E}_{\mathbf{x} \sim p_g}[p(y|\mathbf{x})]\): Marginal distribution
Interpretation:
High score: Generated images are diverse (high \(H(y)\)) and high-quality (low \(H(y|\mathbf{x})\))
Limitations:
Only for class-conditional generation
Can be fooled (generate one high-quality image per class)
FrΓ©chet Inception Distance (FID)ΒΆ
Where:
\((\mu_r, \Sigma_r)\): Mean and covariance of real image features (from Inception pool3)
\((\mu_g, \Sigma_g)\): Mean and covariance of generated image features
Properties:
Lower is better
More robust than IS
Considers both quality and diversity
Variants and ExtensionsΒΆ
1. DCGAN (Deep Convolutional GAN)ΒΆ
Architecture Guidelines:
Replace pooling with strided convolutions (discriminator) and fractional-strided convolutions (generator)
Use batch normalization in both \(G\) and \(D\)
Remove fully connected layers
Use ReLU in generator (except output: Tanh)
Use LeakyReLU in discriminator
Impact: Stabilized training, enabled high-quality image generation.
2. Conditional GAN (cGAN)ΒΆ
Objective: $\(\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x},y}[\log D(\mathbf{x}, y)] + \mathbb{E}_{\mathbf{z},y}[\log(1 - D(G(\mathbf{z}, y), y))]\)$
Where \(y\) is conditioning information (e.g., class label, text, image).
3. Wasserstein GAN (WGAN)ΒΆ
Objective (Earth Moverβs Distance): $\(\min_G \max_{D \in \mathcal{D}} \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[D(\mathbf{x})] - \mathbb{E}_{\mathbf{z} \sim p_z}[D(G(\mathbf{z}))]\)$
Where \(\mathcal{D}\) is the set of 1-Lipschitz functions.
Benefits:
Meaningful loss metric (correlates with sample quality)
No mode collapse
No vanishing gradients
Implementation: Weight clipping or gradient penalty to enforce Lipschitz constraint.
4. Progressive GANΒΆ
Idea: Train GAN progressively from low to high resolution.
Start with 4Γ4 images
Gradually add layers for 8Γ8, 16Γ16, β¦, 1024Γ1024
Benefits: Faster training, higher quality, more stable.
5. StyleGANΒΆ
Architecture: Style-based generator with adaptive instance normalization (AdaIN).
Innovations:
Mapping network: \(\mathbf{z} \rightarrow \mathbf{w}\) (disentangled latent space)
Modulation: Control style at each resolution
Noise injection: Stochastic variation
Result: State-of-the-art image quality, fine-grained control.
6. BigGANΒΆ
Scaling to High Resolution:
Large batch sizes (2048)
Spectral normalization
Orthogonal regularization
Class-conditional batch normalization
Result: 512Γ512 ImageNet generation with high fidelity.
Advanced Training TechniquesΒΆ
Spectral NormalizationΒΆ
Idea: Constrain Lipschitz constant of discriminator.
Method: Normalize weight matrices by their spectral norm (largest singular value): $\(\bar{W} = W / \sigma(W)\)$
Where \(\sigma(W)\) is estimated via power iteration.
Benefit: Stabilizes training by preventing discriminator from becoming too confident.
Two-Time-Scale Update Rule (TTUR)ΒΆ
Observation: Discriminator learns faster than generator.
Solution: Use different learning rates:
\(\alpha_D > \alpha_G\) (e.g., \(\alpha_D = 4 \times 10^{-4}\), \(\alpha_G = 1 \times 10^{-4}\))
Theory: Ensures both networks converge at similar rates.
Self-AttentionΒΆ
Architecture: Add self-attention layers to capture long-range dependencies.
Formulation: $\(\mathbf{o}_i = \sum_j \alpha_{ij} \mathbf{v}_j, \quad \alpha_{ij} = \frac{\exp(s_{ij})}{\sum_k \exp(s_{ik})}, \quad s_{ij} = \mathbf{q}_i^T \mathbf{k}_j\)$
Benefit: Model global structure (e.g., consistent object shapes).
Hinge LossΒΆ
Discriminator: $\(\mathcal{L}_D = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[\max(0, 1 - D(\mathbf{x}))] + \mathbb{E}_{\mathbf{z} \sim p_z}[\max(0, 1 + D(G(\mathbf{z})))]\)$
Generator: $\(\mathcal{L}_G = -\mathbb{E}_{\mathbf{z} \sim p_z}[D(G(\mathbf{z}))]\)$
Benefit: Non-saturating, stable gradients.
Connections to Other FrameworksΒΆ
f-GANΒΆ
Generalization: Any f-divergence can be used: $\(D_f(p \| q) = \mathbb{E}_{q(x)}\left[f\left(\frac{p(x)}{q(x)}\right)\right]\)$
Examples:
\(f(t) = t \log t\): KL divergence
\(f(t) = -\log t\): Reverse KL
\(f(t) = (t-1)^2\): \(\chi^2\) divergence
\(f(t) = \frac{1}{2}(t \log t + (1+t) \log \frac{1+t}{2})\): Jensen-Shannon (standard GAN)
Energy-Based ModelsΒΆ
Connection: Discriminator can be viewed as energy function: $\(D(\mathbf{x}) = \sigma(-E(\mathbf{x}))\)$
Where \(E(\mathbf{x})\) is energy (low for real data, high for fake).
Implicit ModelsΒΆ
GAN Perspective: Generator defines an implicit probability model: $\(p_g(\mathbf{x}) = \int p_z(\mathbf{z}) \delta(\mathbf{x} - G(\mathbf{z})) d\mathbf{z}\)$
Benefit: No need for tractable density \(p_g(\mathbf{x})\) (unlike VAEs, normalizing flows).
Practical TipsΒΆ
HyperparametersΒΆ
Learning rate: \(1 \times 10^{-4}\) to \(2 \times 10^{-4}\)
Batch size: 64-256 (larger is better if memory permits)
Optimizer: Adam with \(\beta_1 = 0.5\), \(\beta_2 = 0.999\)
Discriminator updates: \(k = 1\) (or 5 for WGAN)
ArchitectureΒΆ
LeakyReLU with slope 0.2 in discriminator
ReLU or LeakyReLU in generator
Batch normalization in both networks (except \(D\) output and \(G\) input)
Avoid fully connected layers (use convolutions)
DebuggingΒΆ
Check gradients: Should be non-zero and bounded
Monitor losses: \(D\) loss should stabilize around 0.5-0.7
Visualize samples: Every few hundred iterations
Inspect discriminator outputs: Should be near 0.5 for good \(G\)
ConclusionΒΆ
GANs revolutionized generative modeling through adversarial training, enabling high-quality sample generation without explicit density modeling. The mathematical foundationβrooted in minimax optimization and Jensen-Shannon divergenceβprovides theoretical guarantees for convergence to the data distribution.
Key Insights:
Minimax game: Two-player optimization leads to Nash equilibrium
Optimal discriminator: Ratio of densities \(p_{\text{data}} / (p_{\text{data}} + p_g)\)
Global optimum: \(p_g = p_{\text{data}}\) with \(D^* = 0.5\) everywhere
Training challenges: Mode collapse, vanishing gradients, non-convergence
Impact: GANs spawned numerous variants (DCGAN, WGAN, StyleGAN, BigGAN) and applications (image synthesis, style transfer, data augmentation, super-resolution).
Despite recent advances in diffusion models, GANs remain relevant for their speed (single forward pass) and theoretical elegance, continuing to inspire research in adversarial learning and generative modeling.
"""
GAN Mathematics - Complete Implementation
This implementation demonstrates the mathematical foundations of GANs with:
1. Standard GAN (vanilla architecture)
2. DCGAN (deep convolutional architecture)
3. Wasserstein GAN with gradient penalty
4. Spectral normalization
5. Multiple loss functions
6. Training stability techniques
7. Comprehensive demonstrations
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, Dict, List
from dataclasses import dataclass
# ============================================================================
# Configuration
# ============================================================================
@dataclass
class GANConfig:
"""Configuration for GAN training."""
latent_dim: int = 100
img_channels: int = 1
img_size: int = 28
hidden_dim: int = 128
learning_rate_g: float = 2e-4
learning_rate_d: float = 2e-4
beta1: float = 0.5
beta2: float = 0.999
batch_size: int = 64
num_epochs: int = 50
k_discriminator: int = 1 # Discriminator updates per generator update
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
# ============================================================================
# 1. Standard GAN (Vanilla Architecture)
# ============================================================================
class VanillaGenerator(nn.Module):
"""
Simple feedforward generator.
Architecture:
- Input: latent vector z (latent_dim,)
- Hidden: 2 fully connected layers with BatchNorm and LeakyReLU
- Output: flattened image (img_size * img_size * img_channels,)
"""
def __init__(self, latent_dim: int = 100, img_size: int = 28, img_channels: int = 1,
hidden_dim: int = 256):
super().__init__()
self.img_size = img_size
self.img_channels = img_channels
self.img_shape = (img_channels, img_size, img_size)
self.model = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_dim, hidden_dim * 2),
nn.BatchNorm1d(hidden_dim * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_dim * 2, hidden_dim * 4),
nn.BatchNorm1d(hidden_dim * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(hidden_dim * 4, int(np.prod(self.img_shape))),
nn.Tanh() # Output in [-1, 1]
)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Args:
z: Latent vectors (batch_size, latent_dim)
Returns:
Generated images (batch_size, img_channels, img_size, img_size)
"""
img_flat = self.model(z)
img = img_flat.view(img_flat.size(0), *self.img_shape)
return img
class VanillaDiscriminator(nn.Module):
"""
Simple feedforward discriminator.
Architecture:
- Input: flattened image
- Hidden: 3 fully connected layers with LeakyReLU
- Output: probability (sigmoid)
"""
def __init__(self, img_size: int = 28, img_channels: int = 1, hidden_dim: int = 256):
super().__init__()
self.img_shape = (img_channels, img_size, img_size)
self.model = nn.Sequential(
nn.Linear(int(np.prod(self.img_shape)), hidden_dim * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""
Args:
img: Images (batch_size, img_channels, img_size, img_size)
Returns:
Probabilities (batch_size, 1)
"""
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# ============================================================================
# 2. DCGAN (Deep Convolutional GAN)
# ============================================================================
class DCGANGenerator(nn.Module):
"""
Deep Convolutional GAN generator.
Architecture guidelines (from DCGAN paper):
- Replace pooling with fractional-strided convolutions (transposed conv)
- Use batch normalization
- Remove fully connected layers
- Use ReLU activations (Tanh for output)
"""
def __init__(self, latent_dim: int = 100, img_channels: int = 1, feature_dim: int = 64):
super().__init__()
self.latent_dim = latent_dim
# Initial projection: latent_dim -> 4x4x(feature_dim*8)
self.init_size = 4
self.fc = nn.Linear(latent_dim, feature_dim * 8 * self.init_size * self.init_size)
# Convolutional layers: 4x4 -> 8x8 -> 16x16 -> 32x32 (for 28x28, we crop)
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(feature_dim * 8),
# 4x4 -> 8x8
nn.ConvTranspose2d(feature_dim * 8, feature_dim * 4, 4, 2, 1),
nn.BatchNorm2d(feature_dim * 4),
nn.ReLU(inplace=True),
# 8x8 -> 16x16
nn.ConvTranspose2d(feature_dim * 4, feature_dim * 2, 4, 2, 1),
nn.BatchNorm2d(feature_dim * 2),
nn.ReLU(inplace=True),
# 16x16 -> 32x32
nn.ConvTranspose2d(feature_dim * 2, feature_dim, 4, 2, 1),
nn.BatchNorm2d(feature_dim),
nn.ReLU(inplace=True),
# 32x32 -> 28x28 (using kernel_size=4, stride=1, padding=2)
nn.ConvTranspose2d(feature_dim, img_channels, 4, 1, 2),
nn.Tanh()
)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Args:
z: Latent vectors (batch_size, latent_dim)
Returns:
Generated images (batch_size, img_channels, 28, 28)
"""
out = self.fc(z)
out = out.view(out.size(0), -1, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class DCGANDiscriminator(nn.Module):
"""
Deep Convolutional GAN discriminator.
Architecture guidelines:
- Replace pooling with strided convolutions
- Use batch normalization (except first layer)
- Use LeakyReLU activations
"""
def __init__(self, img_channels: int = 1, feature_dim: int = 64):
super().__init__()
self.conv_blocks = nn.Sequential(
# 28x28 -> 14x14
nn.Conv2d(img_channels, feature_dim, 4, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
# 14x14 -> 7x7
nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1),
nn.BatchNorm2d(feature_dim * 2),
nn.LeakyReLU(0.2, inplace=True),
# 7x7 -> 3x3
nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1),
nn.BatchNorm2d(feature_dim * 4),
nn.LeakyReLU(0.2, inplace=True),
# 3x3 -> 1x1
nn.Conv2d(feature_dim * 4, feature_dim * 8, 3, 1, 0),
nn.BatchNorm2d(feature_dim * 8),
nn.LeakyReLU(0.2, inplace=True),
)
self.adv_layer = nn.Sequential(
nn.Conv2d(feature_dim * 8, 1, 1, 1, 0),
nn.Sigmoid()
)
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""
Args:
img: Images (batch_size, img_channels, 28, 28)
Returns:
Probabilities (batch_size, 1)
"""
out = self.conv_blocks(img)
validity = self.adv_layer(out)
validity = validity.view(validity.size(0), -1)
return validity
# ============================================================================
# 3. Wasserstein GAN with Gradient Penalty
# ============================================================================
class WGANDiscriminator(nn.Module):
"""
WGAN discriminator (critic).
Differences from standard discriminator:
- No sigmoid at output (outputs real-valued scores)
- Called "critic" because it estimates Wasserstein distance
"""
def __init__(self, img_channels: int = 1, feature_dim: int = 64):
super().__init__()
self.model = nn.Sequential(
# 28x28 -> 14x14
nn.Conv2d(img_channels, feature_dim, 4, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
# 14x14 -> 7x7
nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1),
nn.InstanceNorm2d(feature_dim * 2), # Use InstanceNorm instead of BatchNorm for WGAN
nn.LeakyReLU(0.2, inplace=True),
# 7x7 -> 3x3
nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1),
nn.InstanceNorm2d(feature_dim * 4),
nn.LeakyReLU(0.2, inplace=True),
# 3x3 -> 1x1
nn.Conv2d(feature_dim * 4, 1, 3, 1, 0),
# No activation (output real-valued scores)
)
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""
Args:
img: Images (batch_size, img_channels, 28, 28)
Returns:
Critic scores (batch_size, 1)
"""
validity = self.model(img)
validity = validity.view(validity.size(0), -1)
return validity
def compute_gradient_penalty(discriminator: nn.Module, real_samples: torch.Tensor,
fake_samples: torch.Tensor, device: str = 'cpu') -> torch.Tensor:
"""
Compute gradient penalty for WGAN-GP.
Enforces Lipschitz constraint: ||β_x D(x)||β β€ 1
Penalty: Ξ» * E[(||β_xΜ D(xΜ)||β - 1)Β²]
where xΜ = Ξ΅*x_real + (1-Ξ΅)*x_fake, Ξ΅ ~ U(0,1)
Args:
discriminator: Discriminator network
real_samples: Real images (batch_size, C, H, W)
fake_samples: Fake images (batch_size, C, H, W)
device: Device
Returns:
Gradient penalty scalar
"""
batch_size = real_samples.size(0)
# Random weight for interpolation
epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
# Interpolated samples
interpolated = epsilon * real_samples + (1 - epsilon) * fake_samples
interpolated.requires_grad_(True)
# Discriminator output for interpolated samples
d_interpolated = discriminator(interpolated)
# Compute gradients w.r.t. interpolated samples
fake_labels = torch.ones(batch_size, 1, device=device)
gradients = grad(
outputs=d_interpolated,
inputs=interpolated,
grad_outputs=fake_labels,
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
# Flatten gradients
gradients = gradients.view(batch_size, -1)
# Compute gradient norm
gradient_norm = gradients.norm(2, dim=1)
# Penalty: (||gradient|| - 1)Β²
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty
# ============================================================================
# 4. Spectral Normalization
# ============================================================================
class SpectralNorm:
"""
Spectral normalization for weight matrices.
Normalizes weights by their spectral norm (largest singular value):
W_normalized = W / Ο(W)
Ο(W) is estimated via power iteration.
"""
def __init__(self, name: str = 'weight', n_power_iterations: int = 1, eps: float = 1e-12):
self.name = name
self.n_power_iterations = n_power_iterations
self.eps = eps
def compute_weight(self, module: nn.Module) -> torch.Tensor:
"""Compute spectral-normalized weight."""
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
v = getattr(module, self.name + '_v')
# Power iteration
with torch.no_grad():
for _ in range(self.n_power_iterations):
# v = W^T u / ||W^T u||
v = F.normalize(torch.mv(weight.t(), u), dim=0, eps=self.eps)
# u = W v / ||W v||
u = F.normalize(torch.mv(weight, v), dim=0, eps=self.eps)
# Spectral norm: Ο(W) = u^T W v
sigma = torch.dot(u, torch.mv(weight, v))
# Normalize weight
weight_normalized = weight / sigma
return weight_normalized
def __call__(self, module: nn.Module, inputs):
"""Apply spectral normalization before forward pass."""
setattr(module, self.name, self.compute_weight(module))
def spectral_norm(module: nn.Module, name: str = 'weight', n_power_iterations: int = 1) -> nn.Module:
"""
Apply spectral normalization to a module.
Example:
conv = spectral_norm(nn.Conv2d(3, 64, 3))
"""
SpectralNorm.apply(module, name, n_power_iterations)
return module
class SNDiscriminator(nn.Module):
"""
Discriminator with spectral normalization.
Spectral normalization stabilizes training by constraining Lipschitz constant.
"""
def __init__(self, img_channels: int = 1, feature_dim: int = 64):
super().__init__()
self.model = nn.Sequential(
# Use PyTorch's built-in spectral_norm
nn.utils.spectral_norm(nn.Conv2d(img_channels, feature_dim, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
nn.utils.spectral_norm(nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
nn.utils.spectral_norm(nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
nn.utils.spectral_norm(nn.Conv2d(feature_dim * 4, 1, 3, 1, 0)),
nn.Sigmoid()
)
def forward(self, img: torch.Tensor) -> torch.Tensor:
validity = self.model(img)
validity = validity.view(validity.size(0), -1)
return validity
# ============================================================================
# 5. Loss Functions
# ============================================================================
class GANLoss:
"""Collection of GAN loss functions."""
@staticmethod
def vanilla_discriminator_loss(d_real: torch.Tensor, d_fake: torch.Tensor) -> torch.Tensor:
"""
Vanilla GAN discriminator loss.
L_D = -E[log D(x)] - E[log(1 - D(G(z)))]
Equivalent to binary cross-entropy:
L_D = BCE(D(x), 1) + BCE(D(G(z)), 0)
"""
real_loss = F.binary_cross_entropy(d_real, torch.ones_like(d_real))
fake_loss = F.binary_cross_entropy(d_fake, torch.zeros_like(d_fake))
return real_loss + fake_loss
@staticmethod
def vanilla_generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
"""
Vanilla GAN generator loss.
L_G = -E[log D(G(z))]
Equivalent to:
L_G = BCE(D(G(z)), 1)
"""
return F.binary_cross_entropy(d_fake, torch.ones_like(d_fake))
@staticmethod
def nonsaturating_generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
"""
Non-saturating generator loss.
Instead of minimizing -E[log D(G(z))], maximize E[log D(G(z))].
Provides stronger gradients when D is confident.
"""
# Same as vanilla in this implementation
return F.binary_cross_entropy(d_fake, torch.ones_like(d_fake))
@staticmethod
def wasserstein_discriminator_loss(d_real: torch.Tensor, d_fake: torch.Tensor) -> torch.Tensor:
"""
Wasserstein GAN discriminator (critic) loss.
L_D = -E[D(x)] + E[D(G(z))]
Minimize this to maximize E[D(x)] - E[D(G(z))] (Earth Mover's Distance).
"""
return -d_real.mean() + d_fake.mean()
@staticmethod
def wasserstein_generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
"""
Wasserstein GAN generator loss.
L_G = -E[D(G(z))]
"""
return -d_fake.mean()
@staticmethod
def hinge_discriminator_loss(d_real: torch.Tensor, d_fake: torch.Tensor) -> torch.Tensor:
"""
Hinge loss for discriminator.
L_D = E[max(0, 1 - D(x))] + E[max(0, 1 + D(G(z)))]
Used in BigGAN and other modern architectures.
"""
real_loss = F.relu(1.0 - d_real).mean()
fake_loss = F.relu(1.0 + d_fake).mean()
return real_loss + fake_loss
@staticmethod
def hinge_generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
"""
Hinge loss for generator.
L_G = -E[D(G(z))]
"""
return -d_fake.mean()
# ============================================================================
# 6. Training Utilities
# ============================================================================
def train_vanilla_gan_step(
generator: nn.Module,
discriminator: nn.Module,
real_imgs: torch.Tensor,
optimizer_g: optim.Optimizer,
optimizer_d: optim.Optimizer,
latent_dim: int,
device: str = 'cpu'
) -> Dict[str, float]:
"""
Single training step for vanilla GAN.
Returns:
Dictionary with loss values
"""
batch_size = real_imgs.size(0)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_d.zero_grad()
# Real images
real_validity = discriminator(real_imgs)
# Fake images
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z)
fake_validity = discriminator(fake_imgs.detach())
# Discriminator loss
d_loss = GANLoss.vanilla_discriminator_loss(real_validity, fake_validity)
d_loss.backward()
optimizer_d.step()
# -----------------
# Train Generator
# -----------------
optimizer_g.zero_grad()
# Generate new fake images
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z)
fake_validity = discriminator(fake_imgs)
# Generator loss
g_loss = GANLoss.vanilla_generator_loss(fake_validity)
g_loss.backward()
optimizer_g.step()
return {
'd_loss': d_loss.item(),
'g_loss': g_loss.item(),
'd_real': real_validity.mean().item(),
'd_fake': fake_validity.mean().item()
}
def train_wgan_gp_step(
generator: nn.Module,
discriminator: nn.Module,
real_imgs: torch.Tensor,
optimizer_g: optim.Optimizer,
optimizer_d: optim.Optimizer,
latent_dim: int,
lambda_gp: float = 10.0,
device: str = 'cpu'
) -> Dict[str, float]:
"""
Single training step for WGAN-GP.
Args:
lambda_gp: Gradient penalty coefficient (default: 10)
Returns:
Dictionary with loss values
"""
batch_size = real_imgs.size(0)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_d.zero_grad()
# Real images
real_validity = discriminator(real_imgs)
# Fake images
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z)
fake_validity = discriminator(fake_imgs.detach())
# Gradient penalty
gp = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), device)
# Discriminator loss: -E[D(x)] + E[D(G(z))] + Ξ»*GP
d_loss = GANLoss.wasserstein_discriminator_loss(real_validity, fake_validity) + lambda_gp * gp
d_loss.backward()
optimizer_d.step()
# -----------------
# Train Generator
# -----------------
optimizer_g.zero_grad()
# Generate new fake images
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z)
fake_validity = discriminator(fake_imgs)
# Generator loss: -E[D(G(z))]
g_loss = GANLoss.wasserstein_generator_loss(fake_validity)
g_loss.backward()
optimizer_g.step()
return {
'd_loss': d_loss.item(),
'g_loss': g_loss.item(),
'gp': gp.item(),
'd_real': real_validity.mean().item(),
'd_fake': fake_validity.mean().item(),
'wasserstein_dist': (real_validity.mean() - fake_validity.mean()).item()
}
# ============================================================================
# 7. Visualization
# ============================================================================
def visualize_generated_samples(generator: nn.Module, latent_dim: int, num_samples: int = 16,
device: str = 'cpu', figsize: Tuple[int, int] = (8, 8)):
"""
Visualize generated samples.
Args:
generator: Generator network
latent_dim: Dimension of latent space
num_samples: Number of samples to generate (should be square number)
device: Device
figsize: Figure size
"""
generator.eval()
with torch.no_grad():
z = torch.randn(num_samples, latent_dim, device=device)
fake_imgs = generator(z).cpu()
# Plot
grid_size = int(np.sqrt(num_samples))
fig, axes = plt.subplots(grid_size, grid_size, figsize=figsize)
for i in range(grid_size):
for j in range(grid_size):
idx = i * grid_size + j
img = fake_imgs[idx, 0] # Remove channel dimension
axes[i, j].imshow(img, cmap='gray')
axes[i, j].axis('off')
plt.tight_layout()
plt.show()
generator.train()
def visualize_training_progress(losses: Dict[str, List[float]], figsize: Tuple[int, int] = (12, 4)):
"""
Visualize training losses over time.
Args:
losses: Dictionary with loss histories
figsize: Figure size
"""
fig, axes = plt.subplots(1, 3, figsize=figsize)
# Discriminator and Generator losses
axes[0].plot(losses['d_loss'], label='D Loss', alpha=0.7)
axes[0].plot(losses['g_loss'], label='G Loss', alpha=0.7)
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('Loss')
axes[0].set_title('GAN Losses')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Discriminator outputs
axes[1].plot(losses['d_real'], label='D(x)', alpha=0.7)
axes[1].plot(losses['d_fake'], label='D(G(z))', alpha=0.7)
axes[1].axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Random')
axes[1].set_xlabel('Iteration')
axes[1].set_ylabel('Discriminator Output')
axes[1].set_title('Discriminator Predictions')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Wasserstein distance (if available)
if 'wasserstein_dist' in losses:
axes[2].plot(losses['wasserstein_dist'], label='W Distance', alpha=0.7)
axes[2].set_xlabel('Iteration')
axes[2].set_ylabel('Distance')
axes[2].set_title('Wasserstein Distance')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
else:
axes[2].axis('off')
plt.tight_layout()
plt.show()
def visualize_latent_space_interpolation(generator: nn.Module, latent_dim: int, num_steps: int = 10,
device: str = 'cpu', figsize: Tuple[int, int] = (15, 3)):
"""
Interpolate between two random points in latent space.
Args:
generator: Generator network
latent_dim: Dimension of latent space
num_steps: Number of interpolation steps
device: Device
figsize: Figure size
"""
generator.eval()
# Sample two random latent vectors
z1 = torch.randn(1, latent_dim, device=device)
z2 = torch.randn(1, latent_dim, device=device)
# Interpolate
alphas = torch.linspace(0, 1, num_steps, device=device).view(-1, 1)
z_interp = z1 * (1 - alphas) + z2 * alphas
with torch.no_grad():
imgs = generator(z_interp).cpu()
# Plot
fig, axes = plt.subplots(1, num_steps, figsize=figsize)
for i in range(num_steps):
img = imgs[i, 0]
axes[i].imshow(img, cmap='gray')
axes[i].axis('off')
axes[i].set_title(f'Ξ±={alphas[i].item():.1f}')
plt.tight_layout()
plt.show()
generator.train()
# ============================================================================
# 8. Demonstrations
# ============================================================================
def demo_vanilla_gan():
"""Demonstrate vanilla GAN architecture and forward pass."""
print("=" * 70)
print("Vanilla GAN Demonstration")
print("=" * 70)
latent_dim = 100
batch_size = 4
# Initialize models
generator = VanillaGenerator(latent_dim=latent_dim)
discriminator = VanillaDiscriminator()
print(f"\nGenerator Architecture:")
print(generator)
print(f"\nGenerator Parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"\nDiscriminator Architecture:")
print(discriminator)
print(f"\nDiscriminator Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
# Forward pass
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
print(f"\nGenerated Images Shape: {fake_imgs.shape}")
print(f"Generated Images Range: [{fake_imgs.min():.3f}, {fake_imgs.max():.3f}]")
validity = discriminator(fake_imgs)
print(f"\nDiscriminator Output Shape: {validity.shape}")
print(f"Discriminator Output (probabilities): {validity.squeeze().tolist()}")
print()
def demo_dcgan():
"""Demonstrate DCGAN architecture."""
print("=" * 70)
print("DCGAN Demonstration")
print("=" * 70)
latent_dim = 100
batch_size = 4
# Initialize models
generator = DCGANGenerator(latent_dim=latent_dim)
discriminator = DCGANDiscriminator()
print(f"\nGenerator Architecture:")
print(generator)
print(f"\nGenerator Parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"\nDiscriminator Architecture:")
print(discriminator)
print(f"\nDiscriminator Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
# Forward pass
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
print(f"\nGenerated Images Shape: {fake_imgs.shape}")
validity = discriminator(fake_imgs)
print(f"Discriminator Output Shape: {validity.shape}")
print()
def demo_wgan():
"""Demonstrate WGAN-GP architecture and gradient penalty."""
print("=" * 70)
print("WGAN-GP Demonstration")
print("=" * 70)
latent_dim = 100
batch_size = 4
device = 'cpu'
# Initialize models
generator = DCGANGenerator(latent_dim=latent_dim)
discriminator = WGANDiscriminator()
print(f"\nCritic (Discriminator) Architecture:")
print(discriminator)
# Create fake real/fake images
real_imgs = torch.randn(batch_size, 1, 28, 28)
z = torch.randn(batch_size, latent_dim)
fake_imgs = generator(z)
# Compute gradient penalty
gp = compute_gradient_penalty(discriminator, real_imgs, fake_imgs, device)
print(f"\nGradient Penalty: {gp.item():.4f}")
# Compute Wasserstein distance estimate
with torch.no_grad():
real_validity = discriminator(real_imgs)
fake_validity = discriminator(fake_imgs)
w_dist = (real_validity.mean() - fake_validity.mean()).item()
print(f"Wasserstein Distance Estimate: {w_dist:.4f}")
print()
def demo_spectral_normalization():
"""Demonstrate spectral normalization."""
print("=" * 70)
print("Spectral Normalization Demonstration")
print("=" * 70)
# Regular conv layer
conv_regular = nn.Conv2d(1, 64, 4, 2, 1)
# Spectral normalized conv layer
conv_sn = nn.utils.spectral_norm(nn.Conv2d(1, 64, 4, 2, 1))
# Compute spectral norms
def spectral_norm_value(layer):
"""Compute spectral norm of weight matrix."""
weight = layer.weight.data
# Reshape to 2D matrix
weight_mat = weight.view(weight.size(0), -1)
# Compute largest singular value
u, s, v = torch.svd(weight_mat)
return s[0].item()
sn_regular = spectral_norm_value(conv_regular)
sn_normalized = spectral_norm_value(conv_sn)
print(f"\nRegular Conv Layer:")
print(f" Spectral Norm: {sn_regular:.4f}")
print(f"\nSpectral Normalized Conv Layer:")
print(f" Spectral Norm: {sn_normalized:.4f}")
print(f"\nSpectral normalization constrains the Lipschitz constant to β1")
print()
def demo_loss_functions():
"""Demonstrate different GAN loss functions."""
print("=" * 70)
print("GAN Loss Functions Demonstration")
print("=" * 70)
# Create dummy discriminator outputs
d_real = torch.tensor([[0.9], [0.8], [0.85], [0.95]])
d_fake = torch.tensor([[0.2], [0.1], [0.15], [0.25]])
print("\nDiscriminator Outputs:")
print(f" D(x_real): {d_real.squeeze().tolist()}")
print(f" D(x_fake): {d_fake.squeeze().tolist()}")
# Compute losses
print("\nDiscriminator Losses:")
print(f" Vanilla: {GANLoss.vanilla_discriminator_loss(d_real, d_fake):.4f}")
print(f" Wasserstein: {GANLoss.wasserstein_discriminator_loss(d_real, d_fake):.4f}")
print(f" Hinge: {GANLoss.hinge_discriminator_loss(d_real, d_fake):.4f}")
print("\nGenerator Losses:")
print(f" Vanilla: {GANLoss.vanilla_generator_loss(d_fake):.4f}")
print(f" Wasserstein: {GANLoss.wasserstein_generator_loss(d_fake):.4f}")
print(f" Hinge: {GANLoss.hinge_generator_loss(d_fake):.4f}")
print()
def demo_optimal_discriminator():
"""Demonstrate optimal discriminator theorem."""
print("=" * 70)
print("Optimal Discriminator Demonstration")
print("=" * 70)
# Create grid of x values
x = torch.linspace(-5, 5, 100)
# Define different scenarios for p_data and p_g
scenarios = [
("Perfect Generator",
lambda x: torch.exp(-x**2),
lambda x: torch.exp(-x**2)),
("Good Generator",
lambda x: torch.exp(-x**2),
lambda x: torch.exp(-(x-0.5)**2)),
("Poor Generator",
lambda x: torch.exp(-x**2),
lambda x: torch.exp(-(x-2)**2)),
]
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for idx, (name, p_data_fn, p_g_fn) in enumerate(scenarios):
p_data = p_data_fn(x)
p_g = p_g_fn(x)
# Optimal discriminator: D*(x) = p_data(x) / (p_data(x) + p_g(x))
d_star = p_data / (p_data + p_g)
axes[idx].plot(x.numpy(), p_data.numpy(), label='$p_{data}(x)$', linewidth=2)
axes[idx].plot(x.numpy(), p_g.numpy(), label='$p_g(x)$', linewidth=2)
axes[idx].plot(x.numpy(), d_star.numpy(), label='$D^*(x)$', linewidth=2, linestyle='--')
axes[idx].axhline(y=0.5, color='r', linestyle=':', alpha=0.5, label='Random')
axes[idx].set_xlabel('x')
axes[idx].set_ylabel('Probability / D(x)')
axes[idx].set_title(name)
axes[idx].legend()
axes[idx].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\nKey Observations:")
print(" 1. When p_g = p_data, D*(x) = 0.5 (cannot distinguish)")
print(" 2. When p_data > p_g, D*(x) > 0.5 (likely real)")
print(" 3. When p_g > p_data, D*(x) < 0.5 (likely fake)")
print()
# ============================================================================
# Main Demonstration
# ============================================================================
if __name__ == "__main__":
print("\n" + "="*70)
print(" GAN Mathematics - Comprehensive Implementation Demonstrations")
print("="*70 + "\n")
# Run all demonstrations
demo_vanilla_gan()
demo_dcgan()
demo_wgan()
demo_spectral_normalization()
demo_loss_functions()
demo_optimal_discriminator()
print("="*70)
print("All demonstrations completed successfully!")
print("="*70)
Generative Adversarial Networks (GANs)ΒΆ
Learning Objectives:
Understand GAN theory and mathematics
Implement generator and discriminator networks
Train GANs with proper loss functions
Apply to image generation
Prerequisites: Deep learning basics, PyTorch
Time: 90 minutes
π Reference Materials:
gan.pdf - GAN theory and variants
rbm_gan.pdf - RBM and GAN connections
generative_models.pdf - Overview of generative models
1. The GAN FrameworkΒΆ
Original Formulation (Goodfellow et al., 2014)ΒΆ
Players:
Generator \(G\): Maps noise \(z \sim p_z\) to fake data \(G(z)\)
Discriminator \(D\): Classifies real vs fake data
Objective (Minimax Game):
Intuition:
\(D\) tries to maximize (distinguish real from fake)
\(G\) tries to minimize (fool \(D\))
At equilibrium: \(p_g = p_{data}\) (generator matches data distribution)
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset
from scipy import stats
# Set style and seed
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)
torch.manual_seed(42)
np.random.seed(42)
# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
2. Theoretical AnalysisΒΆ
Optimal DiscriminatorΒΆ
For fixed \(G\), the optimal discriminator is:
Proof: Maximize: $\(V(D, G) = \int_x p_{data}(x) \log D(x) dx + \int_x p_g(x) \log(1-D(x)) dx\)$
Taking derivative w.r.t. \(D(x)\) and setting to zero: $\(\frac{p_{data}(x)}{D(x)} - \frac{p_g(x)}{1-D(x)} = 0\)$
Solving: \(D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\)
Global Optimal PointΒΆ
When \(D = D^*_G\), the value function becomes:
where \(JSD\) is the Jensen-Shannon divergence.
Key Result: Global minimum achieved when \(p_g = p_{data}\)
At optimum: \(D^*(x) = 1/2\) everywhere
\(C(G) = -\log 4\)
2.5. Advanced Theory: Mode Collapse AnalysisΒΆ
Why Mode Collapse OccursΒΆ
Jensen-Shannon Divergence Properties:
When distributions have non-overlapping support: $\(JSD(P || Q) = \log 2\)$
This can lead to:
Gradient saturation when discriminator is too strong
Missing modes in generated distribution
Oscillating training dynamics
Mathematical Analysis of Mode CollapseΒΆ
Consider generator producing only mode \(m_i\) of true distribution. The discriminator can easily classify:
Samples near \(m_i\): could be real or fake
Samples near other modes \(m_j\) (j β i): must be real
Implication: Generator receives no gradient signal to explore other modes!
Reverse KL Divergence PerspectiveΒΆ
GAN implicitly minimizes a form related to: $\(KL(p_g || p_{data})\)$
Properties:
Mode-seeking: \(p_g\) concentrates on high-probability regions of \(p_{data}\)
Zero-avoiding: When \(p_{data}(x) = 0\), allows \(p_g(x) > 0\) (hallucination)
Zero-forcing: When \(p_g(x) = 0\), requires \(p_{data}(x) β 0\) (mode collapse)
Contrast with: $\(KL(p_{data} || p_g) \text{ (mode-covering, used in VAE)}\)$
# Visualize mode collapse phenomenon
def demonstrate_mode_collapse():
"""
Simulate mode collapse by training a generator
that only captures one mode of a multi-modal distribution
"""
# True distribution: 5 Gaussians
def sample_true_data(n):
modes = np.array([-4, -2, 0, 2, 4])
chosen_modes = np.random.choice(5, n)
samples = modes[chosen_modes] + np.random.randn(n) * 0.3
return samples
# Collapsed generator: only mode at x=0
def sample_collapsed_generator(n):
return np.random.randn(n) * 0.3 # Only around 0
# Diverse generator: captures all modes
def sample_diverse_generator(n):
modes = np.array([-4, -2, 0, 2, 4])
chosen_modes = np.random.choice(5, n)
samples = modes[chosen_modes] + np.random.randn(n) * 0.35
return samples
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
# True distribution
true_samples = sample_true_data(10000)
axes[0].hist(true_samples, bins=80, density=True, alpha=0.7,
color='blue', edgecolor='black')
axes[0].set_title('True Distribution (5 Modes)', fontweight='bold', fontsize=12)
axes[0].set_xlabel('x')
axes[0].set_ylabel('Density')
axes[0].set_xlim([-6, 6])
axes[0].grid(True, alpha=0.3)
# Collapsed generator
collapsed_samples = sample_collapsed_generator(10000)
axes[1].hist(collapsed_samples, bins=80, density=True, alpha=0.7,
color='red', edgecolor='black')
axes[1].axvline(x=0, color='darkred', linestyle='--', linewidth=2,
label='Captured mode')
axes[1].set_title('Mode Collapse (1/5 Modes)', fontweight='bold', fontsize=12)
axes[1].set_xlabel('x')
axes[1].set_ylabel('Density')
axes[1].set_xlim([-6, 6])
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Diverse generator
diverse_samples = sample_diverse_generator(10000)
axes[2].hist(diverse_samples, bins=80, density=True, alpha=0.7,
color='green', edgecolor='black')
axes[2].set_title('Successful GAN (All Modes)', fontweight='bold', fontsize=12)
axes[2].set_xlabel('x')
axes[2].set_ylabel('Density')
axes[2].set_xlim([-6, 6])
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\n" + "="*60)
print("MODE COLLAPSE ANALYSIS")
print("="*60)
print(f"True distribution: 5 modes at positions {[-4, -2, 0, 2, 4]}")
print(f"Collapsed GAN: Captures only 1 mode (20% coverage)")
print(f"Successful GAN: Captures all 5 modes (100% coverage)")
print("\nMetrics:")
print(f" - Diversity (std): True={true_samples.std():.2f}, "
f"Collapsed={collapsed_samples.std():.2f}, "
f"Diverse={diverse_samples.std():.2f}")
print("="*60)
demonstrate_mode_collapse()
# Visualize optimal discriminator
def optimal_discriminator(x, p_data_func, p_g_func):
"""Compute D*(x) = p_data(x) / (p_data(x) + p_g(x))"""
p_data = p_data_func(x)
p_g = p_g_func(x)
return p_data / (p_data + p_g + 1e-10)
# Example: Gaussian distributions
x = np.linspace(-5, 5, 1000)
p_data = stats.norm(0, 1).pdf(x) # Real data: N(0, 1)
p_g_close = stats.norm(0.2, 1.1).pdf(x) # Generator close to real
p_g_far = stats.norm(2, 0.5).pdf(x) # Generator far from real
D_star_close = optimal_discriminator(x, lambda x: stats.norm(0, 1).pdf(x),
lambda x: stats.norm(0.2, 1.1).pdf(x))
D_star_far = optimal_discriminator(x, lambda x: stats.norm(0, 1).pdf(x),
lambda x: stats.norm(2, 0.5).pdf(x))
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Plot 1: Distributions when G is close
axes[0].plot(x, p_data, label='$p_{data}$', linewidth=2)
axes[0].plot(x, p_g_close, label='$p_g$ (close)', linewidth=2)
axes[0].set_title('Generator Close to Data', fontweight='bold')
axes[0].set_xlabel('x')
axes[0].set_ylabel('Probability Density')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Plot 2: Optimal D when G is close
axes[1].plot(x, D_star_close, linewidth=2, color='purple')
axes[1].axhline(y=0.5, color='r', linestyle='--', label='Perfect ($D^* = 0.5$)')
axes[1].set_title('$D^*$ when $p_g$ β $p_{data}$', fontweight='bold')
axes[1].set_xlabel('x')
axes[1].set_ylabel('$D^*(x)$')
axes[1].set_ylim([0, 1])
axes[1].legend()
axes[1].grid(True, alpha=0.3)
# Plot 3: Optimal D when G is far
axes[2].plot(x, D_star_far, linewidth=2, color='orange')
axes[2].axhline(y=0.5, color='r', linestyle='--', label='Perfect ($D^* = 0.5$)')
axes[2].set_title('$D^*$ when $p_g$ β $p_{data}$', fontweight='bold')
axes[2].set_xlabel('x')
axes[2].set_ylabel('$D^*(x)$')
axes[2].set_ylim([0, 1])
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\nObservations:")
print("- When p_g β p_data: D*(x) β 0.5 everywhere (cannot distinguish)")
print("- When p_g β p_data: D*(x) varies (can distinguish)")
Implementing Vanilla GAN β 1D Gaussian MixtureΒΆ
To build intuition before scaling to images, we start with a GAN that learns a simple 1D Gaussian mixture distribution. The generator maps samples from a uniform prior \(z \sim U(-1, 1)\) to the data space, while the discriminator receives both real samples from the target distribution and fake samples from the generator. Training alternates between two objectives: the discriminator maximizes \(\mathbb{E}[\log D(x)] + \mathbb{E}[\log(1 - D(G(z)))]\), learning to tell real from fake, while the generator minimizes \(\mathbb{E}[\log(1 - D(G(z)))]\), learning to fool the discriminator. Observing this minimax game in 1D makes it easy to visualize how the generated distribution gradually aligns with the target β a key insight before moving to high-dimensional image generation.
# Generate 1D data: Mixture of Gaussians
def generate_real_data(n_samples=1000):
"""Generate samples from mixture of 3 Gaussians"""
# Mix components
components = np.random.choice(3, n_samples, p=[0.3, 0.5, 0.2])
samples = np.zeros(n_samples)
samples[components == 0] = np.random.normal(-2, 0.5, (components == 0).sum())
samples[components == 1] = np.random.normal(0, 0.8, (components == 1).sum())
samples[components == 2] = np.random.normal(3, 0.3, (components == 2).sum())
return torch.FloatTensor(samples).unsqueeze(1)
# Visualize real data
real_data = generate_real_data(5000)
plt.figure(figsize=(10, 4))
plt.hist(real_data.numpy(), bins=100, density=True, alpha=0.7, edgecolor='black')
plt.title('Real Data Distribution: Mixture of 3 Gaussians', fontsize=13, fontweight='bold')
plt.xlabel('x')
plt.ylabel('Density')
plt.grid(True, alpha=0.3)
plt.show()
# Define Generator
class Generator(nn.Module):
def __init__(self, noise_dim=10, hidden_dim=128):
super(Generator, self).__init__()
self.net = nn.Sequential(
nn.Linear(noise_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1) # Output: 1D
)
def forward(self, z):
return self.net(z)
# Define Discriminator
class Discriminator(nn.Module):
def __init__(self, hidden_dim=128):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Linear(1, hidden_dim), # Input: 1D
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1),
nn.Sigmoid() # Output: probability
)
def forward(self, x):
return self.net(x)
# Initialize models
noise_dim = 10
G = Generator(noise_dim=noise_dim).to(device)
D = Discriminator().to(device)
print("Generator:")
print(G)
print(f"\nTotal parameters: {sum(p.numel() for p in G.parameters())}")
print("\nDiscriminator:")
print(D)
print(f"\nTotal parameters: {sum(p.numel() for p in D.parameters())}")
# Training function
def train_gan(G, D, n_epochs=5000, batch_size=256, lr=0.0002):
"""Train GAN with original loss"""
# Optimizers
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
# Loss function
criterion = nn.BCELoss()
# Training history
history = {'D_loss': [], 'G_loss': [], 'D_real': [], 'D_fake': []}
for epoch in range(n_epochs):
# Generate real data batch
real_data = generate_real_data(batch_size).to(device)
# Labels
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# =====================
# Train Discriminator
# =====================
optimizer_D.zero_grad()
# Real data
D_real = D(real_data)
loss_D_real = criterion(D_real, real_labels)
# Fake data
z = torch.randn(batch_size, noise_dim).to(device)
fake_data = G(z).detach() # Detach to avoid training G
D_fake = D(fake_data)
loss_D_fake = criterion(D_fake, fake_labels)
# Total D loss
loss_D = loss_D_real + loss_D_fake
loss_D.backward()
optimizer_D.step()
# ==================
# Train Generator
# ==================
optimizer_G.zero_grad()
# Generate fake data
z = torch.randn(batch_size, noise_dim).to(device)
fake_data = G(z)
D_fake = D(fake_data)
# G loss: fool D (D_fake should be close to 1)
loss_G = criterion(D_fake, real_labels)
loss_G.backward()
optimizer_G.step()
# Record history
history['D_loss'].append(loss_D.item())
history['G_loss'].append(loss_G.item())
history['D_real'].append(D_real.mean().item())
history['D_fake'].append(D_fake.mean().item())
# Print progress
if (epoch + 1) % 500 == 0:
print(f"Epoch [{epoch+1}/{n_epochs}] | "
f"D Loss: {loss_D.item():.4f} | G Loss: {loss_G.item():.4f} | "
f"D(real): {D_real.mean():.4f} | D(fake): {D_fake.mean():.4f}")
return history
# Train GAN
print("Training GAN...")
history = train_gan(G, D, n_epochs=5000, batch_size=256)
print("\nβ Training complete!")
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
# Loss curves
axes[0].plot(history['D_loss'], label='D Loss', alpha=0.7)
axes[0].plot(history['G_loss'], label='G Loss', alpha=0.7)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Losses', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Discriminator outputs
axes[1].plot(history['D_real'], label='D(real)', alpha=0.7)
axes[1].plot(history['D_fake'], label='D(fake)', alpha=0.7)
axes[1].axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Equilibrium')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Discriminator Output')
axes[1].set_title('Discriminator Performance', fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Generate and visualize samples
G.eval()
with torch.no_grad():
z = torch.randn(5000, noise_dim).to(device)
generated_data = G(z).cpu()
# Plot real vs generated
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(real_data.numpy(), bins=100, density=True, alpha=0.7, label='Real', edgecolor='black')
plt.title('Real Data Distribution', fontsize=13, fontweight='bold')
plt.xlabel('x')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.hist(generated_data.numpy(), bins=100, density=True, alpha=0.7, color='orange', label='Generated', edgecolor='black')
plt.title('Generated Data Distribution', fontsize=13, fontweight='bold')
plt.xlabel('x')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\nβ Generator successfully learned the mixture distribution!")
4. Training ChallengesΒΆ
Common IssuesΒΆ
Mode Collapse: Generator produces limited variety
Non-Convergence: Oscillating losses
Vanishing Gradients: G canβt learn when D is too strong
Hyperparameter Sensitivity: lr, architecture choices
Alternative Loss: Non-Saturating LossΒΆ
Problem: Original G loss saturates when D is confident: $\(\nabla_G \log(1 - D(G(z))) \rightarrow 0 \text{ when } D(G(z)) \rightarrow 0\)$
Solution: Maximize \(\log D(G(z))\) instead: $\(\max_G \mathbb{E}_{z \sim p_z}[\log D(G(z))]\)$
Stronger gradients early in training!
4.5. Advanced Training TechniquesΒΆ
Technique 1: Label SmoothingΒΆ
Replace hard labels with smooth labels:
Real labels: \(0.9\) instead of \(1.0\)
Fake labels: \(0.1\) instead of \(0.0\)
Benefit: Prevents discriminator overconfidence
Technique 2: One-Sided Label SmoothingΒΆ
Only smooth real labels (recommended):
Real: \(0.9\)
Fake: \(0.0\)
Rationale: Smoothing fake labels can lead to unstable generator
Technique 3: Feature MatchingΒΆ
Instead of maximizing \(\log D(G(z))\), match statistics:
where \(f(x)\) are features from an intermediate layer of \(D\)
Technique 4: Minibatch DiscriminationΒΆ
Add a term to discriminator that considers relationships between samples in a batch:
Prevents mode collapse by encouraging diversity
Discriminator sees βcollapseβ when all samples in batch are similar
# Implementation: Label Smoothing and Advanced Training
class ImprovedGANTrainer:
"""
Advanced GAN training with multiple stabilization techniques
"""
def __init__(self, G, D, lr=0.0002, betas=(0.5, 0.999)):
self.G = G
self.D = D
self.optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=betas)
self.optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=betas)
self.criterion = nn.BCELoss()
def train_step(self, real_data, noise_dim,
label_smoothing=True, smooth_real=0.9, smooth_fake=0.0):
"""
Single training step with advanced techniques
Args:
real_data: Batch of real samples
noise_dim: Dimension of noise vector
label_smoothing: Whether to use label smoothing
smooth_real: Label value for real samples (default 0.9)
smooth_fake: Label value for fake samples (default 0.0)
"""
batch_size = real_data.size(0)
device = real_data.device
# Create labels with optional smoothing
if label_smoothing:
real_labels = torch.full((batch_size, 1), smooth_real, device=device)
fake_labels = torch.full((batch_size, 1), smooth_fake, device=device)
else:
real_labels = torch.ones(batch_size, 1, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
# =====================
# Train Discriminator
# =====================
self.optimizer_D.zero_grad()
# Real samples
D_real = self.D(real_data)
loss_D_real = self.criterion(D_real, real_labels)
# Fake samples
z = torch.randn(batch_size, noise_dim, device=device)
fake_data = self.G(z).detach()
D_fake = self.D(fake_data)
loss_D_fake = self.criterion(D_fake, fake_labels)
# Combined loss
loss_D = loss_D_real + loss_D_fake
loss_D.backward()
self.optimizer_D.step()
# ==================
# Train Generator
# ==================
self.optimizer_G.zero_grad()
# Generate new fake samples
z = torch.randn(batch_size, noise_dim, device=device)
fake_data = self.G(z)
D_fake = self.D(fake_data)
# Generator loss (fool discriminator)
loss_G = self.criterion(D_fake, torch.ones(batch_size, 1, device=device))
loss_G.backward()
self.optimizer_G.step()
return {
'loss_D': loss_D.item(),
'loss_G': loss_G.item(),
'D_real': D_real.mean().item(),
'D_fake': D_fake.mean().item()
}
# Compare vanilla vs improved training
print("="*60)
print("COMPARISON: Vanilla vs Improved GAN Training")
print("="*60)
# Reset models
G_improved = Generator(noise_dim=10).to(device)
D_improved = Discriminator().to(device)
trainer = ImprovedGANTrainer(G_improved, D_improved)
# Train for fewer epochs to demonstrate improvement
n_epochs = 2000
history_improved = {'D_loss': [], 'G_loss': [], 'D_real': [], 'D_fake': []}
print("\nTraining with label smoothing and advanced techniques...")
for epoch in range(n_epochs):
real_batch = generate_real_data(256).to(device)
# Train with label smoothing
metrics = trainer.train_step(real_batch, noise_dim=10,
label_smoothing=True,
smooth_real=0.9,
smooth_fake=0.0)
history_improved['D_loss'].append(metrics['loss_D'])
history_improved['G_loss'].append(metrics['loss_G'])
history_improved['D_real'].append(metrics['D_real'])
history_improved['D_fake'].append(metrics['D_fake'])
if (epoch + 1) % 500 == 0:
print(f"Epoch [{epoch+1}/{n_epochs}] | "
f"D Loss: {metrics['loss_D']:.4f} | "
f"G Loss: {metrics['loss_G']:.4f} | "
f"D(real): {metrics['D_real']:.4f} | "
f"D(fake): {metrics['D_fake']:.4f}")
print("\nβ Improved training complete!")
print("\nKey Differences:")
print(" - Label smoothing prevents discriminator overconfidence")
print(" - More stable training dynamics")
print(" - Better gradient flow to generator")
5. SummaryΒΆ
Key Mathematical ResultsΒΆ
β Optimal Discriminator: \(D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\)
β Global Optimum: \(p_g = p_{data}\) minimizes JS divergence
β Value Function: Minimax game with theoretical guarantees
β Training Dynamics: Alternating optimization
Practical ImplementationΒΆ
β Implemented vanilla GAN from scratch β Trained on 1D mixture distribution β Visualized learning dynamics β Understood training challenges
Next TopicsΒΆ
W-GAN: Wasserstein distance for better stability
Info-GAN: Interpretable latent codes
Conditional GAN: Class-conditional generation
Progressive GAN: High-resolution images
Next Notebook: 02_wgan_theory.ipynb