GAN Mathematics: Comprehensive TheoryΒΆ

IntroductionΒΆ

Generative Adversarial Networks (GANs), introduced by Ian Goodfellow et al. in 2014, revolutionized generative modeling by framing it as a two-player minimax game between a generator and discriminator. This adversarial framework enables learning complex data distributions without explicit density estimation.

Core Idea: Train two neural networks in opposition:

  • Generator \(G\): Learns to create fake samples that resemble real data

  • Discriminator \(D\): Learns to distinguish real samples from fake ones

Through this competition, \(G\) improves at generating realistic samples, while \(D\) becomes better at detecting fakes. At equilibrium, \(G\) produces samples indistinguishable from real data.

Mathematical FormulationΒΆ

The Minimax GameΒΆ

GANs are defined by the following minimax objective:

\[\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}(\mathbf{x})}[\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{z} \sim p_z(\mathbf{z})}[\log(1 - D(G(\mathbf{z})))]\]

Where:

  • \(\mathbf{x}\): Real data samples

  • \(\mathbf{z}\): Latent noise vector (typically \(\mathbf{z} \sim \mathcal{N}(0, I)\))

  • \(G(\mathbf{z})\): Generated sample from noise \(\mathbf{z}\)

  • \(D(\mathbf{x}) \in [0, 1]\): Probability that \(\mathbf{x}\) is real

  • \(p_{\text{data}}\): True data distribution

  • \(p_z\): Prior distribution over latent space (usually Gaussian or uniform)

  • \(p_g\): Distribution induced by generator \(G\)

InterpretationΒΆ

Discriminator’s Goal (maximization): $\(\max_D \left[\underbrace{\mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[\log D(\mathbf{x})]}_{\text{correctly classify real}} + \underbrace{\mathbb{E}_{\mathbf{z} \sim p_z}[\log(1 - D(G(\mathbf{z})))]}_{\text{correctly classify fake}}\right]\)$

  • Assign \(D(\mathbf{x}) \approx 1\) for real samples

  • Assign \(D(G(\mathbf{z})) \approx 0\) for fake samples

Generator’s Goal (minimization): $\(\min_G \mathbb{E}_{\mathbf{z} \sim p_z}[\log(1 - D(G(\mathbf{z})))]\)$

  • Make \(D(G(\mathbf{z})) \approx 1\) (fool the discriminator)

  • Equivalent to making generated samples indistinguishable from real

Optimal DiscriminatorΒΆ

Proposition 1: Optimal Discriminator for Fixed GeneratorΒΆ

For a fixed generator \(G\), the optimal discriminator is:

\[D^*_G(\mathbf{x}) = \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\]

Proof:

The objective for \(D\) can be rewritten as:

(1)ΒΆ\[\begin{align} V(D, G) &= \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{x} \sim p_g}[\log(1 - D(\mathbf{x}))] \\ &= \int_{\mathbf{x}} p_{\text{data}}(\mathbf{x}) \log D(\mathbf{x}) + p_g(\mathbf{x}) \log(1 - D(\mathbf{x})) \, d\mathbf{x} \end{align}\]

To maximize w.r.t. \(D(\mathbf{x})\) for each \(\mathbf{x}\), take the derivative and set to zero:

\[\frac{\partial}{\partial D(\mathbf{x})} \left[p_{\text{data}}(\mathbf{x}) \log D(\mathbf{x}) + p_g(\mathbf{x}) \log(1 - D(\mathbf{x}))\right] = 0\]
\[\frac{p_{\text{data}}(\mathbf{x})}{D(\mathbf{x})} - \frac{p_g(\mathbf{x})}{1 - D(\mathbf{x})} = 0\]

Solving for \(D(\mathbf{x})\):

\[D^*_G(\mathbf{x}) = \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\]

Interpretation:

  • If \(p_{\text{data}}(\mathbf{x}) \gg p_g(\mathbf{x})\): \(D^*_G(\mathbf{x}) \approx 1\) (likely real)

  • If \(p_g(\mathbf{x}) \gg p_{\text{data}}(\mathbf{x})\): \(D^*_G(\mathbf{x}) \approx 0\) (likely fake)

  • If \(p_{\text{data}}(\mathbf{x}) = p_g(\mathbf{x})\): \(D^*_G(\mathbf{x}) = 0.5\) (perfectly confused)

Virtual Training CriterionΒΆ

Proposition 2: Generator Training with Optimal DiscriminatorΒΆ

Substituting \(D^*_G\) into the objective yields:

\[C(G) = \max_D V(D, G) = V(D^*_G, G)\]

Expanding:

(2)ΒΆ\[\begin{align} C(G) &= \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}\left[\log \frac{p_{\text{data}}(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\right] + \mathbb{E}_{\mathbf{x} \sim p_g}\left[\log \frac{p_g(\mathbf{x})}{p_{\text{data}}(\mathbf{x}) + p_g(\mathbf{x})}\right] \end{align}\]

Rewrite using the average distribution \(p_m = \frac{p_{\text{data}} + p_g}{2}\):

(3)ΒΆ\[\begin{align} C(G) &= \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}\left[\log \frac{p_{\text{data}}(\mathbf{x})}{2p_m(\mathbf{x})}\right] + \mathbb{E}_{\mathbf{x} \sim p_g}\left[\log \frac{p_g(\mathbf{x})}{2p_m(\mathbf{x})}\right] \\ &= -\log 4 + \text{KL}(p_{\text{data}} \| p_m) + \text{KL}(p_g \| p_m) \\ &= -\log 4 + 2 \cdot \text{JSD}(p_{\text{data}} \| p_g) \end{align}\]

Where \(\text{JSD}\) is the Jensen-Shannon divergence:

\[\text{JSD}(p \| q) = \frac{1}{2}\text{KL}(p \| m) + \frac{1}{2}\text{KL}(q \| m), \quad m = \frac{p + q}{2}\]

Properties of JSD:

  1. Non-negative: \(\text{JSD}(p \| q) \geq 0\)

  2. Symmetric: \(\text{JSD}(p \| q) = \text{JSD}(q \| p)\)

  3. Bounded: \(0 \leq \text{JSD}(p \| q) \leq \log 2\)

  4. Zero iff equal: \(\text{JSD}(p \| q) = 0 \iff p = q\)

Global OptimalityΒΆ

Theorem: Global Minimum of \(C(G)\)ΒΆ

Statement: The global minimum of \(C(G)\) is achieved if and only if \(p_g = p_{\text{data}}\). At this point: $\(C(G^*) = -\log 4\)$

And the optimal discriminator becomes: $\(D^*_{G^*}(\mathbf{x}) = \frac{1}{2} \quad \forall \mathbf{x}\)$

Proof:

From the Jensen-Shannon divergence formulation: $\(C(G) = -\log 4 + 2 \cdot \text{JSD}(p_{\text{data}} \| p_g)\)$

Since \(\text{JSD} \geq 0\) with equality iff \(p_{\text{data}} = p_g\): $\(C(G) \geq -\log 4\)$

Equality holds when \(p_g = p_{\text{data}}\), giving: $\(C(G^*) = -\log 4 \approx -1.386\)$

Interpretation:

  • When \(G\) perfectly matches the data distribution, \(D\) cannot do better than random guessing (50% accuracy)

  • The game reaches a Nash equilibrium

Nash EquilibriumΒΆ

DefinitionΒΆ

A pair \((G^*, D^*)\) is a Nash equilibrium if:

  1. \(D^* \in \arg\max_D V(D, G^*)\) (optimal discriminator for \(G^*\))

  2. \(G^* \in \arg\min_G V(D^*, G)\) (optimal generator against \(D^*\))

Theorem: The Nash equilibrium is achieved when:

  • \(p_g = p_{\text{data}}\)

  • \(D^*(\mathbf{x}) = \frac{1}{2}\) for all \(\mathbf{x}\)

Training DynamicsΒΆ

Algorithm: GAN TrainingΒΆ

For each iteration:

  1. Update Discriminator (k steps):

    • Sample minibatch of \(m\) real samples: \(\{\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(m)}\} \sim p_{\text{data}}\)

    • Sample minibatch of \(m\) noise samples: \(\{\mathbf{z}^{(1)}, \ldots, \mathbf{z}^{(m)}\} \sim p_z\)

    • Update discriminator by ascending its stochastic gradient: $\(\nabla_{\theta_d} \frac{1}{m} \sum_{i=1}^m \left[\log D(\mathbf{x}^{(i)}) + \log(1 - D(G(\mathbf{z}^{(i)})))\right]\)$

  2. Update Generator (1 step):

    • Sample minibatch of \(m\) noise samples: \(\{\mathbf{z}^{(1)}, \ldots, \mathbf{z}^{(m)}\} \sim p_z\)

    • Update generator by descending its stochastic gradient: $\(\nabla_{\theta_g} \frac{1}{m} \sum_{i=1}^m \log(1 - D(G(\mathbf{z}^{(i)})))\)$

