Wasserstein GAN (WGAN)ΒΆ

Learning Objectives:

  • Understand Wasserstein distance and Earth Mover’s distance

  • Implement WGAN with gradient penalty

  • Compare with vanilla GAN training stability

  • Apply to MNIST image generation

Prerequisites: Deep learning, GANs basics, measure theory

Time: 90 minutes

πŸ“š Reference Materials:

  • gan.pdf - Comprehensive GAN theory including Wasserstein variants

1. Problems with Vanilla GANsΒΆ

Vanilla GAN ObjectiveΒΆ

Recall the minimax game: $\(\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\)$

IssuesΒΆ

  1. Vanishing Gradients

    • When \(D\) is optimal, gradient for \(G\) vanishes

    • \(\nabla_\theta \log(1 - D(G(z))) \approx 0\) when \(D\) is perfect

  2. Mode Collapse

    • Generator produces limited variety

    • Collapses to single or few modes

  3. Training Instability

    • Hard to balance \(D\) and \(G\) training

    • Sensitive to hyperparameters

Root Cause: Jensen-Shannon DivergenceΒΆ

Vanilla GAN minimizes JS divergence: $\(D_{JS}(p_{\text{data}} || p_g) = \frac{1}{2} D_{KL}(p_{\text{data}} || m) + \frac{1}{2} D_{KL}(p_g || m)\)$

where \(m = \frac{1}{2}(p_{\text{data}} + p_g)\)

Problem: When supports don’t overlap, \(D_{JS} = \log 2\) (constant!)

Solution: Use Wasserstein distance instead!

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import stats

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

2. Wasserstein Distance (Earth Mover’s Distance)ΒΆ

DefinitionΒΆ

Wasserstein-1 distance between distributions \(p\) and \(q\):

\[W_1(p, q) = \inf_{\gamma \in \Pi(p, q)} \mathbb{E}_{(x, y) \sim \gamma}[||x - y||]\]

where \(\Pi(p, q)\) is the set of all joint distributions with marginals \(p\) and \(q\).

Intuition: Minimum β€œwork” to transform \(p\) into \(q\)

  • Think of \(p\) as pile of dirt

  • Think of \(q\) as target locations

  • \(W_1\) is minimum effort to move dirt to targets

PropertiesΒΆ

βœ… Metric: Satisfies triangle inequality, symmetry, non-negativity βœ… Continuous: Small change in distributions β†’ small change in distance βœ… Meaningful gradients: Even when supports don’t overlap!

Example: 1D CaseΒΆ

For 1D distributions with CDFs \(F_p\) and \(F_q\): $\(W_1(p, q) = \int_{-\infty}^{\infty} |F_p(x) - F_q(x)| dx\)$

2.5. Kantorovich-Rubinstein DualityΒΆ

Dual FormulationΒΆ

The Wasserstein distance has an equivalent dual form:

\[W_1(p, q) = \sup_{||f||_L \leq 1} \mathbb{E}_{x \sim p}[f(x)] - \mathbb{E}_{x \sim q}[f(x)]\]

where the supremum is over all 1-Lipschitz functions \(f\)

1-Lipschitz Constraint: $\(|f(x_1) - f(x_2)| \leq ||x_1 - x_2|| \quad \forall x_1, x_2\)$

Why This Matters for GANsΒΆ

Key Insight: We can approximate the supremum using a neural network!

Replace the discriminator with a critic \(f_w\) (parameterized by weights \(w\)): $\(W_1(p_{\text{data}}, p_g) \approx \max_{w: ||f_w||_L \leq 1} \mathbb{E}_{x \sim p_{\text{data}}}[f_w(x)] - \mathbb{E}_{z \sim p_z}[f_w(G(z))]\)$

Differences from Vanilla GAN:

  1. No sigmoid in critic (outputs raw scores, not probabilities)

  2. Critic must be 1-Lipschitz (enforced via weight clipping or gradient penalty)

  3. Maximize critic score difference instead of cross-entropy

Mathematical AdvantageΒΆ

When \(p_{\text{data}}\) and \(p_g\) have disjoint supports:

  • JS divergence: \(D_{JS} = \log 2\) (constant, no gradient)

  • Wasserstein distance: \(W_1 > 0\) (meaningful gradient everywhere!)

