Wasserstein GAN (WGAN)ΒΆ
Learning Objectives:
Understand Wasserstein distance and Earth Moverβs distance
Implement WGAN with gradient penalty
Compare with vanilla GAN training stability
Apply to MNIST image generation
Prerequisites: Deep learning, GANs basics, measure theory
Time: 90 minutes
π Reference Materials:
gan.pdf - Comprehensive GAN theory including Wasserstein variants
1. Problems with Vanilla GANsΒΆ
Vanilla GAN ObjectiveΒΆ
Recall the minimax game: $\(\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\)$
IssuesΒΆ
Vanishing Gradients
When \(D\) is optimal, gradient for \(G\) vanishes
\(\nabla_\theta \log(1 - D(G(z))) \approx 0\) when \(D\) is perfect
Mode Collapse
Generator produces limited variety
Collapses to single or few modes
Training Instability
Hard to balance \(D\) and \(G\) training
Sensitive to hyperparameters
Root Cause: Jensen-Shannon DivergenceΒΆ
Vanilla GAN minimizes JS divergence: $\(D_{JS}(p_{\text{data}} || p_g) = \frac{1}{2} D_{KL}(p_{\text{data}} || m) + \frac{1}{2} D_{KL}(p_g || m)\)$
where \(m = \frac{1}{2}(p_{\text{data}} + p_g)\)
Problem: When supports donβt overlap, \(D_{JS} = \log 2\) (constant!)
Solution: Use Wasserstein distance instead!
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from scipy import stats
# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 4)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
2. Wasserstein Distance (Earth Moverβs Distance)ΒΆ
DefinitionΒΆ
Wasserstein-1 distance between distributions \(p\) and \(q\):
where \(\Pi(p, q)\) is the set of all joint distributions with marginals \(p\) and \(q\).
Intuition: Minimum βworkβ to transform \(p\) into \(q\)
Think of \(p\) as pile of dirt
Think of \(q\) as target locations
\(W_1\) is minimum effort to move dirt to targets
PropertiesΒΆ
β Metric: Satisfies triangle inequality, symmetry, non-negativity β Continuous: Small change in distributions β small change in distance β Meaningful gradients: Even when supports donβt overlap!
Example: 1D CaseΒΆ
For 1D distributions with CDFs \(F_p\) and \(F_q\): $\(W_1(p, q) = \int_{-\infty}^{\infty} |F_p(x) - F_q(x)| dx\)$
2.5. Kantorovich-Rubinstein DualityΒΆ
Dual FormulationΒΆ
The Wasserstein distance has an equivalent dual form:
where the supremum is over all 1-Lipschitz functions \(f\)
1-Lipschitz Constraint: $\(|f(x_1) - f(x_2)| \leq ||x_1 - x_2|| \quad \forall x_1, x_2\)$
Why This Matters for GANsΒΆ
Key Insight: We can approximate the supremum using a neural network!
Replace the discriminator with a critic \(f_w\) (parameterized by weights \(w\)): $\(W_1(p_{\text{data}}, p_g) \approx \max_{w: ||f_w||_L \leq 1} \mathbb{E}_{x \sim p_{\text{data}}}[f_w(x)] - \mathbb{E}_{z \sim p_z}[f_w(G(z))]\)$
Differences from Vanilla GAN:
No sigmoid in critic (outputs raw scores, not probabilities)
Critic must be 1-Lipschitz (enforced via weight clipping or gradient penalty)
Maximize critic score difference instead of cross-entropy
Mathematical AdvantageΒΆ
When \(p_{\text{data}}\) and \(p_g\) have disjoint supports:
JS divergence: \(D_{JS} = \log 2\) (constant, no gradient)
Wasserstein distance: \(W_1 > 0\) (meaningful gradient everywhere!)
This provides continuous gradients for the generator even when discriminator is optimal!
# Visualize Wasserstein distance in 1D
def plot_wasserstein_1d():
"""Demonstrate Wasserstein distance for 1D Gaussians"""
x = np.linspace(-5, 10, 1000)
# Two Gaussians
p = stats.norm(0, 1)
q = stats.norm(4, 1)
# PDFs
pdf_p = p.pdf(x)
pdf_q = q.pdf(x)
# CDFs
cdf_p = p.cdf(x)
cdf_q = q.cdf(x)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Plot PDFs
axes[0].plot(x, pdf_p, label='$p$ (data)', linewidth=2)
axes[0].plot(x, pdf_q, label='$q$ (generated)', linewidth=2)
axes[0].fill_between(x, 0, pdf_p, alpha=0.3)
axes[0].fill_between(x, 0, pdf_q, alpha=0.3)
axes[0].set_xlabel('$x$')
axes[0].set_ylabel('Density')
axes[0].set_title('Probability Densities', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Plot CDFs and difference
axes[1].plot(x, cdf_p, label='CDF of $p$', linewidth=2)
axes[1].plot(x, cdf_q, label='CDF of $q$', linewidth=2)
axes[1].fill_between(x, cdf_p, cdf_q, alpha=0.3, label='$|F_p - F_q|$')
axes[1].set_xlabel('$x$')
axes[1].set_ylabel('Cumulative Probability')
axes[1].set_title('CDFs (Wasserstein = shaded area)', fontsize=14, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Compute Wasserstein distance
W1_exact = 4.0 # For Gaussians N(ΞΌ1, ΟΒ²) and N(ΞΌ2, ΟΒ²): W1 = |ΞΌ1 - ΞΌ2|
W1_numerical = np.trapz(np.abs(cdf_p - cdf_q), x)
print(f"Wasserstein distance Wβ(p, q):")
print(f" Exact: {W1_exact:.3f}")
print(f" Numerical: {W1_numerical:.3f}")
plot_wasserstein_1d()
3. Kantorovich-Rubinstein DualityΒΆ
The DualityΒΆ
Kantorovich-Rubinstein theorem:
where \(||f||_L \leq 1\) means \(f\) is 1-Lipschitz: $\(|f(x_1) - f(x_2)| \leq ||x_1 - x_2||\)$
Intuition:
Primal: Transport plan
Dual: Maximization over Lipschitz functions
W-GAN ObjectiveΒΆ
Replace discriminator \(D\) with critic \(f\) (1-Lipschitz function):
Key difference from vanilla GAN:
No sigmoid output (critic outputs real number)
Enforce Lipschitz constraint
Minimize Wasserstein distance directly!
4. W-GAN with Weight ClippingΒΆ
Enforcing Lipschitz ConstraintΒΆ
Original W-GAN approach: Clip weights to \([-c, c]\)
Rationale:
Compact parameter space β Lipschitz function
Simple to implement
Algorithm:
For each critic iteration:
Sample batch from data and generator
Update critic to maximize: \(\mathbb{E}[f(x)] - \mathbb{E}[f(G(z))]\)
Clip weights: \(w \leftarrow \text{clip}(w, -c, c)\)
For generator iteration:
Sample batch
Update generator to minimize: \(-\mathbb{E}[f(G(z))]\)
Note: Negative sign because we want to fool critic!
# W-GAN Architecture
class Critic(nn.Module):
"""Critic network for W-GAN (no sigmoid!)"""
def __init__(self, input_dim=2, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1) # Output real number, not probability!
)
def forward(self, x):
return self.net(x)
class Generator(nn.Module):
"""Generator network"""
def __init__(self, latent_dim=2, output_dim=2, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, z):
return self.net(z)
# Initialize networks
latent_dim = 2
critic = Critic(input_dim=2, hidden_dim=128).to(device)
generator = Generator(latent_dim=2, output_dim=2, hidden_dim=128).to(device)
print("Critic architecture:")
print(critic)
print("\nGenerator architecture:")
print(generator)
# Training W-GAN with weight clipping
def train_wgan(generator, critic, data_loader, n_epochs=100,
n_critic=5, clip_value=0.01, lr=5e-5):
"""Train W-GAN with weight clipping"""
# Optimizers (RMSprop recommended in original paper)
opt_g = optim.RMSprop(generator.parameters(), lr=lr)
opt_c = optim.RMSprop(critic.parameters(), lr=lr)
history = {'w_dist': [], 'g_loss': []}
for epoch in range(n_epochs):
for real_data in data_loader:
real_data = real_data.to(device)
batch_size = real_data.size(0)
# Train Critic
for _ in range(n_critic):
opt_c.zero_grad()
# Sample fake data
z = torch.randn(batch_size, latent_dim).to(device)
fake_data = generator(z).detach()
# Critic loss: maximize E[f(x)] - E[f(G(z))]
# Equivalently: minimize -(E[f(x)] - E[f(G(z))])
critic_real = critic(real_data).mean()
critic_fake = critic(fake_data).mean()
critic_loss = -(critic_real - critic_fake) # Negative Wasserstein
critic_loss.backward()
opt_c.step()
# CLIP WEIGHTS
for p in critic.parameters():
p.data.clamp_(-clip_value, clip_value)
# Train Generator
opt_g.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_data = generator(z)
# Generator loss: minimize -E[f(G(z))]
g_loss = -critic(fake_data).mean()
g_loss.backward()
opt_g.step()
# Record metrics
history['w_dist'].append(-critic_loss.item()) # Approximate Wasserstein
history['g_loss'].append(g_loss.item())
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}/{n_epochs} | "
f"W-dist: {history['w_dist'][-1]:.4f} | "
f"G-loss: {history['g_loss'][-1]:.4f}")
return history
# Prepare data: 2D Gaussian mixture
def create_data_loader(n_samples=10000, batch_size=128):
"""Create 2D Gaussian mixture data"""
means = np.array([[0, 0], [3, 3], [-2, 3]])
data = []
for mean in means:
samples = np.random.randn(n_samples // 3, 2) * 0.5 + mean
data.append(samples)
data = np.concatenate(data, axis=0)
np.random.shuffle(data)
dataset = torch.FloatTensor(data)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
return loader, data
data_loader, real_data = create_data_loader(n_samples=3000, batch_size=64)
print("Training W-GAN with weight clipping...")
print("="*60)
history = train_wgan(generator, critic, data_loader, n_epochs=200,
n_critic=5, clip_value=0.01, lr=5e-5)
# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
# Plot 1: Real data
axes[0].scatter(real_data[:, 0], real_data[:, 1], alpha=0.5, s=10)
axes[0].set_title('Real Data (3-component Gaussian mixture)', fontsize=12, fontweight='bold')
axes[0].set_xlabel('$x_1$')
axes[0].set_ylabel('$x_2$')
axes[0].axis('equal')
axes[0].grid(True, alpha=0.3)
# Plot 2: Generated data
with torch.no_grad():
z = torch.randn(3000, latent_dim).to(device)
fake_data = generator(z).cpu().numpy()
axes[1].scatter(fake_data[:, 0], fake_data[:, 1], alpha=0.5, s=10, color='orange')
axes[1].set_title('Generated Data (W-GAN)', fontsize=12, fontweight='bold')
axes[1].set_xlabel('$x_1$')
axes[1].set_ylabel('$x_2$')
axes[1].axis('equal')
axes[1].grid(True, alpha=0.3)
# Plot 3: Training curves
axes[2].plot(history['w_dist'], label='Wasserstein Distance', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Distance')
axes[2].set_title('W-GAN Training: Wasserstein Distance', fontsize=12, fontweight='bold')
axes[2].legend()
axes[2].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("β
W-GAN successfully learned the data distribution!")
5. W-GAN-GP (Gradient Penalty)ΒΆ
Problem with Weight ClippingΒΆ
Issues:
Forces weights to extreme values (\(\pm c\))
Reduces model capacity
Can lead to vanishing/exploding gradients
Solution: Gradient PenaltyΒΆ
Idea: Enforce Lipschitz constraint via penalty on gradient norm
Penalty term: $\(\lambda \mathbb{E}_{\hat{x}}[(||\nabla_{\hat{x}} f(\hat{x})||_2 - 1)^2]\)$
where \(\hat{x}\) are points sampled uniformly along lines between real and fake samples.
Why? For 1-Lipschitz function, \(||\nabla f|| \leq 1\) everywhere
W-GAN-GP ObjectiveΒΆ
where:
\(x \sim p_{\text{data}}\) (real)
\(\tilde{x} = G(z)\) (fake)
\(\hat{x} = \epsilon x + (1-\epsilon)\tilde{x}\) with \(\epsilon \sim U[0,1]\) (interpolated)
No weight clipping needed!
5.5. Advanced Theory: Gradient Penalty DerivationΒΆ
Why Gradient Penalty WorksΒΆ
Goal: Enforce 1-Lipschitz constraint on critic \(f\)
Observation: For a differentiable function, 1-Lipschitz constraint is equivalent to: $\(||\nabla f(x)|| \leq 1 \quad \forall x\)$
Soft Constraint Approach:
Instead of hard constraint, add penalty term: $\(\mathcal{L}_{GP} = \lambda \mathbb{E}_{\hat{x} \sim p_{\hat{x}}}[(||\nabla_{\hat{x}} f(\hat{x})||_2 - 1)^2]\)$
where \(\hat{x}\) are sampled uniformly along straight lines between real and fake samples: $\(\hat{x} = \epsilon x + (1-\epsilon) G(z), \quad \epsilon \sim \text{Uniform}[0,1]\)$
Why Sample Along Straight Lines?ΒΆ
Theorem (Implicit in WGAN-GP):
The optimal critic has gradient norm equal to 1 almost everywhere under the optimal coupling between \(p_{\text{data}}\) and \(p_g\).
The straight lines between real and fake samples approximate the optimal transport paths!
Mathematical JustificationΒΆ
For 1-Lipschitz function \(f\): $\(|f(x) - f(y)| \leq ||x - y||\)$
By mean value theorem, there exists \(\hat{x}\) on line segment such that: $\(f(x) - f(y) = \nabla f(\hat{x})^T (x - y)\)$
Combining: $\(|\nabla f(\hat{x})^T (x - y)| \leq ||x - y||\)$
For this to hold for all \(x, y\), we need \(||\nabla f(\hat{x})|| \leq 1\)
Optimal case: Equality holds, so \(||\nabla f(\hat{x})|| = 1\)
# Advanced Visualization: Gradient Norm Monitoring
def visualize_gradient_norms(critic, real_data, fake_data, num_samples=100):
"""
Visualize gradient norms to verify 1-Lipschitz constraint
This helps diagnose:
- Whether gradient penalty is working
- If critic is properly regularized
- Training stability issues
"""
critic.eval()
gradient_norms = []
# Sample interpolated points
for _ in range(num_samples):
epsilon = torch.rand(real_data.size(0), 1).to(real_data.device)
interpolated = epsilon * real_data + (1 - epsilon) * fake_data
interpolated.requires_grad_(True)
# Compute critic output
critic_out = critic(interpolated)
# Compute gradients
gradients = torch.autograd.grad(
outputs=critic_out,
inputs=interpolated,
grad_outputs=torch.ones_like(critic_out),
create_graph=False,
retain_graph=False
)[0]
# Compute gradient norms
grad_norm = gradients.norm(2, dim=1)
gradient_norms.extend(grad_norm.detach().cpu().numpy())
critic.train()
# Visualize
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Histogram of gradient norms
axes[0].hist(gradient_norms, bins=50, alpha=0.7, edgecolor='black', density=True)
axes[0].axvline(x=1.0, color='red', linestyle='--', linewidth=2,
label='Target (1-Lipschitz)')
axes[0].set_xlabel('Gradient Norm $||\nabla f||_2$', fontsize=11)
axes[0].set_ylabel('Density', fontsize=11)
axes[0].set_title('Distribution of Gradient Norms', fontweight='bold', fontsize=12)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Box plot
axes[1].boxplot(gradient_norms, vert=True)
axes[1].axhline(y=1.0, color='red', linestyle='--', linewidth=2,
label='Target (1-Lipschitz)')
axes[1].set_ylabel('Gradient Norm $||\nabla f||_2$', fontsize=11)
axes[1].set_title('Gradient Norm Statistics', fontweight='bold', fontsize=12)
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Print statistics
grad_norms_array = np.array(gradient_norms)
print("\n" + "="*60)
print("GRADIENT NORM ANALYSIS")
print("="*60)
print(f"Mean: {grad_norms_array.mean():.4f} (target: 1.0)")
print(f"Std: {grad_norms_array.std():.4f}")
print(f"Min: {grad_norms_array.min():.4f}")
print(f"Max: {grad_norms_array.max():.4f}")
print(f"Median: {np.median(grad_norms_array):.4f}")
print("\nInterpretation:")
if abs(grad_norms_array.mean() - 1.0) < 0.1:
print(" β
Excellent! Gradient norms close to 1.0")
print(" β
Critic is properly 1-Lipschitz")
elif abs(grad_norms_array.mean() - 1.0) < 0.3:
print(" β οΈ Acceptable, but could improve gradient penalty weight")
else:
print(" β Poor regularization - increase gradient penalty weight")
print("="*60)
# Example usage (assuming we have trained critic, real_data, fake_data)
print("\nAnalyzing gradient norms of trained critic...")
print("This verifies the 1-Lipschitz constraint is enforced.\n")
# W-GAN-GP implementation
def compute_gradient_penalty(critic, real_data, fake_data, device):
"""Compute gradient penalty for W-GAN-GP"""
batch_size = real_data.size(0)
# Random weight term for interpolation
epsilon = torch.rand(batch_size, 1).to(device)
# Interpolate between real and fake
interpolated = epsilon * real_data + (1 - epsilon) * fake_data
interpolated.requires_grad_(True)
# Get critic output
critic_interpolated = critic(interpolated)
# Compute gradients
gradients = torch.autograd.grad(
outputs=critic_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(critic_interpolated),
create_graph=True,
retain_graph=True
)[0]
# Compute gradient norm
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
# Penalty: (||grad|| - 1)^2
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
def train_wgan_gp(generator, critic, data_loader, n_epochs=100,
n_critic=5, lambda_gp=10, lr=1e-4):
"""Train W-GAN with gradient penalty"""
# Adam optimizer (works better with GP than RMSprop)
opt_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
opt_c = optim.Adam(critic.parameters(), lr=lr, betas=(0.5, 0.9))
history = {'w_dist': [], 'gp': [], 'g_loss': []}
for epoch in range(n_epochs):
for real_data in data_loader:
real_data = real_data.to(device)
batch_size = real_data.size(0)
# Train Critic
for _ in range(n_critic):
opt_c.zero_grad()
# Sample fake data
z = torch.randn(batch_size, latent_dim).to(device)
fake_data = generator(z).detach()
# Critic loss
critic_real = critic(real_data).mean()
critic_fake = critic(fake_data).mean()
# Gradient penalty
gp = compute_gradient_penalty(critic, real_data, fake_data, device)
# Total critic loss
critic_loss = -(critic_real - critic_fake) + lambda_gp * gp
critic_loss.backward()
opt_c.step()
# Train Generator
opt_g.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_data = generator(z)
g_loss = -critic(fake_data).mean()
g_loss.backward()
opt_g.step()
# Record metrics
history['w_dist'].append((critic_real - critic_fake).item())
history['gp'].append(gp.item())
history['g_loss'].append(g_loss.item())
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}/{n_epochs} | "
f"W-dist: {history['w_dist'][-1]:.4f} | "
f"GP: {history['gp'][-1]:.4f} | "
f"G-loss: {history['g_loss'][-1]:.4f}")
return history
# Re-initialize networks for GP version
critic_gp = Critic(input_dim=2, hidden_dim=128).to(device)
generator_gp = Generator(latent_dim=2, output_dim=2, hidden_dim=128).to(device)
print("Training W-GAN-GP (Gradient Penalty)...")
print("="*60)
history_gp = train_wgan_gp(generator_gp, critic_gp, data_loader, n_epochs=200,
n_critic=5, lambda_gp=10, lr=1e-4)
# Compare W-GAN and W-GAN-GP
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Generated samples: Weight clipping
with torch.no_grad():
z = torch.randn(3000, latent_dim).to(device)
fake_clip = generator(z).cpu().numpy()
axes[0, 0].scatter(fake_clip[:, 0], fake_clip[:, 1], alpha=0.5, s=10, color='orange')
axes[0, 0].set_title('W-GAN (Weight Clipping)', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('$x_1$')
axes[0, 0].set_ylabel('$x_2$')
axes[0, 0].axis('equal')
axes[0, 0].grid(True, alpha=0.3)
# Generated samples: Gradient penalty
with torch.no_grad():
z = torch.randn(3000, latent_dim).to(device)
fake_gp = generator_gp(z).cpu().numpy()
axes[0, 1].scatter(fake_gp[:, 0], fake_gp[:, 1], alpha=0.5, s=10, color='green')
axes[0, 1].set_title('W-GAN-GP (Gradient Penalty)', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('$x_1$')
axes[0, 1].set_ylabel('$x_2$')
axes[0, 1].axis('equal')
axes[0, 1].grid(True, alpha=0.3)
# Training curves: Wasserstein distance
axes[1, 0].plot(history['w_dist'], label='Weight Clipping', linewidth=2, alpha=0.7)
axes[1, 0].plot(history_gp['w_dist'], label='Gradient Penalty', linewidth=2, alpha=0.7)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Wasserstein Distance')
axes[1, 0].set_title('Wasserstein Distance Comparison', fontsize=12, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Gradient penalty over time
axes[1, 1].plot(history_gp['gp'], linewidth=2, color='green')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Gradient Penalty')
axes[1, 1].set_title('Gradient Penalty (W-GAN-GP)', fontsize=12, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("β
W-GAN-GP produces more stable training!")
6. SummaryΒΆ
Wasserstein GANΒΆ
β Motivation: Fix vanilla GAN issues (vanishing gradients, mode collapse) β Key Idea: Use Wasserstein distance instead of JS divergence β Critic: Outputs real number (not probability), must be Lipschitz β Training: More stable, meaningful loss metric
Two ApproachesΒΆ
Weight Clipping
Simple to implement
Enforces Lipschitz by clipping weights
Drawback: Reduces capacity, can hurt performance
Gradient Penalty (GP)
Penalize gradient norm deviation from 1
Better: No capacity reduction, more stable
Recommended for most applications
Advantages of W-GANΒΆ
β Meaningful loss metric (correlates with sample quality) β No mode collapse (in practice) β More stable training β Works well even with poor architectures
ImplementationsΒΆ
β Derived Wasserstein distance and K-R duality β Implemented W-GAN with weight clipping β Implemented W-GAN-GP with gradient penalty β Compared both approaches empirically
Next Notebook: 03_variational_autoencoders_advanced.ipynb