Hyperparameter \(k\): Number of discriminator updates per generator update (typically \(k=1\)).

Non-Saturating Generator LossΒΆ

Problem with Original Loss:

When \(D\) is very confident (\(D(G(\mathbf{z})) \approx 0\)), the gradient \(\nabla_G \log(1 - D(G(\mathbf{z})))\) vanishes:

\[\frac{\partial}{\partial D} \log(1 - D) = -\frac{1}{1 - D}\]

When \(D \approx 0\), this gradient is small, causing slow learning early in training.

Non-Saturating Alternative:

Instead of minimizing \(\log(1 - D(G(\mathbf{z})))\), maximize \(\log D(G(\mathbf{z}))\):

\[\max_G \mathbb{E}_{\mathbf{z} \sim p_z}[\log D(G(\mathbf{z}))]\]

Gradient Comparison:

  • Original: \(\frac{\partial}{\partial D} \log(1 - D) = -\frac{1}{1 - D}\) (small when \(D \approx 0\))

  • Non-saturating: \(\frac{\partial}{\partial D} \log D = \frac{1}{D}\) (large when \(D \approx 0\))

Trade-off:

  • Non-saturating loss provides stronger gradients early

  • But changes the fixed-point of the dynamics (not equivalent to original objective)

Mode CollapseΒΆ

DefinitionΒΆ

Mode collapse occurs when the generator produces limited variety of samples, concentrating on a few modes of \(p_{\text{data}}\) while ignoring others.

TypesΒΆ

  1. Complete collapse: \(G\) produces single sample regardless of \(\mathbf{z}\)

  2. Partial collapse: \(G\) covers some modes but misses others

  3. Oscillation: \(G\) cycles through modes during training

CausesΒΆ

  1. Missing Mass: Generator finds a single mode where \(D\) is weak, exploits it

  2. Gradient Pathologies: Local minima that don’t correspond to global optimum

  3. Training Imbalance: \(D\) too strong β†’ \(G\) gets stuck; \(D\) too weak β†’ poor guidance

Mitigation StrategiesΒΆ

  1. Unrolled GANs: Update \(G\) considering \(k\) future updates of \(D\)

  2. Minibatch Discrimination: Penalize \(G\) if minibatch samples are too similar

  3. Feature Matching: Match statistics of intermediate layers instead of fooling \(D\)

  4. Experience Replay: Store past generated samples, retrain \(D\) on them

  5. Multiple GANs: Ensemble of generators covering different modes

Theoretical ChallengesΒΆ

1. Vanishing GradientsΒΆ

Problem: If \(D\) is optimal, gradient for \(G\) can vanish.

Proof: When \(p_g\) and \(p_{\text{data}}\) have disjoint supports: $\(\text{JSD}(p_{\text{data}} \| p_g) = \log 2\)$

This means \(D^*\) can be perfect (zero loss), providing no gradient signal to \(G\).

Solution: Add noise to inputs, use Wasserstein GAN (different divergence).

2. Non-ConvergenceΒΆ

Problem: Training may oscillate indefinitely without reaching equilibrium.

Example: Consider simple case where \(G\) and \(D\) update in cycles:

  • \(G\) moves to fool \(D\)

  • \(D\) adapts to new \(G\)

  • \(G\) shifts again

  • Repeat (never converges)

Mitigation: Careful learning rate selection, momentum, regularization.

Evaluation MetricsΒΆ

Inception Score (IS)ΒΆ

\[\text{IS}(G) = \exp\left(\mathbb{E}_{\mathbf{x} \sim p_g}\left[D_{\text{KL}}(p(y|\mathbf{x}) \| p(y))\right]\right)\]

Where:

  • \(p(y|\mathbf{x})\): Class probabilities from Inception network

  • \(p(y) = \mathbb{E}_{\mathbf{x} \sim p_g}[p(y|\mathbf{x})]\): Marginal distribution

Interpretation:

  • High score: Generated images are diverse (high \(H(y)\)) and high-quality (low \(H(y|\mathbf{x})\))

Limitations:

  • Only for class-conditional generation

  • Can be fooled (generate one high-quality image per class)

FrΓ©chet Inception Distance (FID)ΒΆ

\[\text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}\left(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2}\right)\]

Where:

  • \((\mu_r, \Sigma_r)\): Mean and covariance of real image features (from Inception pool3)

  • \((\mu_g, \Sigma_g)\): Mean and covariance of generated image features

Properties:

  • Lower is better

  • More robust than IS

  • Considers both quality and diversity

Variants and ExtensionsΒΆ

1. DCGAN (Deep Convolutional GAN)ΒΆ

Architecture Guidelines:

  • Replace pooling with strided convolutions (discriminator) and fractional-strided convolutions (generator)

  • Use batch normalization in both \(G\) and \(D\)

  • Remove fully connected layers

  • Use ReLU in generator (except output: Tanh)

  • Use LeakyReLU in discriminator

Impact: Stabilized training, enabled high-quality image generation.

2. Conditional GAN (cGAN)ΒΆ

Objective: $\(\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x},y}[\log D(\mathbf{x}, y)] + \mathbb{E}_{\mathbf{z},y}[\log(1 - D(G(\mathbf{z}, y), y))]\)$

Where \(y\) is conditioning information (e.g., class label, text, image).

3. Wasserstein GAN (WGAN)ΒΆ

Objective (Earth Mover’s Distance): $\(\min_G \max_{D \in \mathcal{D}} \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[D(\mathbf{x})] - \mathbb{E}_{\mathbf{z} \sim p_z}[D(G(\mathbf{z}))]\)$

Where \(\mathcal{D}\) is the set of 1-Lipschitz functions.

Benefits:

  • Meaningful loss metric (correlates with sample quality)

  • No mode collapse

  • No vanishing gradients

Implementation: Weight clipping or gradient penalty to enforce Lipschitz constraint.

4. Progressive GANΒΆ

Idea: Train GAN progressively from low to high resolution.

  • Start with 4Γ—4 images

  • Gradually add layers for 8Γ—8, 16Γ—16, …, 1024Γ—1024

Benefits: Faster training, higher quality, more stable.

5. StyleGANΒΆ

Architecture: Style-based generator with adaptive instance normalization (AdaIN).

Innovations:

  • Mapping network: \(\mathbf{z} \rightarrow \mathbf{w}\) (disentangled latent space)

  • Modulation: Control style at each resolution

  • Noise injection: Stochastic variation

Result: State-of-the-art image quality, fine-grained control.

6. BigGANΒΆ

Scaling to High Resolution:

  • Large batch sizes (2048)

  • Spectral normalization

  • Orthogonal regularization

  • Class-conditional batch normalization

Result: 512Γ—512 ImageNet generation with high fidelity.

Advanced Training TechniquesΒΆ

Spectral NormalizationΒΆ

Idea: Constrain Lipschitz constant of discriminator.

Method: Normalize weight matrices by their spectral norm (largest singular value): $\(\bar{W} = W / \sigma(W)\)$

Where \(\sigma(W)\) is estimated via power iteration.

Benefit: Stabilizes training by preventing discriminator from becoming too confident.

Two-Time-Scale Update Rule (TTUR)ΒΆ

Observation: Discriminator learns faster than generator.

Solution: Use different learning rates:

  • \(\alpha_D > \alpha_G\) (e.g., \(\alpha_D = 4 \times 10^{-4}\), \(\alpha_G = 1 \times 10^{-4}\))

Theory: Ensures both networks converge at similar rates.

Self-AttentionΒΆ

Architecture: Add self-attention layers to capture long-range dependencies.

Formulation: $\(\mathbf{o}_i = \sum_j \alpha_{ij} \mathbf{v}_j, \quad \alpha_{ij} = \frac{\exp(s_{ij})}{\sum_k \exp(s_{ik})}, \quad s_{ij} = \mathbf{q}_i^T \mathbf{k}_j\)$

Benefit: Model global structure (e.g., consistent object shapes).

Hinge LossΒΆ

Discriminator: $\(\mathcal{L}_D = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[\max(0, 1 - D(\mathbf{x}))] + \mathbb{E}_{\mathbf{z} \sim p_z}[\max(0, 1 + D(G(\mathbf{z})))]\)$

Generator: $\(\mathcal{L}_G = -\mathbb{E}_{\mathbf{z} \sim p_z}[D(G(\mathbf{z}))]\)$

Benefit: Non-saturating, stable gradients.

Connections to Other FrameworksΒΆ

f-GANΒΆ

Generalization: Any f-divergence can be used: $\(D_f(p \| q) = \mathbb{E}_{q(x)}\left[f\left(\frac{p(x)}{q(x)}\right)\right]\)$