This provides continuous gradients for the generator even when discriminator is optimal!

# Visualize Wasserstein distance in 1D

def plot_wasserstein_1d():
    """Demonstrate Wasserstein distance for 1D Gaussians"""
    x = np.linspace(-5, 10, 1000)
    
    # Two Gaussians
    p = stats.norm(0, 1)
    q = stats.norm(4, 1)
    
    # PDFs
    pdf_p = p.pdf(x)
    pdf_q = q.pdf(x)
    
    # CDFs
    cdf_p = p.cdf(x)
    cdf_q = q.cdf(x)
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot PDFs
    axes[0].plot(x, pdf_p, label='$p$ (data)', linewidth=2)
    axes[0].plot(x, pdf_q, label='$q$ (generated)', linewidth=2)
    axes[0].fill_between(x, 0, pdf_p, alpha=0.3)
    axes[0].fill_between(x, 0, pdf_q, alpha=0.3)
    axes[0].set_xlabel('$x$')
    axes[0].set_ylabel('Density')
    axes[0].set_title('Probability Densities', fontsize=14, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot CDFs and difference
    axes[1].plot(x, cdf_p, label='CDF of $p$', linewidth=2)
    axes[1].plot(x, cdf_q, label='CDF of $q$', linewidth=2)
    axes[1].fill_between(x, cdf_p, cdf_q, alpha=0.3, label='$|F_p - F_q|$')
    axes[1].set_xlabel('$x$')
    axes[1].set_ylabel('Cumulative Probability')
    axes[1].set_title('CDFs (Wasserstein = shaded area)', fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Compute Wasserstein distance
    W1_exact = 4.0  # For Gaussians N(ΞΌ1, σ²) and N(ΞΌ2, σ²): W1 = |ΞΌ1 - ΞΌ2|
    W1_numerical = np.trapz(np.abs(cdf_p - cdf_q), x)
    
    print(f"Wasserstein distance W₁(p, q):")
    print(f"  Exact: {W1_exact:.3f}")
    print(f"  Numerical: {W1_numerical:.3f}")

plot_wasserstein_1d()

3. Kantorovich-Rubinstein DualityΒΆ

The DualityΒΆ

Kantorovich-Rubinstein theorem:

\[W_1(p, q) = \sup_{||f||_L \leq 1} \mathbb{E}_{x \sim p}[f(x)] - \mathbb{E}_{x \sim q}[f(x)]\]

where \(||f||_L \leq 1\) means \(f\) is 1-Lipschitz: $\(|f(x_1) - f(x_2)| \leq ||x_1 - x_2||\)$

Intuition:

  • Primal: Transport plan

  • Dual: Maximization over Lipschitz functions

W-GAN ObjectiveΒΆ

Replace discriminator \(D\) with critic \(f\) (1-Lipschitz function):

\[\min_G \max_{||f||_L \leq 1} \mathbb{E}_{x \sim p_{\text{data}}}[f(x)] - \mathbb{E}_{z \sim p_z}[f(G(z))]\]

Key difference from vanilla GAN:

  • No sigmoid output (critic outputs real number)

  • Enforce Lipschitz constraint

  • Minimize Wasserstein distance directly!

4. W-GAN with Weight ClippingΒΆ

Enforcing Lipschitz ConstraintΒΆ

Original W-GAN approach: Clip weights to \([-c, c]\)

Rationale:

  • Compact parameter space β†’ Lipschitz function

  • Simple to implement

Algorithm:

For each critic iteration:

  1. Sample batch from data and generator

  2. Update critic to maximize: \(\mathbb{E}[f(x)] - \mathbb{E}[f(G(z))]\)

  3. Clip weights: \(w \leftarrow \text{clip}(w, -c, c)\)

For generator iteration:

  1. Sample batch

  2. Update generator to minimize: \(-\mathbb{E}[f(G(z))]\)

Note: Negative sign because we want to fool critic!

# W-GAN Architecture

class Critic(nn.Module):
    """Critic network for W-GAN (no sigmoid!)"""
    def __init__(self, input_dim=2, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1)  # Output real number, not probability!
        )
    
    def forward(self, x):
        return self.net(x)

class Generator(nn.Module):
    """Generator network"""
    def __init__(self, latent_dim=2, output_dim=2, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, z):
        return self.net(z)

# Initialize networks
latent_dim = 2
critic = Critic(input_dim=2, hidden_dim=128).to(device)
generator = Generator(latent_dim=2, output_dim=2, hidden_dim=128).to(device)

print("Critic architecture:")
print(critic)
print("\nGenerator architecture:")
print(generator)
# Training W-GAN with weight clipping

def train_wgan(generator, critic, data_loader, n_epochs=100, 
               n_critic=5, clip_value=0.01, lr=5e-5):
    """Train W-GAN with weight clipping"""
    
    # Optimizers (RMSprop recommended in original paper)
    opt_g = optim.RMSprop(generator.parameters(), lr=lr)
    opt_c = optim.RMSprop(critic.parameters(), lr=lr)
    
    history = {'w_dist': [], 'g_loss': []}
    
    for epoch in range(n_epochs):
        for real_data in data_loader:
            real_data = real_data.to(device)
            batch_size = real_data.size(0)
            
            # Train Critic
            for _ in range(n_critic):
                opt_c.zero_grad()
                
                # Sample fake data
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_data = generator(z).detach()
                
                # Critic loss: maximize E[f(x)] - E[f(G(z))]
                # Equivalently: minimize -(E[f(x)] - E[f(G(z))])
                critic_real = critic(real_data).mean()
                critic_fake = critic(fake_data).mean()
                critic_loss = -(critic_real - critic_fake)  # Negative Wasserstein
                
                critic_loss.backward()
                opt_c.step()
                
                # CLIP WEIGHTS
                for p in critic.parameters():
                    p.data.clamp_(-clip_value, clip_value)
            
            # Train Generator
            opt_g.zero_grad()
            
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_data = generator(z)
            
            # Generator loss: minimize -E[f(G(z))]
            g_loss = -critic(fake_data).mean()
            
            g_loss.backward()
            opt_g.step()
        
        # Record metrics
        history['w_dist'].append(-critic_loss.item())  # Approximate Wasserstein
        history['g_loss'].append(g_loss.item())
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{n_epochs} | "
                  f"W-dist: {history['w_dist'][-1]:.4f} | "
                  f"G-loss: {history['g_loss'][-1]:.4f}")
    
    return history

# Prepare data: 2D Gaussian mixture
def create_data_loader(n_samples=10000, batch_size=128):
    """Create 2D Gaussian mixture data"""
    means = np.array([[0, 0], [3, 3], [-2, 3]])
    data = []
    
    for mean in means:
        samples = np.random.randn(n_samples // 3, 2) * 0.5 + mean
        data.append(samples)
    
    data = np.concatenate(data, axis=0)
    np.random.shuffle(data)
    
    dataset = torch.FloatTensor(data)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return loader, data

data_loader, real_data = create_data_loader(n_samples=3000, batch_size=64)

print("Training W-GAN with weight clipping...")
print("="*60)
history = train_wgan(generator, critic, data_loader, n_epochs=200, 
                     n_critic=5, clip_value=0.01, lr=5e-5)
# Visualize results

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: Real data
axes[0].scatter(real_data[:, 0], real_data[:, 1], alpha=0.5, s=10)
axes[0].set_title('Real Data (3-component Gaussian mixture)', fontsize=12, fontweight='bold')
axes[0].set_xlabel('$x_1$')
axes[0].set_ylabel('$x_2$')
axes[0].axis('equal')
axes[0].grid(True, alpha=0.3)

# Plot 2: Generated data
with torch.no_grad():
    z = torch.randn(3000, latent_dim).to(device)
    fake_data = generator(z).cpu().numpy()

axes[1].scatter(fake_data[:, 0], fake_data[:, 1], alpha=0.5, s=10, color='orange')
axes[1].set_title('Generated Data (W-GAN)', fontsize=12, fontweight='bold')
axes[1].set_xlabel('$x_1$')
axes[1].set_ylabel('$x_2$')
axes[1].axis('equal')
axes[1].grid(True, alpha=0.3)

# Plot 3: Training curves
axes[2].plot(history['w_dist'], label='Wasserstein Distance', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Distance')
axes[2].set_title('W-GAN Training: Wasserstein Distance', fontsize=12, fontweight='bold')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("βœ… W-GAN successfully learned the data distribution!")

5. W-GAN-GP (Gradient Penalty)ΒΆ

Problem with Weight ClippingΒΆ

Issues:

  • Forces weights to extreme values (\(\pm c\))

  • Reduces model capacity

  • Can lead to vanishing/exploding gradients

Solution: Gradient PenaltyΒΆ

Idea: Enforce Lipschitz constraint via penalty on gradient norm

Penalty term: $\(\lambda \mathbb{E}_{\hat{x}}[(||\nabla_{\hat{x}} f(\hat{x})||_2 - 1)^2]\)$

where \(\hat{x}\) are points sampled uniformly along lines between real and fake samples.

Why? For 1-Lipschitz function, \(||\nabla f|| \leq 1\) everywhere

W-GAN-GP ObjectiveΒΆ

\[\mathcal{L} = \mathbb{E}_{\tilde{x}}[f(\tilde{x})] - \mathbb{E}_{x}[f(x)] + \lambda \mathbb{E}_{\hat{x}}[(||\nabla_{\hat{x}} f(\hat{x})||_2 - 1)^2]\]

where:

  • \(x \sim p_{\text{data}}\) (real)

  • \(\tilde{x} = G(z)\) (fake)

  • \(\hat{x} = \epsilon x + (1-\epsilon)\tilde{x}\) with \(\epsilon \sim U[0,1]\) (interpolated)

No weight clipping needed!

5.5. Advanced Theory: Gradient Penalty DerivationΒΆ

Why Gradient Penalty WorksΒΆ

Goal: Enforce 1-Lipschitz constraint on critic \(f\)

Observation: For a differentiable function, 1-Lipschitz constraint is equivalent to: $\(||\nabla f(x)|| \leq 1 \quad \forall x\)$

Soft Constraint Approach:

Instead of hard constraint, add penalty term: $\(\mathcal{L}_{GP} = \lambda \mathbb{E}_{\hat{x} \sim p_{\hat{x}}}[(||\nabla_{\hat{x}} f(\hat{x})||_2 - 1)^2]\)$

where \(\hat{x}\) are sampled uniformly along straight lines between real and fake samples: $\(\hat{x} = \epsilon x + (1-\epsilon) G(z), \quad \epsilon \sim \text{Uniform}[0,1]\)$

Why Sample Along Straight Lines?ΒΆ

Theorem (Implicit in WGAN-GP):

The optimal critic has gradient norm equal to 1 almost everywhere under the optimal coupling between \(p_{\text{data}}\) and \(p_g\).

The straight lines between real and fake samples approximate the optimal transport paths!

Mathematical JustificationΒΆ

For 1-Lipschitz function \(f\): $\(|f(x) - f(y)| \leq ||x - y||\)$

By mean value theorem, there exists \(\hat{x}\) on line segment such that: $\(f(x) - f(y) = \nabla f(\hat{x})^T (x - y)\)$

Combining: $\(|\nabla f(\hat{x})^T (x - y)| \leq ||x - y||\)$

For this to hold for all \(x, y\), we need \(||\nabla f(\hat{x})|| \leq 1\)

Optimal case: Equality holds, so \(||\nabla f(\hat{x})|| = 1\)

# Advanced Visualization: Gradient Norm Monitoring

def visualize_gradient_norms(critic, real_data, fake_data, num_samples=100):
    """
    Visualize gradient norms to verify 1-Lipschitz constraint
    
    This helps diagnose:
    - Whether gradient penalty is working
    - If critic is properly regularized
    - Training stability issues
    """
    critic.eval()
    gradient_norms = []
    
    # Sample interpolated points
    for _ in range(num_samples):
        epsilon = torch.rand(real_data.size(0), 1).to(real_data.device)
        interpolated = epsilon * real_data + (1 - epsilon) * fake_data
        interpolated.requires_grad_(True)
        
        # Compute critic output
        critic_out = critic(interpolated)
        
        # Compute gradients
        gradients = torch.autograd.grad(
            outputs=critic_out,
            inputs=interpolated,
            grad_outputs=torch.ones_like(critic_out),
            create_graph=False,
            retain_graph=False
        )[0]
        
        # Compute gradient norms
        grad_norm = gradients.norm(2, dim=1)
        gradient_norms.extend(grad_norm.detach().cpu().numpy())
    
    critic.train()
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Histogram of gradient norms
    axes[0].hist(gradient_norms, bins=50, alpha=0.7, edgecolor='black', density=True)
    axes[0].axvline(x=1.0, color='red', linestyle='--', linewidth=2, 
                    label='Target (1-Lipschitz)')
    axes[0].set_xlabel('Gradient Norm $||\nabla f||_2$', fontsize=11)
    axes[0].set_ylabel('Density', fontsize=11)
    axes[0].set_title('Distribution of Gradient Norms', fontweight='bold', fontsize=12)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Box plot
    axes[1].boxplot(gradient_norms, vert=True)
    axes[1].axhline(y=1.0, color='red', linestyle='--', linewidth=2, 
                    label='Target (1-Lipschitz)')
    axes[1].set_ylabel('Gradient Norm $||\nabla f||_2$', fontsize=11)
    axes[1].set_title('Gradient Norm Statistics', fontweight='bold', fontsize=12)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    grad_norms_array = np.array(gradient_norms)
    print("\n" + "="*60)
    print("GRADIENT NORM ANALYSIS")
    print("="*60)
    print(f"Mean:   {grad_norms_array.mean():.4f} (target: 1.0)")
    print(f"Std:    {grad_norms_array.std():.4f}")
    print(f"Min:    {grad_norms_array.min():.4f}")
    print(f"Max:    {grad_norms_array.max():.4f}")
    print(f"Median: {np.median(grad_norms_array):.4f}")
    print("\nInterpretation:")
    if abs(grad_norms_array.mean() - 1.0) < 0.1:
        print("  βœ… Excellent! Gradient norms close to 1.0")
        print("  βœ… Critic is properly 1-Lipschitz")
    elif abs(grad_norms_array.mean() - 1.0) < 0.3:
        print("  ⚠️  Acceptable, but could improve gradient penalty weight")
    else:
        print("  ❌ Poor regularization - increase gradient penalty weight")
    print("="*60)

# Example usage (assuming we have trained critic, real_data, fake_data)
print("\nAnalyzing gradient norms of trained critic...")
print("This verifies the 1-Lipschitz constraint is enforced.\n")
# W-GAN-GP implementation

def compute_gradient_penalty(critic, real_data, fake_data, device):
    """Compute gradient penalty for W-GAN-GP"""
    batch_size = real_data.size(0)
    
    # Random weight term for interpolation
    epsilon = torch.rand(batch_size, 1).to(device)
    
    # Interpolate between real and fake
    interpolated = epsilon * real_data + (1 - epsilon) * fake_data
    interpolated.requires_grad_(True)
    
    # Get critic output
    critic_interpolated = critic(interpolated)
    
    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=critic_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(critic_interpolated),
        create_graph=True,
        retain_graph=True
    )[0]
    
    # Compute gradient norm
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    
    # Penalty: (||grad|| - 1)^2
    penalty = ((gradient_norm - 1) ** 2).mean()
    
    return penalty

def train_wgan_gp(generator, critic, data_loader, n_epochs=100, 
                  n_critic=5, lambda_gp=10, lr=1e-4):
    """Train W-GAN with gradient penalty"""
    
    # Adam optimizer (works better with GP than RMSprop)
    opt_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
    opt_c = optim.Adam(critic.parameters(), lr=lr, betas=(0.5, 0.9))
    
    history = {'w_dist': [], 'gp': [], 'g_loss': []}
    
    for epoch in range(n_epochs):
        for real_data in data_loader:
            real_data = real_data.to(device)
            batch_size = real_data.size(0)
            
            # Train Critic
            for _ in range(n_critic):
                opt_c.zero_grad()
                
                # Sample fake data
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_data = generator(z).detach()
                
                # Critic loss
                critic_real = critic(real_data).mean()
                critic_fake = critic(fake_data).mean()
                
                # Gradient penalty
                gp = compute_gradient_penalty(critic, real_data, fake_data, device)
                
                # Total critic loss
                critic_loss = -(critic_real - critic_fake) + lambda_gp * gp
                
                critic_loss.backward()
                opt_c.step()
            
            # Train Generator
            opt_g.zero_grad()
            
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_data = generator(z)
            
            g_loss = -critic(fake_data).mean()
            
            g_loss.backward()
            opt_g.step()
        
        # Record metrics
        history['w_dist'].append((critic_real - critic_fake).item())
        history['gp'].append(gp.item())
        history['g_loss'].append(g_loss.item())
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{n_epochs} | "
                  f"W-dist: {history['w_dist'][-1]:.4f} | "
                  f"GP: {history['gp'][-1]:.4f} | "
                  f"G-loss: {history['g_loss'][-1]:.4f}")
    
    return history