Examples:

  • \(f(t) = t \log t\): KL divergence

  • \(f(t) = -\log t\): Reverse KL

  • \(f(t) = (t-1)^2\): \(\chi^2\) divergence

  • \(f(t) = \frac{1}{2}(t \log t + (1+t) \log \frac{1+t}{2})\): Jensen-Shannon (standard GAN)

Energy-Based ModelsΒΆ

Connection: Discriminator can be viewed as energy function: $\(D(\mathbf{x}) = \sigma(-E(\mathbf{x}))\)$

Where \(E(\mathbf{x})\) is energy (low for real data, high for fake).

Implicit ModelsΒΆ

GAN Perspective: Generator defines an implicit probability model: $\(p_g(\mathbf{x}) = \int p_z(\mathbf{z}) \delta(\mathbf{x} - G(\mathbf{z})) d\mathbf{z}\)$

Benefit: No need for tractable density \(p_g(\mathbf{x})\) (unlike VAEs, normalizing flows).

Practical TipsΒΆ

HyperparametersΒΆ

  • Learning rate: \(1 \times 10^{-4}\) to \(2 \times 10^{-4}\)

  • Batch size: 64-256 (larger is better if memory permits)

  • Optimizer: Adam with \(\beta_1 = 0.5\), \(\beta_2 = 0.999\)

  • Discriminator updates: \(k = 1\) (or 5 for WGAN)

ArchitectureΒΆ

  • LeakyReLU with slope 0.2 in discriminator

  • ReLU or LeakyReLU in generator

  • Batch normalization in both networks (except \(D\) output and \(G\) input)

  • Avoid fully connected layers (use convolutions)

DebuggingΒΆ

  • Check gradients: Should be non-zero and bounded

  • Monitor losses: \(D\) loss should stabilize around 0.5-0.7

  • Visualize samples: Every few hundred iterations

  • Inspect discriminator outputs: Should be near 0.5 for good \(G\)

ConclusionΒΆ

GANs revolutionized generative modeling through adversarial training, enabling high-quality sample generation without explicit density modeling. The mathematical foundationβ€”rooted in minimax optimization and Jensen-Shannon divergenceβ€”provides theoretical guarantees for convergence to the data distribution.

Key Insights:

  1. Minimax game: Two-player optimization leads to Nash equilibrium

  2. Optimal discriminator: Ratio of densities \(p_{\text{data}} / (p_{\text{data}} + p_g)\)

  3. Global optimum: \(p_g = p_{\text{data}}\) with \(D^* = 0.5\) everywhere

  4. Training challenges: Mode collapse, vanishing gradients, non-convergence

Impact: GANs spawned numerous variants (DCGAN, WGAN, StyleGAN, BigGAN) and applications (image synthesis, style transfer, data augmentation, super-resolution).

Despite recent advances in diffusion models, GANs remain relevant for their speed (single forward pass) and theoretical elegance, continuing to inspire research in adversarial learning and generative modeling.

"""
GAN Mathematics - Complete Implementation

This implementation demonstrates the mathematical foundations of GANs with:
1. Standard GAN (vanilla architecture)
2. DCGAN (deep convolutional architecture)
3. Wasserstein GAN with gradient penalty
4. Spectral normalization
5. Multiple loss functions
6. Training stability techniques
7. Comprehensive demonstrations
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, Dict, List
from dataclasses import dataclass

# ============================================================================
# Configuration
# ============================================================================

@dataclass
class GANConfig:
    """Configuration for GAN training."""
    latent_dim: int = 100
    img_channels: int = 1
    img_size: int = 28
    hidden_dim: int = 128
    learning_rate_g: float = 2e-4
    learning_rate_d: float = 2e-4
    beta1: float = 0.5
    beta2: float = 0.999
    batch_size: int = 64
    num_epochs: int = 50
    k_discriminator: int = 1  # Discriminator updates per generator update
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

# ============================================================================
# 1. Standard GAN (Vanilla Architecture)
# ============================================================================

class VanillaGenerator(nn.Module):
    """
    Simple feedforward generator.
    
    Architecture:
    - Input: latent vector z (latent_dim,)
    - Hidden: 2 fully connected layers with BatchNorm and LeakyReLU
    - Output: flattened image (img_size * img_size * img_channels,)
    """
    def __init__(self, latent_dim: int = 100, img_size: int = 28, img_channels: int = 1,
                 hidden_dim: int = 256):
        super().__init__()
        self.img_size = img_size
        self.img_channels = img_channels
        self.img_shape = (img_channels, img_size, img_size)
        
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.BatchNorm1d(hidden_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Linear(hidden_dim * 4, int(np.prod(self.img_shape))),
            nn.Tanh()  # Output in [-1, 1]
        )
    
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Args:
            z: Latent vectors (batch_size, latent_dim)
        
        Returns:
            Generated images (batch_size, img_channels, img_size, img_size)
        """
        img_flat = self.model(z)
        img = img_flat.view(img_flat.size(0), *self.img_shape)
        return img


class VanillaDiscriminator(nn.Module):
    """
    Simple feedforward discriminator.
    
    Architecture:
    - Input: flattened image
    - Hidden: 3 fully connected layers with LeakyReLU
    - Output: probability (sigmoid)
    """
    def __init__(self, img_size: int = 28, img_channels: int = 1, hidden_dim: int = 256):
        super().__init__()
        self.img_shape = (img_channels, img_size, img_size)
        
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.img_shape)), hidden_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        """
        Args:
            img: Images (batch_size, img_channels, img_size, img_size)
        
        Returns:
            Probabilities (batch_size, 1)
        """
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


# ============================================================================
# 2. DCGAN (Deep Convolutional GAN)
# ============================================================================