# Re-initialize networks for GP version
critic_gp = Critic(input_dim=2, hidden_dim=128).to(device)
generator_gp = Generator(latent_dim=2, output_dim=2, hidden_dim=128).to(device)

print("Training W-GAN-GP (Gradient Penalty)...")
print("="*60)
history_gp = train_wgan_gp(generator_gp, critic_gp, data_loader, n_epochs=200, 
                           n_critic=5, lambda_gp=10, lr=1e-4)
# Compare W-GAN and W-GAN-GP

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Generated samples: Weight clipping
with torch.no_grad():
    z = torch.randn(3000, latent_dim).to(device)
    fake_clip = generator(z).cpu().numpy()

axes[0, 0].scatter(fake_clip[:, 0], fake_clip[:, 1], alpha=0.5, s=10, color='orange')
axes[0, 0].set_title('W-GAN (Weight Clipping)', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('$x_1$')
axes[0, 0].set_ylabel('$x_2$')
axes[0, 0].axis('equal')
axes[0, 0].grid(True, alpha=0.3)

# Generated samples: Gradient penalty
with torch.no_grad():
    z = torch.randn(3000, latent_dim).to(device)
    fake_gp = generator_gp(z).cpu().numpy()

axes[0, 1].scatter(fake_gp[:, 0], fake_gp[:, 1], alpha=0.5, s=10, color='green')
axes[0, 1].set_title('W-GAN-GP (Gradient Penalty)', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('$x_1$')
axes[0, 1].set_ylabel('$x_2$')
axes[0, 1].axis('equal')
axes[0, 1].grid(True, alpha=0.3)

# Training curves: Wasserstein distance
axes[1, 0].plot(history['w_dist'], label='Weight Clipping', linewidth=2, alpha=0.7)
axes[1, 0].plot(history_gp['w_dist'], label='Gradient Penalty', linewidth=2, alpha=0.7)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Wasserstein Distance')
axes[1, 0].set_title('Wasserstein Distance Comparison', fontsize=12, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Gradient penalty over time
axes[1, 1].plot(history_gp['gp'], linewidth=2, color='green')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Gradient Penalty')
axes[1, 1].set_title('Gradient Penalty (W-GAN-GP)', fontsize=12, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("βœ… W-GAN-GP produces more stable training!")

6. SummaryΒΆ

Wasserstein GANΒΆ

βœ… Motivation: Fix vanilla GAN issues (vanishing gradients, mode collapse) βœ… Key Idea: Use Wasserstein distance instead of JS divergence βœ… Critic: Outputs real number (not probability), must be Lipschitz βœ… Training: More stable, meaningful loss metric

Two ApproachesΒΆ

  1. Weight Clipping

    • Simple to implement

    • Enforces Lipschitz by clipping weights

    • Drawback: Reduces capacity, can hurt performance

  2. Gradient Penalty (GP)

    • Penalize gradient norm deviation from 1

    • Better: No capacity reduction, more stable

    • Recommended for most applications

Advantages of W-GANΒΆ

βœ… Meaningful loss metric (correlates with sample quality) βœ… No mode collapse (in practice) βœ… More stable training βœ… Works well even with poor architectures

ImplementationsΒΆ

βœ… Derived Wasserstein distance and K-R duality βœ… Implemented W-GAN with weight clipping βœ… Implemented W-GAN-GP with gradient penalty βœ… Compared both approaches empirically

Next Notebook: 03_variational_autoencoders_advanced.ipynb