class DCGANGenerator(nn.Module):
    """
    Deep Convolutional GAN generator.
    
    Architecture guidelines (from DCGAN paper):
    - Replace pooling with fractional-strided convolutions (transposed conv)
    - Use batch normalization
    - Remove fully connected layers
    - Use ReLU activations (Tanh for output)
    """
    def __init__(self, latent_dim: int = 100, img_channels: int = 1, feature_dim: int = 64):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Initial projection: latent_dim -> 4x4x(feature_dim*8)
        self.init_size = 4
        self.fc = nn.Linear(latent_dim, feature_dim * 8 * self.init_size * self.init_size)
        
        # Convolutional layers: 4x4 -> 8x8 -> 16x16 -> 32x32 (for 28x28, we crop)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(feature_dim * 8),
            
            # 4x4 -> 8x8
            nn.ConvTranspose2d(feature_dim * 8, feature_dim * 4, 4, 2, 1),
            nn.BatchNorm2d(feature_dim * 4),
            nn.ReLU(inplace=True),
            
            # 8x8 -> 16x16
            nn.ConvTranspose2d(feature_dim * 4, feature_dim * 2, 4, 2, 1),
            nn.BatchNorm2d(feature_dim * 2),
            nn.ReLU(inplace=True),
            
            # 16x16 -> 32x32
            nn.ConvTranspose2d(feature_dim * 2, feature_dim, 4, 2, 1),
            nn.BatchNorm2d(feature_dim),
            nn.ReLU(inplace=True),
            
            # 32x32 -> 28x28 (using kernel_size=4, stride=1, padding=2)
            nn.ConvTranspose2d(feature_dim, img_channels, 4, 1, 2),
            nn.Tanh()
        )
    
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Args:
            z: Latent vectors (batch_size, latent_dim)
        
        Returns:
            Generated images (batch_size, img_channels, 28, 28)
        """
        out = self.fc(z)
        out = out.view(out.size(0), -1, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class DCGANDiscriminator(nn.Module):
    """
    Deep Convolutional GAN discriminator.
    
    Architecture guidelines:
    - Replace pooling with strided convolutions
    - Use batch normalization (except first layer)
    - Use LeakyReLU activations
    """
    def __init__(self, img_channels: int = 1, feature_dim: int = 64):
        super().__init__()
        
        self.conv_blocks = nn.Sequential(
            # 28x28 -> 14x14
            nn.Conv2d(img_channels, feature_dim, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 14x14 -> 7x7
            nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1),
            nn.BatchNorm2d(feature_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 7x7 -> 3x3
            nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1),
            nn.BatchNorm2d(feature_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 3x3 -> 1x1
            nn.Conv2d(feature_dim * 4, feature_dim * 8, 3, 1, 0),
            nn.BatchNorm2d(feature_dim * 8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.adv_layer = nn.Sequential(
            nn.Conv2d(feature_dim * 8, 1, 1, 1, 0),
            nn.Sigmoid()
        )
    
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        """
        Args:
            img: Images (batch_size, img_channels, 28, 28)
        
        Returns:
            Probabilities (batch_size, 1)
        """
        out = self.conv_blocks(img)
        validity = self.adv_layer(out)
        validity = validity.view(validity.size(0), -1)
        return validity


# ============================================================================
# 3. Wasserstein GAN with Gradient Penalty
# ============================================================================

class WGANDiscriminator(nn.Module):
    """
    WGAN discriminator (critic).
    
    Differences from standard discriminator:
    - No sigmoid at output (outputs real-valued scores)
    - Called "critic" because it estimates Wasserstein distance
    """
    def __init__(self, img_channels: int = 1, feature_dim: int = 64):
        super().__init__()
        
        self.model = nn.Sequential(
            # 28x28 -> 14x14
            nn.Conv2d(img_channels, feature_dim, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 14x14 -> 7x7
            nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1),
            nn.InstanceNorm2d(feature_dim * 2),  # Use InstanceNorm instead of BatchNorm for WGAN
            nn.LeakyReLU(0.2, inplace=True),
            
            # 7x7 -> 3x3
            nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1),
            nn.InstanceNorm2d(feature_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 3x3 -> 1x1
            nn.Conv2d(feature_dim * 4, 1, 3, 1, 0),
            # No activation (output real-valued scores)
        )
    
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        """
        Args:
            img: Images (batch_size, img_channels, 28, 28)
        
        Returns:
            Critic scores (batch_size, 1)
        """
        validity = self.model(img)
        validity = validity.view(validity.size(0), -1)
        return validity


def compute_gradient_penalty(discriminator: nn.Module, real_samples: torch.Tensor,
                              fake_samples: torch.Tensor, device: str = 'cpu') -> torch.Tensor:
    """
    Compute gradient penalty for WGAN-GP.
    
    Enforces Lipschitz constraint: ||βˆ‡_x D(x)||β‚‚ ≀ 1
    
    Penalty: Ξ» * E[(||βˆ‡_xΜ‚ D(xΜ‚)||β‚‚ - 1)Β²]
    where xΜ‚ = Ξ΅*x_real + (1-Ξ΅)*x_fake, Ξ΅ ~ U(0,1)
    
    Args:
        discriminator: Discriminator network
        real_samples: Real images (batch_size, C, H, W)
        fake_samples: Fake images (batch_size, C, H, W)
        device: Device
    
    Returns:
        Gradient penalty scalar
    """
    batch_size = real_samples.size(0)
    
    # Random weight for interpolation
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    
    # Interpolated samples
    interpolated = epsilon * real_samples + (1 - epsilon) * fake_samples
    interpolated.requires_grad_(True)
    
    # Discriminator output for interpolated samples
    d_interpolated = discriminator(interpolated)
    
    # Compute gradients w.r.t. interpolated samples
    fake_labels = torch.ones(batch_size, 1, device=device)
    gradients = grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=fake_labels,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Flatten gradients
    gradients = gradients.view(batch_size, -1)
    
    # Compute gradient norm
    gradient_norm = gradients.norm(2, dim=1)
    
    # Penalty: (||gradient|| - 1)Β²
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()
    
    return gradient_penalty


# ============================================================================
# 4. Spectral Normalization
# ============================================================================

class SpectralNorm:
    """
    Spectral normalization for weight matrices.
    
    Normalizes weights by their spectral norm (largest singular value):
    W_normalized = W / Οƒ(W)
    
    Οƒ(W) is estimated via power iteration.
    """
    def __init__(self, name: str = 'weight', n_power_iterations: int = 1, eps: float = 1e-12):
        self.name = name
        self.n_power_iterations = n_power_iterations
        self.eps = eps
    
    def compute_weight(self, module: nn.Module) -> torch.Tensor:
        """Compute spectral-normalized weight."""
        weight = getattr(module, self.name + '_orig')
        u = getattr(module, self.name + '_u')
        v = getattr(module, self.name + '_v')
        
        # Power iteration
        with torch.no_grad():
            for _ in range(self.n_power_iterations):
                # v = W^T u / ||W^T u||
                v = F.normalize(torch.mv(weight.t(), u), dim=0, eps=self.eps)
                # u = W v / ||W v||
                u = F.normalize(torch.mv(weight, v), dim=0, eps=self.eps)
        
        # Spectral norm: Οƒ(W) = u^T W v
        sigma = torch.dot(u, torch.mv(weight, v))
        
        # Normalize weight
        weight_normalized = weight / sigma
        
        return weight_normalized
    
    def __call__(self, module: nn.Module, inputs):
        """Apply spectral normalization before forward pass."""
        setattr(module, self.name, self.compute_weight(module))


def spectral_norm(module: nn.Module, name: str = 'weight', n_power_iterations: int = 1) -> nn.Module:
    """
    Apply spectral normalization to a module.
    
    Example:
        conv = spectral_norm(nn.Conv2d(3, 64, 3))
    """
    SpectralNorm.apply(module, name, n_power_iterations)
    return module


class SNDiscriminator(nn.Module):
    """
    Discriminator with spectral normalization.
    
    Spectral normalization stabilizes training by constraining Lipschitz constant.
    """
    def __init__(self, img_channels: int = 1, feature_dim: int = 64):
        super().__init__()
        
        self.model = nn.Sequential(
            # Use PyTorch's built-in spectral_norm
            nn.utils.spectral_norm(nn.Conv2d(img_channels, feature_dim, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(feature_dim * 4, 1, 3, 1, 0)),
            nn.Sigmoid()
        )
    
    def forward(self, img: torch.Tensor) -> torch.Tensor:
        validity = self.model(img)
        validity = validity.view(validity.size(0), -1)
        return validity


# ============================================================================
# 5. Loss Functions
# ============================================================================

class GANLoss:
    """Collection of GAN loss functions."""
    
    @staticmethod
    def vanilla_discriminator_loss(d_real: torch.Tensor, d_fake: torch.Tensor) -> torch.Tensor:
        """
        Vanilla GAN discriminator loss.
        
        L_D = -E[log D(x)] - E[log(1 - D(G(z)))]
        
        Equivalent to binary cross-entropy:
        L_D = BCE(D(x), 1) + BCE(D(G(z)), 0)
        """
        real_loss = F.binary_cross_entropy(d_real, torch.ones_like(d_real))
        fake_loss = F.binary_cross_entropy(d_fake, torch.zeros_like(d_fake))
        return real_loss + fake_loss
    
    @staticmethod
    def vanilla_generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
        """
        Vanilla GAN generator loss.
        
        L_G = -E[log D(G(z))]
        
        Equivalent to:
        L_G = BCE(D(G(z)), 1)
        """
        return F.binary_cross_entropy(d_fake, torch.ones_like(d_fake))
    
    @staticmethod
    def nonsaturating_generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
        """
        Non-saturating generator loss.
        
        Instead of minimizing -E[log D(G(z))], maximize E[log D(G(z))].
        
        Provides stronger gradients when D is confident.
        """
        # Same as vanilla in this implementation
        return F.binary_cross_entropy(d_fake, torch.ones_like(d_fake))
    
    @staticmethod
    def wasserstein_discriminator_loss(d_real: torch.Tensor, d_fake: torch.Tensor) -> torch.Tensor:
        """
        Wasserstein GAN discriminator (critic) loss.
        
        L_D = -E[D(x)] + E[D(G(z))]
        
        Minimize this to maximize E[D(x)] - E[D(G(z))] (Earth Mover's Distance).
        """
        return -d_real.mean() + d_fake.mean()
    
    @staticmethod
    def wasserstein_generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
        """
        Wasserstein GAN generator loss.
        
        L_G = -E[D(G(z))]
        """
        return -d_fake.mean()
    
    @staticmethod
    def hinge_discriminator_loss(d_real: torch.Tensor, d_fake: torch.Tensor) -> torch.Tensor:
        """
        Hinge loss for discriminator.
        
        L_D = E[max(0, 1 - D(x))] + E[max(0, 1 + D(G(z)))]
        
        Used in BigGAN and other modern architectures.
        """
        real_loss = F.relu(1.0 - d_real).mean()
        fake_loss = F.relu(1.0 + d_fake).mean()
        return real_loss + fake_loss
    
    @staticmethod
    def hinge_generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
        """
        Hinge loss for generator.
        
        L_G = -E[D(G(z))]
        """
        return -d_fake.mean()


# ============================================================================
# 6. Training Utilities
# ============================================================================

def train_vanilla_gan_step(
    generator: nn.Module,
    discriminator: nn.Module,
    real_imgs: torch.Tensor,
    optimizer_g: optim.Optimizer,
    optimizer_d: optim.Optimizer,
    latent_dim: int,
    device: str = 'cpu'
) -> Dict[str, float]:
    """
    Single training step for vanilla GAN.
    
    Returns:
        Dictionary with loss values
    """
    batch_size = real_imgs.size(0)
    
    # ---------------------
    # Train Discriminator
    # ---------------------
    optimizer_d.zero_grad()
    
    # Real images
    real_validity = discriminator(real_imgs)
    
    # Fake images
    z = torch.randn(batch_size, latent_dim, device=device)
    fake_imgs = generator(z)
    fake_validity = discriminator(fake_imgs.detach())
    
    # Discriminator loss
    d_loss = GANLoss.vanilla_discriminator_loss(real_validity, fake_validity)
    d_loss.backward()
    optimizer_d.step()
    
    # -----------------
    # Train Generator
    # -----------------
    optimizer_g.zero_grad()
    
    # Generate new fake images
    z = torch.randn(batch_size, latent_dim, device=device)
    fake_imgs = generator(z)
    fake_validity = discriminator(fake_imgs)
    
    # Generator loss
    g_loss = GANLoss.vanilla_generator_loss(fake_validity)
    g_loss.backward()
    optimizer_g.step()
    
    return {
        'd_loss': d_loss.item(),
        'g_loss': g_loss.item(),
        'd_real': real_validity.mean().item(),
        'd_fake': fake_validity.mean().item()
    }


def train_wgan_gp_step(
    generator: nn.Module,
    discriminator: nn.Module,
    real_imgs: torch.Tensor,
    optimizer_g: optim.Optimizer,
    optimizer_d: optim.Optimizer,
    latent_dim: int,
    lambda_gp: float = 10.0,
    device: str = 'cpu'
) -> Dict[str, float]:
    """
    Single training step for WGAN-GP.
    
    Args:
        lambda_gp: Gradient penalty coefficient (default: 10)
    
    Returns:
        Dictionary with loss values
    """
    batch_size = real_imgs.size(0)
    
    # ---------------------
    # Train Discriminator
    # ---------------------
    optimizer_d.zero_grad()
    
    # Real images
    real_validity = discriminator(real_imgs)
    
    # Fake images
    z = torch.randn(batch_size, latent_dim, device=device)
    fake_imgs = generator(z)
    fake_validity = discriminator(fake_imgs.detach())
    
    # Gradient penalty
    gp = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), device)
    
    # Discriminator loss: -E[D(x)] + E[D(G(z))] + Ξ»*GP
    d_loss = GANLoss.wasserstein_discriminator_loss(real_validity, fake_validity) + lambda_gp * gp
    d_loss.backward()
    optimizer_d.step()
    
    # -----------------
    # Train Generator
    # -----------------
    optimizer_g.zero_grad()
    
    # Generate new fake images
    z = torch.randn(batch_size, latent_dim, device=device)
    fake_imgs = generator(z)
    fake_validity = discriminator(fake_imgs)
    
    # Generator loss: -E[D(G(z))]
    g_loss = GANLoss.wasserstein_generator_loss(fake_validity)
    g_loss.backward()
    optimizer_g.step()
    
    return {
        'd_loss': d_loss.item(),
        'g_loss': g_loss.item(),
        'gp': gp.item(),
        'd_real': real_validity.mean().item(),
        'd_fake': fake_validity.mean().item(),
        'wasserstein_dist': (real_validity.mean() - fake_validity.mean()).item()
    }


# ============================================================================
# 7. Visualization
# ============================================================================

def visualize_generated_samples(generator: nn.Module, latent_dim: int, num_samples: int = 16,
                                  device: str = 'cpu', figsize: Tuple[int, int] = (8, 8)):
    """
    Visualize generated samples.
    
    Args:
        generator: Generator network
        latent_dim: Dimension of latent space
        num_samples: Number of samples to generate (should be square number)
        device: Device
        figsize: Figure size
    """
    generator.eval()
    
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim, device=device)
        fake_imgs = generator(z).cpu()
    
    # Plot
    grid_size = int(np.sqrt(num_samples))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=figsize)
    
    for i in range(grid_size):
        for j in range(grid_size):
            idx = i * grid_size + j
            img = fake_imgs[idx, 0]  # Remove channel dimension
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    generator.train()


def visualize_training_progress(losses: Dict[str, List[float]], figsize: Tuple[int, int] = (12, 4)):
    """
    Visualize training losses over time.
    
    Args:
        losses: Dictionary with loss histories
        figsize: Figure size
    """
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    # Discriminator and Generator losses
    axes[0].plot(losses['d_loss'], label='D Loss', alpha=0.7)
    axes[0].plot(losses['g_loss'], label='G Loss', alpha=0.7)
    axes[0].set_xlabel('Iteration')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('GAN Losses')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Discriminator outputs
    axes[1].plot(losses['d_real'], label='D(x)', alpha=0.7)
    axes[1].plot(losses['d_fake'], label='D(G(z))', alpha=0.7)
    axes[1].axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Random')
    axes[1].set_xlabel('Iteration')
    axes[1].set_ylabel('Discriminator Output')
    axes[1].set_title('Discriminator Predictions')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Wasserstein distance (if available)
    if 'wasserstein_dist' in losses:
        axes[2].plot(losses['wasserstein_dist'], label='W Distance', alpha=0.7)
        axes[2].set_xlabel('Iteration')
        axes[2].set_ylabel('Distance')
        axes[2].set_title('Wasserstein Distance')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)
    else:
        axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()


def visualize_latent_space_interpolation(generator: nn.Module, latent_dim: int, num_steps: int = 10,
                                          device: str = 'cpu', figsize: Tuple[int, int] = (15, 3)):
    """
    Interpolate between two random points in latent space.
    
    Args:
        generator: Generator network
        latent_dim: Dimension of latent space
        num_steps: Number of interpolation steps
        device: Device
        figsize: Figure size
    """
    generator.eval()
    
    # Sample two random latent vectors
    z1 = torch.randn(1, latent_dim, device=device)
    z2 = torch.randn(1, latent_dim, device=device)
    
    # Interpolate
    alphas = torch.linspace(0, 1, num_steps, device=device).view(-1, 1)
    z_interp = z1 * (1 - alphas) + z2 * alphas
    
    with torch.no_grad():
        imgs = generator(z_interp).cpu()
    
    # Plot
    fig, axes = plt.subplots(1, num_steps, figsize=figsize)
    
    for i in range(num_steps):
        img = imgs[i, 0]
        axes[i].imshow(img, cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'Ξ±={alphas[i].item():.1f}')
    
    plt.tight_layout()
    plt.show()
    
    generator.train()


# ============================================================================
# 8. Demonstrations
# ============================================================================

def demo_vanilla_gan():
    """Demonstrate vanilla GAN architecture and forward pass."""
    print("=" * 70)
    print("Vanilla GAN Demonstration")
    print("=" * 70)
    
    latent_dim = 100
    batch_size = 4
    
    # Initialize models
    generator = VanillaGenerator(latent_dim=latent_dim)
    discriminator = VanillaDiscriminator()
    
    print(f"\nGenerator Architecture:")
    print(generator)
    print(f"\nGenerator Parameters: {sum(p.numel() for p in generator.parameters()):,}")
    
    print(f"\nDiscriminator Architecture:")
    print(discriminator)
    print(f"\nDiscriminator Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
    
    # Forward pass
    z = torch.randn(batch_size, latent_dim)
    fake_imgs = generator(z)
    print(f"\nGenerated Images Shape: {fake_imgs.shape}")
    print(f"Generated Images Range: [{fake_imgs.min():.3f}, {fake_imgs.max():.3f}]")
    
    validity = discriminator(fake_imgs)
    print(f"\nDiscriminator Output Shape: {validity.shape}")
    print(f"Discriminator Output (probabilities): {validity.squeeze().tolist()}")
    print()


def demo_dcgan():
    """Demonstrate DCGAN architecture."""
    print("=" * 70)
    print("DCGAN Demonstration")
    print("=" * 70)
    
    latent_dim = 100
    batch_size = 4
    
    # Initialize models
    generator = DCGANGenerator(latent_dim=latent_dim)
    discriminator = DCGANDiscriminator()
    
    print(f"\nGenerator Architecture:")
    print(generator)
    print(f"\nGenerator Parameters: {sum(p.numel() for p in generator.parameters()):,}")
    
    print(f"\nDiscriminator Architecture:")
    print(discriminator)
    print(f"\nDiscriminator Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
    
    # Forward pass
    z = torch.randn(batch_size, latent_dim)
    fake_imgs = generator(z)
    print(f"\nGenerated Images Shape: {fake_imgs.shape}")
    
    validity = discriminator(fake_imgs)
    print(f"Discriminator Output Shape: {validity.shape}")
    print()


def demo_wgan():
    """Demonstrate WGAN-GP architecture and gradient penalty."""
    print("=" * 70)
    print("WGAN-GP Demonstration")
    print("=" * 70)
    
    latent_dim = 100
    batch_size = 4
    device = 'cpu'
    
    # Initialize models
    generator = DCGANGenerator(latent_dim=latent_dim)
    discriminator = WGANDiscriminator()
    
    print(f"\nCritic (Discriminator) Architecture:")
    print(discriminator)
    
    # Create fake real/fake images
    real_imgs = torch.randn(batch_size, 1, 28, 28)
    z = torch.randn(batch_size, latent_dim)
    fake_imgs = generator(z)
    
    # Compute gradient penalty
    gp = compute_gradient_penalty(discriminator, real_imgs, fake_imgs, device)
    print(f"\nGradient Penalty: {gp.item():.4f}")
    
    # Compute Wasserstein distance estimate
    with torch.no_grad():
        real_validity = discriminator(real_imgs)
        fake_validity = discriminator(fake_imgs)
        w_dist = (real_validity.mean() - fake_validity.mean()).item()
    
    print(f"Wasserstein Distance Estimate: {w_dist:.4f}")
    print()


def demo_spectral_normalization():
    """Demonstrate spectral normalization."""
    print("=" * 70)
    print("Spectral Normalization Demonstration")
    print("=" * 70)
    
    # Regular conv layer
    conv_regular = nn.Conv2d(1, 64, 4, 2, 1)
    
    # Spectral normalized conv layer
    conv_sn = nn.utils.spectral_norm(nn.Conv2d(1, 64, 4, 2, 1))
    
    # Compute spectral norms
    def spectral_norm_value(layer):
        """Compute spectral norm of weight matrix."""
        weight = layer.weight.data
        # Reshape to 2D matrix
        weight_mat = weight.view(weight.size(0), -1)
        # Compute largest singular value
        u, s, v = torch.svd(weight_mat)
        return s[0].item()
    
    sn_regular = spectral_norm_value(conv_regular)
    sn_normalized = spectral_norm_value(conv_sn)
    
    print(f"\nRegular Conv Layer:")
    print(f"  Spectral Norm: {sn_regular:.4f}")
    
    print(f"\nSpectral Normalized Conv Layer:")
    print(f"  Spectral Norm: {sn_normalized:.4f}")
    
    print(f"\nSpectral normalization constrains the Lipschitz constant to β‰ˆ1")
    print()


def demo_loss_functions():
    """Demonstrate different GAN loss functions."""
    print("=" * 70)
    print("GAN Loss Functions Demonstration")
    print("=" * 70)
    
    # Create dummy discriminator outputs
    d_real = torch.tensor([[0.9], [0.8], [0.85], [0.95]])
    d_fake = torch.tensor([[0.2], [0.1], [0.15], [0.25]])
    
    print("\nDiscriminator Outputs:")
    print(f"  D(x_real): {d_real.squeeze().tolist()}")
    print(f"  D(x_fake): {d_fake.squeeze().tolist()}")
    
    # Compute losses
    print("\nDiscriminator Losses:")
    print(f"  Vanilla: {GANLoss.vanilla_discriminator_loss(d_real, d_fake):.4f}")
    print(f"  Wasserstein: {GANLoss.wasserstein_discriminator_loss(d_real, d_fake):.4f}")
    print(f"  Hinge: {GANLoss.hinge_discriminator_loss(d_real, d_fake):.4f}")
    
    print("\nGenerator Losses:")
    print(f"  Vanilla: {GANLoss.vanilla_generator_loss(d_fake):.4f}")
    print(f"  Wasserstein: {GANLoss.wasserstein_generator_loss(d_fake):.4f}")
    print(f"  Hinge: {GANLoss.hinge_generator_loss(d_fake):.4f}")
    print()


def demo_optimal_discriminator():
    """Demonstrate optimal discriminator theorem."""
    print("=" * 70)
    print("Optimal Discriminator Demonstration")
    print("=" * 70)
    
    # Create grid of x values
    x = torch.linspace(-5, 5, 100)
    
    # Define different scenarios for p_data and p_g
    scenarios = [
        ("Perfect Generator", 
         lambda x: torch.exp(-x**2), 
         lambda x: torch.exp(-x**2)),
        
        ("Good Generator", 
         lambda x: torch.exp(-x**2), 
         lambda x: torch.exp(-(x-0.5)**2)),
        
        ("Poor Generator", 
         lambda x: torch.exp(-x**2), 
         lambda x: torch.exp(-(x-2)**2)),
    ]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    for idx, (name, p_data_fn, p_g_fn) in enumerate(scenarios):
        p_data = p_data_fn(x)
        p_g = p_g_fn(x)
        
        # Optimal discriminator: D*(x) = p_data(x) / (p_data(x) + p_g(x))
        d_star = p_data / (p_data + p_g)
        
        axes[idx].plot(x.numpy(), p_data.numpy(), label='$p_{data}(x)$', linewidth=2)
        axes[idx].plot(x.numpy(), p_g.numpy(), label='$p_g(x)$', linewidth=2)
        axes[idx].plot(x.numpy(), d_star.numpy(), label='$D^*(x)$', linewidth=2, linestyle='--')
        axes[idx].axhline(y=0.5, color='r', linestyle=':', alpha=0.5, label='Random')
        axes[idx].set_xlabel('x')
        axes[idx].set_ylabel('Probability / D(x)')
        axes[idx].set_title(name)
        axes[idx].legend()
        axes[idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nKey Observations:")
    print("  1. When p_g = p_data, D*(x) = 0.5 (cannot distinguish)")
    print("  2. When p_data > p_g, D*(x) > 0.5 (likely real)")
    print("  3. When p_g > p_data, D*(x) < 0.5 (likely fake)")
    print()


# ============================================================================
# Main Demonstration
# ============================================================================

if __name__ == "__main__":
    print("\n" + "="*70)
    print(" GAN Mathematics - Comprehensive Implementation Demonstrations")
    print("="*70 + "\n")
    
    # Run all demonstrations
    demo_vanilla_gan()
    demo_dcgan()
    demo_wgan()
    demo_spectral_normalization()
    demo_loss_functions()
    demo_optimal_discriminator()
    
    print("="*70)
    print("All demonstrations completed successfully!")
    print("="*70)

Generative Adversarial Networks (GANs)ΒΆ

Learning Objectives:

  • Understand GAN theory and mathematics

  • Implement generator and discriminator networks

  • Train GANs with proper loss functions

  • Apply to image generation

Prerequisites: Deep learning basics, PyTorch

Time: 90 minutes

πŸ“š Reference Materials:

1. The GAN FrameworkΒΆ

Original Formulation (Goodfellow et al., 2014)ΒΆ

Players:

  • Generator \(G\): Maps noise \(z \sim p_z\) to fake data \(G(z)\)

  • Discriminator \(D\): Classifies real vs fake data

Objective (Minimax Game):

\[\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\]

Intuition:

  • \(D\) tries to maximize (distinguish real from fake)

  • \(G\) tries to minimize (fool \(D\))

  • At equilibrium: \(p_g = p_{data}\) (generator matches data distribution)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset
from scipy import stats

# Set style and seed
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

2. Theoretical AnalysisΒΆ

Optimal DiscriminatorΒΆ

For fixed \(G\), the optimal discriminator is:

\[D^*_G(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\]

Proof: Maximize: $\(V(D, G) = \int_x p_{data}(x) \log D(x) dx + \int_x p_g(x) \log(1-D(x)) dx\)$

Taking derivative w.r.t. \(D(x)\) and setting to zero: $\(\frac{p_{data}(x)}{D(x)} - \frac{p_g(x)}{1-D(x)} = 0\)$

Solving: \(D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\)

Global Optimal PointΒΆ

When \(D = D^*_G\), the value function becomes:

\[C(G) = \max_D V(D,G) = -\log 4 + 2 \cdot JSD(p_{data} || p_g)\]

where \(JSD\) is the Jensen-Shannon divergence.

Key Result: Global minimum achieved when \(p_g = p_{data}\)

  • At optimum: \(D^*(x) = 1/2\) everywhere

  • \(C(G) = -\log 4\)

2.5. Advanced Theory: Mode Collapse AnalysisΒΆ

Why Mode Collapse OccursΒΆ

Jensen-Shannon Divergence Properties:

When distributions have non-overlapping support: $\(JSD(P || Q) = \log 2\)$

This can lead to:

  • Gradient saturation when discriminator is too strong

  • Missing modes in generated distribution

  • Oscillating training dynamics

Mathematical Analysis of Mode CollapseΒΆ

Consider generator producing only mode \(m_i\) of true distribution. The discriminator can easily classify:

  • Samples near \(m_i\): could be real or fake

  • Samples near other modes \(m_j\) (j β‰  i): must be real

Implication: Generator receives no gradient signal to explore other modes!

Reverse KL Divergence PerspectiveΒΆ

GAN implicitly minimizes a form related to: $\(KL(p_g || p_{data})\)$

Properties:

  • Mode-seeking: \(p_g\) concentrates on high-probability regions of \(p_{data}\)

  • Zero-avoiding: When \(p_{data}(x) = 0\), allows \(p_g(x) > 0\) (hallucination)

  • Zero-forcing: When \(p_g(x) = 0\), requires \(p_{data}(x) β‰ˆ 0\) (mode collapse)

Contrast with: $\(KL(p_{data} || p_g) \text{ (mode-covering, used in VAE)}\)$

# Visualize mode collapse phenomenon
def demonstrate_mode_collapse():
    """
    Simulate mode collapse by training a generator 
    that only captures one mode of a multi-modal distribution
    """
    
    # True distribution: 5 Gaussians
    def sample_true_data(n):
        modes = np.array([-4, -2, 0, 2, 4])
        chosen_modes = np.random.choice(5, n)
        samples = modes[chosen_modes] + np.random.randn(n) * 0.3
        return samples
    
    # Collapsed generator: only mode at x=0
    def sample_collapsed_generator(n):
        return np.random.randn(n) * 0.3  # Only around 0
    
    # Diverse generator: captures all modes
    def sample_diverse_generator(n):
        modes = np.array([-4, -2, 0, 2, 4])
        chosen_modes = np.random.choice(5, n)
        samples = modes[chosen_modes] + np.random.randn(n) * 0.35
        return samples
    
    fig, axes = plt.subplots(1, 3, figsize=(16, 4))
    
    # True distribution
    true_samples = sample_true_data(10000)
    axes[0].hist(true_samples, bins=80, density=True, alpha=0.7, 
                 color='blue', edgecolor='black')
    axes[0].set_title('True Distribution (5 Modes)', fontweight='bold', fontsize=12)
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('Density')
    axes[0].set_xlim([-6, 6])
    axes[0].grid(True, alpha=0.3)
    
    # Collapsed generator
    collapsed_samples = sample_collapsed_generator(10000)
    axes[1].hist(collapsed_samples, bins=80, density=True, alpha=0.7, 
                 color='red', edgecolor='black')
    axes[1].axvline(x=0, color='darkred', linestyle='--', linewidth=2, 
                    label='Captured mode')
    axes[1].set_title('Mode Collapse (1/5 Modes)', fontweight='bold', fontsize=12)
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('Density')
    axes[1].set_xlim([-6, 6])
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Diverse generator
    diverse_samples = sample_diverse_generator(10000)
    axes[2].hist(diverse_samples, bins=80, density=True, alpha=0.7, 
                 color='green', edgecolor='black')
    axes[2].set_title('Successful GAN (All Modes)', fontweight='bold', fontsize=12)
    axes[2].set_xlabel('x')
    axes[2].set_ylabel('Density')
    axes[2].set_xlim([-6, 6])
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*60)
    print("MODE COLLAPSE ANALYSIS")
    print("="*60)
    print(f"True distribution: 5 modes at positions {[-4, -2, 0, 2, 4]}")
    print(f"Collapsed GAN: Captures only 1 mode (20% coverage)")
    print(f"Successful GAN: Captures all 5 modes (100% coverage)")
    print("\nMetrics:")
    print(f"  - Diversity (std): True={true_samples.std():.2f}, "
          f"Collapsed={collapsed_samples.std():.2f}, "
          f"Diverse={diverse_samples.std():.2f}")
    print("="*60)

demonstrate_mode_collapse()
# Visualize optimal discriminator
def optimal_discriminator(x, p_data_func, p_g_func):
    """Compute D*(x) = p_data(x) / (p_data(x) + p_g(x))"""
    p_data = p_data_func(x)
    p_g = p_g_func(x)
    return p_data / (p_data + p_g + 1e-10)

# Example: Gaussian distributions
x = np.linspace(-5, 5, 1000)
p_data = stats.norm(0, 1).pdf(x)  # Real data: N(0, 1)
p_g_close = stats.norm(0.2, 1.1).pdf(x)  # Generator close to real
p_g_far = stats.norm(2, 0.5).pdf(x)  # Generator far from real

D_star_close = optimal_discriminator(x, lambda x: stats.norm(0, 1).pdf(x), 
                                      lambda x: stats.norm(0.2, 1.1).pdf(x))
D_star_far = optimal_discriminator(x, lambda x: stats.norm(0, 1).pdf(x), 
                                    lambda x: stats.norm(2, 0.5).pdf(x))

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Distributions when G is close
axes[0].plot(x, p_data, label='$p_{data}$', linewidth=2)
axes[0].plot(x, p_g_close, label='$p_g$ (close)', linewidth=2)
axes[0].set_title('Generator Close to Data', fontweight='bold')
axes[0].set_xlabel('x')
axes[0].set_ylabel('Probability Density')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot 2: Optimal D when G is close
axes[1].plot(x, D_star_close, linewidth=2, color='purple')
axes[1].axhline(y=0.5, color='r', linestyle='--', label='Perfect ($D^* = 0.5$)')
axes[1].set_title('$D^*$ when $p_g$ β‰ˆ $p_{data}$', fontweight='bold')
axes[1].set_xlabel('x')
axes[1].set_ylabel('$D^*(x)$')
axes[1].set_ylim([0, 1])
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Plot 3: Optimal D when G is far
axes[2].plot(x, D_star_far, linewidth=2, color='orange')
axes[2].axhline(y=0.5, color='r', linestyle='--', label='Perfect ($D^* = 0.5$)')
axes[2].set_title('$D^*$ when $p_g$ β‰  $p_{data}$', fontweight='bold')
axes[2].set_xlabel('x')
axes[2].set_ylabel('$D^*(x)$')
axes[2].set_ylim([0, 1])
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nObservations:")
print("- When p_g β‰ˆ p_data: D*(x) β‰ˆ 0.5 everywhere (cannot distinguish)")
print("- When p_g β‰  p_data: D*(x) varies (can distinguish)")

Implementing Vanilla GAN – 1D Gaussian MixtureΒΆ

To build intuition before scaling to images, we start with a GAN that learns a simple 1D Gaussian mixture distribution. The generator maps samples from a uniform prior \(z \sim U(-1, 1)\) to the data space, while the discriminator receives both real samples from the target distribution and fake samples from the generator. Training alternates between two objectives: the discriminator maximizes \(\mathbb{E}[\log D(x)] + \mathbb{E}[\log(1 - D(G(z)))]\), learning to tell real from fake, while the generator minimizes \(\mathbb{E}[\log(1 - D(G(z)))]\), learning to fool the discriminator. Observing this minimax game in 1D makes it easy to visualize how the generated distribution gradually aligns with the target – a key insight before moving to high-dimensional image generation.

# Generate 1D data: Mixture of Gaussians
def generate_real_data(n_samples=1000):
    """Generate samples from mixture of 3 Gaussians"""
    # Mix components
    components = np.random.choice(3, n_samples, p=[0.3, 0.5, 0.2])
    
    samples = np.zeros(n_samples)
    samples[components == 0] = np.random.normal(-2, 0.5, (components == 0).sum())
    samples[components == 1] = np.random.normal(0, 0.8, (components == 1).sum())
    samples[components == 2] = np.random.normal(3, 0.3, (components == 2).sum())
    
    return torch.FloatTensor(samples).unsqueeze(1)

# Visualize real data
real_data = generate_real_data(5000)
plt.figure(figsize=(10, 4))
plt.hist(real_data.numpy(), bins=100, density=True, alpha=0.7, edgecolor='black')
plt.title('Real Data Distribution: Mixture of 3 Gaussians', fontsize=13, fontweight='bold')
plt.xlabel('x')
plt.ylabel('Density')
plt.grid(True, alpha=0.3)
plt.show()
# Define Generator
class Generator(nn.Module):
    def __init__(self, noise_dim=10, hidden_dim=128):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Output: 1D
        )
    
    def forward(self, z):
        return self.net(z)

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dim),  # Input: 1D
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # Output: probability
        )
    
    def forward(self, x):
        return self.net(x)

# Initialize models
noise_dim = 10
G = Generator(noise_dim=noise_dim).to(device)
D = Discriminator().to(device)

print("Generator:")
print(G)
print(f"\nTotal parameters: {sum(p.numel() for p in G.parameters())}")

print("\nDiscriminator:")
print(D)
print(f"\nTotal parameters: {sum(p.numel() for p in D.parameters())}")
# Training function
def train_gan(G, D, n_epochs=5000, batch_size=256, lr=0.0002):
    """Train GAN with original loss"""
    
    # Optimizers
    optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    
    # Loss function
    criterion = nn.BCELoss()
    
    # Training history
    history = {'D_loss': [], 'G_loss': [], 'D_real': [], 'D_fake': []}
    
    for epoch in range(n_epochs):
        # Generate real data batch
        real_data = generate_real_data(batch_size).to(device)
        
        # Labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # =====================
        # Train Discriminator
        # =====================
        optimizer_D.zero_grad()
        
        # Real data
        D_real = D(real_data)
        loss_D_real = criterion(D_real, real_labels)
        
        # Fake data
        z = torch.randn(batch_size, noise_dim).to(device)
        fake_data = G(z).detach()  # Detach to avoid training G
        D_fake = D(fake_data)
        loss_D_fake = criterion(D_fake, fake_labels)
        
        # Total D loss
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        optimizer_D.step()
        
        # ==================
        # Train Generator
        # ==================
        optimizer_G.zero_grad()
        
        # Generate fake data
        z = torch.randn(batch_size, noise_dim).to(device)
        fake_data = G(z)
        D_fake = D(fake_data)
        
        # G loss: fool D (D_fake should be close to 1)
        loss_G = criterion(D_fake, real_labels)
        loss_G.backward()
        optimizer_G.step()
        
        # Record history
        history['D_loss'].append(loss_D.item())
        history['G_loss'].append(loss_G.item())
        history['D_real'].append(D_real.mean().item())
        history['D_fake'].append(D_fake.mean().item())
        
        # Print progress
        if (epoch + 1) % 500 == 0:
            print(f"Epoch [{epoch+1}/{n_epochs}] | "
                  f"D Loss: {loss_D.item():.4f} | G Loss: {loss_G.item():.4f} | "
                  f"D(real): {D_real.mean():.4f} | D(fake): {D_fake.mean():.4f}")
    
    return history

# Train GAN
print("Training GAN...")
history = train_gan(G, D, n_epochs=5000, batch_size=256)
print("\nβœ“ Training complete!")
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Loss curves
axes[0].plot(history['D_loss'], label='D Loss', alpha=0.7)
axes[0].plot(history['G_loss'], label='G Loss', alpha=0.7)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Losses', fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Discriminator outputs
axes[1].plot(history['D_real'], label='D(real)', alpha=0.7)
axes[1].plot(history['D_fake'], label='D(fake)', alpha=0.7)
axes[1].axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Equilibrium')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Discriminator Output')
axes[1].set_title('Discriminator Performance', fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()
# Generate and visualize samples
G.eval()
with torch.no_grad():
    z = torch.randn(5000, noise_dim).to(device)
    generated_data = G(z).cpu()

# Plot real vs generated
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(real_data.numpy(), bins=100, density=True, alpha=0.7, label='Real', edgecolor='black')
plt.title('Real Data Distribution', fontsize=13, fontweight='bold')
plt.xlabel('x')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.hist(generated_data.numpy(), bins=100, density=True, alpha=0.7, color='orange', label='Generated', edgecolor='black')
plt.title('Generated Data Distribution', fontsize=13, fontweight='bold')
plt.xlabel('x')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nβœ“ Generator successfully learned the mixture distribution!")

4. Training ChallengesΒΆ

Common IssuesΒΆ

  1. Mode Collapse: Generator produces limited variety

  2. Non-Convergence: Oscillating losses

  3. Vanishing Gradients: G can’t learn when D is too strong

  4. Hyperparameter Sensitivity: lr, architecture choices

Alternative Loss: Non-Saturating LossΒΆ

Problem: Original G loss saturates when D is confident: $\(\nabla_G \log(1 - D(G(z))) \rightarrow 0 \text{ when } D(G(z)) \rightarrow 0\)$

Solution: Maximize \(\log D(G(z))\) instead: $\(\max_G \mathbb{E}_{z \sim p_z}[\log D(G(z))]\)$

Stronger gradients early in training!

4.5. Advanced Training TechniquesΒΆ

Technique 1: Label SmoothingΒΆ

Replace hard labels with smooth labels:

  • Real labels: \(0.9\) instead of \(1.0\)

  • Fake labels: \(0.1\) instead of \(0.0\)

Benefit: Prevents discriminator overconfidence

Technique 2: One-Sided Label SmoothingΒΆ

Only smooth real labels (recommended):

  • Real: \(0.9\)

  • Fake: \(0.0\)

Rationale: Smoothing fake labels can lead to unstable generator

Technique 3: Feature MatchingΒΆ

Instead of maximizing \(\log D(G(z))\), match statistics:

\[\min_G ||\mathbb{E}_{x \sim p_{data}}[f(x)] - \mathbb{E}_{z \sim p_z}[f(G(z))]||^2_2\]

where \(f(x)\) are features from an intermediate layer of \(D\)

Technique 4: Minibatch DiscriminationΒΆ

Add a term to discriminator that considers relationships between samples in a batch:

  • Prevents mode collapse by encouraging diversity

  • Discriminator sees β€œcollapse” when all samples in batch are similar

# Implementation: Label Smoothing and Advanced Training
class ImprovedGANTrainer:
    """
    Advanced GAN training with multiple stabilization techniques
    """
    
    def __init__(self, G, D, lr=0.0002, betas=(0.5, 0.999)):
        self.G = G
        self.D = D
        self.optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=betas)
        self.optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=betas)
        self.criterion = nn.BCELoss()
        
    def train_step(self, real_data, noise_dim, 
                   label_smoothing=True, smooth_real=0.9, smooth_fake=0.0):
        """
        Single training step with advanced techniques
        
        Args:
            real_data: Batch of real samples
            noise_dim: Dimension of noise vector
            label_smoothing: Whether to use label smoothing
            smooth_real: Label value for real samples (default 0.9)
            smooth_fake: Label value for fake samples (default 0.0)
        """
        batch_size = real_data.size(0)
        device = real_data.device
        
        # Create labels with optional smoothing
        if label_smoothing:
            real_labels = torch.full((batch_size, 1), smooth_real, device=device)
            fake_labels = torch.full((batch_size, 1), smooth_fake, device=device)
        else:
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)
        
        # =====================
        # Train Discriminator
        # =====================
        self.optimizer_D.zero_grad()
        
        # Real samples
        D_real = self.D(real_data)
        loss_D_real = self.criterion(D_real, real_labels)
        
        # Fake samples
        z = torch.randn(batch_size, noise_dim, device=device)
        fake_data = self.G(z).detach()
        D_fake = self.D(fake_data)
        loss_D_fake = self.criterion(D_fake, fake_labels)
        
        # Combined loss
        loss_D = loss_D_real + loss_D_fake
        loss_D.backward()
        self.optimizer_D.step()
        
        # ==================
        # Train Generator
        # ==================
        self.optimizer_G.zero_grad()
        
        # Generate new fake samples
        z = torch.randn(batch_size, noise_dim, device=device)
        fake_data = self.G(z)
        D_fake = self.D(fake_data)
        
        # Generator loss (fool discriminator)
        loss_G = self.criterion(D_fake, torch.ones(batch_size, 1, device=device))
        loss_G.backward()
        self.optimizer_G.step()
        
        return {
            'loss_D': loss_D.item(),
            'loss_G': loss_G.item(),
            'D_real': D_real.mean().item(),
            'D_fake': D_fake.mean().item()
        }

# Compare vanilla vs improved training
print("="*60)
print("COMPARISON: Vanilla vs Improved GAN Training")
print("="*60)

# Reset models
G_improved = Generator(noise_dim=10).to(device)
D_improved = Discriminator().to(device)

trainer = ImprovedGANTrainer(G_improved, D_improved)

# Train for fewer epochs to demonstrate improvement
n_epochs = 2000
history_improved = {'D_loss': [], 'G_loss': [], 'D_real': [], 'D_fake': []}

print("\nTraining with label smoothing and advanced techniques...")
for epoch in range(n_epochs):
    real_batch = generate_real_data(256).to(device)
    
    # Train with label smoothing
    metrics = trainer.train_step(real_batch, noise_dim=10, 
                                label_smoothing=True, 
                                smooth_real=0.9, 
                                smooth_fake=0.0)
    
    history_improved['D_loss'].append(metrics['loss_D'])
    history_improved['G_loss'].append(metrics['loss_G'])
    history_improved['D_real'].append(metrics['D_real'])
    history_improved['D_fake'].append(metrics['D_fake'])
    
    if (epoch + 1) % 500 == 0:
        print(f"Epoch [{epoch+1}/{n_epochs}] | "
              f"D Loss: {metrics['loss_D']:.4f} | "
              f"G Loss: {metrics['loss_G']:.4f} | "
              f"D(real): {metrics['D_real']:.4f} | "
              f"D(fake): {metrics['D_fake']:.4f}")

print("\nβœ“ Improved training complete!")
print("\nKey Differences:")
print("  - Label smoothing prevents discriminator overconfidence")
print("  - More stable training dynamics")
print("  - Better gradient flow to generator")

5. SummaryΒΆ

Key Mathematical ResultsΒΆ

βœ… Optimal Discriminator: \(D^*(x) = \frac{p_{data}(x)}{p_{data}(x) + p_g(x)}\)

βœ… Global Optimum: \(p_g = p_{data}\) minimizes JS divergence

βœ… Value Function: Minimax game with theoretical guarantees

βœ… Training Dynamics: Alternating optimization

Practical ImplementationΒΆ

βœ… Implemented vanilla GAN from scratch βœ… Trained on 1D mixture distribution βœ… Visualized learning dynamics βœ… Understood training challenges

Next TopicsΒΆ

  • W-GAN: Wasserstein distance for better stability

  • Info-GAN: Interpretable latent codes

  • Conditional GAN: Class-conditional generation

  • Progressive GAN: High-resolution images

Next Notebook: 02_wgan_theory.ipynb