import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Style-Based GeneratorΒΆ
Key InnovationΒΆ
Mapping network \(f: \mathcal{Z} \to \mathcal{W}\)
Latent \(z \sim \mathcal{N}(0, I)\)
Intermediate \(w = f(z)\)
Controls synthesis via AdaIN
Adaptive Instance NormalizationΒΆ
where \(y = (y_s, y_b)\) from affine transform of \(w\).
π Reference Materials:
gan.pdf - Gan
class AdaptiveInstanceNorm(nn.Module):
"""AdaIN layer."""
def __init__(self, num_features, w_dim):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.style = nn.Linear(w_dim, num_features * 2)
def forward(self, x, w):
style = self.style(w).unsqueeze(2).unsqueeze(3)
gamma, beta = style.chunk(2, dim=1)
out = self.norm(x)
return gamma * out + beta
# Test
adain = AdaptiveInstanceNorm(64, 512).to(device)
x = torch.randn(4, 64, 16, 16).to(device)
w = torch.randn(4, 512).to(device)
out = adain(x, w)
print(f"AdaIN output: {out.shape}")
Generator ArchitectureΒΆ
StyleGANβs generator replaces the traditional single-pass architecture with a mapping network followed by a synthesis network. The mapping network transforms the input latent \(z\) through 8 fully connected layers into an intermediate latent \(w\), which lives in a more disentangled \(\mathcal{W}\) space. The synthesis network then uses \(w\) to modulate feature maps at each resolution through adaptive instance normalization (AdaIN): \(\text{AdaIN}(x_i, y) = y_{s,i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}\). This style-based design allows independent control of coarse, medium, and fine image attributes by injecting different \(w\) vectors at different layers.
class MappingNetwork(nn.Module):
"""Maps z to w."""
def __init__(self, z_dim=512, w_dim=512, n_layers=8):
super().__init__()
layers = []
for i in range(n_layers):
layers.extend([
nn.Linear(z_dim if i == 0 else w_dim, w_dim),
nn.LeakyReLU(0.2)
])
self.mapping = nn.Sequential(*layers)
def forward(self, z):
return self.mapping(z)
class StyleBlock(nn.Module):
"""Synthesis block with AdaIN."""
def __init__(self, in_channels, out_channels, w_dim):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.adain = AdaptiveInstanceNorm(out_channels, w_dim)
self.activation = nn.LeakyReLU(0.2)
def forward(self, x, w):
x = self.conv(x)
x = self.adain(x, w)
return self.activation(x)
class StyleGAN_Generator(nn.Module):
"""Simplified StyleGAN generator."""
def __init__(self, z_dim=512, w_dim=512, img_size=64, img_channels=1):
super().__init__()
self.img_size = img_size
# Mapping network
self.mapping = MappingNetwork(z_dim, w_dim)
# Constant input
self.const = nn.Parameter(torch.randn(1, 512, 4, 4))
# Synthesis blocks
self.block1 = StyleBlock(512, 256, w_dim)
self.block2 = StyleBlock(256, 128, w_dim)
self.block3 = StyleBlock(128, 64, w_dim)
self.block4 = StyleBlock(64, 32, w_dim)
# To RGB
self.to_rgb = nn.Conv2d(32, img_channels, 1)
def forward(self, z, truncation=1.0):
# Map to w
w = self.mapping(z)
# Truncation trick
if truncation < 1.0:
w = truncation * w
# Start from constant
x = self.const.repeat(z.size(0), 1, 1, 1)
# Progressive upsampling
x = self.block1(x, w)
x = F.interpolate(x, scale_factor=2)
x = self.block2(x, w)
x = F.interpolate(x, scale_factor=2)
x = self.block3(x, w)
x = F.interpolate(x, scale_factor=2)
x = self.block4(x, w)
x = F.interpolate(x, scale_factor=2)
# To RGB
return torch.tanh(self.to_rgb(x))
# Test
gen = StyleGAN_Generator(img_size=64).to(device)
z = torch.randn(4, 512).to(device)
imgs = gen(z)
print(f"Generated images: {imgs.shape}")
Training on MNISTΒΆ
While StyleGAN was designed for high-resolution face generation, training a simplified version on MNIST demonstrates the core architectural ideas at a manageable scale. The progressive growing strategy (if used) starts at low resolution (\(4 \times 4\)) and gradually adds layers, stabilizing training at each stage before introducing finer detail. Even on MNIST, the style-based injection should produce better latent space disentanglement than a standard GAN, which we can verify by manipulating styles at different layers.
class Discriminator(nn.Module):
"""Simple discriminator."""
def __init__(self, img_channels=1):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(img_channels, 32, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(256 * 4 * 4, 1)
)
def forward(self, x):
return self.model(x)
# Load data
transform = transforms.Compose([
transforms.Resize(64),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)
# Models
G = StyleGAN_Generator().to(device)
D = Discriminator().to(device)
# Optimizers
opt_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
print("Setup complete")
def train_stylegan(n_epochs=5):
losses_G, losses_D = [], []
for epoch in range(n_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
real_imgs = real_imgs.to(device)
batch_size = real_imgs.size(0)
# Train Discriminator
z = torch.randn(batch_size, 512).to(device)
fake_imgs = G(z)
real_loss = F.binary_cross_entropy_with_logits(
D(real_imgs), torch.ones(batch_size, 1).to(device)
)
fake_loss = F.binary_cross_entropy_with_logits(
D(fake_imgs.detach()), torch.zeros(batch_size, 1).to(device)
)
d_loss = (real_loss + fake_loss) / 2
opt_D.zero_grad()
d_loss.backward()
opt_D.step()
# Train Generator
z = torch.randn(batch_size, 512).to(device)
fake_imgs = G(z)
g_loss = F.binary_cross_entropy_with_logits(
D(fake_imgs), torch.ones(batch_size, 1).to(device)
)
opt_G.zero_grad()
g_loss.backward()
opt_G.step()
losses_G.append(g_loss.item())
losses_D.append(d_loss.item())
if i % 200 == 0:
print(f"Epoch {epoch}, Batch {i}, D: {d_loss.item():.4f}, G: {g_loss.item():.4f}")
return losses_G, losses_D
losses_G, losses_D = train_stylegan(n_epochs=5)
plt.figure(figsize=(10, 5))
plt.plot(losses_G, alpha=0.5, label='Generator')
plt.plot(losses_D, alpha=0.5, label='Discriminator')
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('Loss', fontsize=11)
plt.title('StyleGAN Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Style MixingΒΆ
Style mixing is a key capability of StyleGAN: by injecting the \(w\) vector from one latent code at coarse layers and from another at fine layers, we can combine the overall structure of one image with the fine details of another. For example, in face generation, coarse styles control pose and face shape while fine styles control hair texture and background. On MNIST, style mixing might combine the global shape of one digit with the stroke style of another, demonstrating that different layers have learned to control different levels of abstraction.
def style_mixing_demo():
"""Demonstrate style mixing."""
G.eval()
with torch.no_grad():
# Source styles
z1 = torch.randn(1, 512).to(device)
z2 = torch.randn(1, 512).to(device)
# Generate with pure styles
img1 = G(z1)
img2 = G(z2)
# Mixed: coarse from z1, fine from z2
w1 = G.mapping(z1)
w2 = G.mapping(z2)
# Manual forward with mixed styles
x = G.const.repeat(1, 1, 1, 1)
x = G.block1(x, w1) # Coarse from z1
x = F.interpolate(x, scale_factor=2)
x = G.block2(x, w1)
x = F.interpolate(x, scale_factor=2)
x = G.block3(x, w2) # Fine from z2
x = F.interpolate(x, scale_factor=2)
x = G.block4(x, w2)
x = F.interpolate(x, scale_factor=2)
img_mixed = torch.tanh(G.to_rgb(x))
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(img1[0, 0].cpu(), cmap='gray')
axes[0].set_title('Style 1', fontsize=11)
axes[0].axis('off')
axes[1].imshow(img_mixed[0, 0].cpu(), cmap='gray')
axes[1].set_title('Mixed (coarse:1, fine:2)', fontsize=11)
axes[1].axis('off')
axes[2].imshow(img2[0, 0].cpu(), cmap='gray')
axes[2].set_title('Style 2', fontsize=11)
axes[2].axis('off')
plt.tight_layout()
plt.show()
style_mixing_demo()
Truncation TrickΒΆ
The truncation trick improves sample quality at the expense of diversity by interpolating the \(w\) vector toward the mean of \(\mathcal{W}\): \(w' = \bar{w} + \psi (w - \bar{w})\), where \(\psi \in [0, 1]\) controls the trade-off. At \(\psi = 1\) we get full diversity (standard sampling); at \(\psi = 0\) all samples converge to the βaverageβ image. Values around \(\psi = 0.7\) typically produce the best visual quality. This technique exploits the fact that the density of training data is highest near the mean of the latent distribution, so truncating avoids low-density regions that may produce artifacts.
def truncation_comparison():
"""Compare different truncation values."""
G.eval()
z = torch.randn(1, 512).to(device)
truncations = [0.5, 0.7, 1.0, 1.5]
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
with torch.no_grad():
for i, trunc in enumerate(truncations):
img = G(z, truncation=trunc)
axes[i].imshow(img[0, 0].cpu(), cmap='gray')
axes[i].set_title(f'Ο={trunc}', fontsize=11)
axes[i].axis('off')
plt.suptitle('Truncation Trick', fontsize=12)
plt.tight_layout()
plt.show()
truncation_comparison()
Generate GridΒΆ
Generating a grid of samples provides a comprehensive visual assessment of the modelβs output quality and diversity. Each image in the grid is produced from an independent random latent vector, so the grid should exhibit varied attributes (different digit identities, stroke widths, orientations) while maintaining consistent quality. Inspecting the grid for repeated patterns, artifacts, or missing modes helps diagnose training issues like mode collapse or insufficient model capacity.
G.eval()
with torch.no_grad():
z = torch.randn(16, 512).to(device)
imgs = G(z, truncation=0.7)
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(16):
ax = axes[i // 4, i % 4]
ax.imshow(imgs[i, 0].cpu(), cmap='gray')
ax.axis('off')
plt.suptitle('StyleGAN Generated Samples', fontsize=13)
plt.tight_layout()
plt.show()
SummaryΒΆ
StyleGAN Innovations:ΒΆ
Mapping network - \(z \to w\) disentanglement
AdaIN - Style control at each layer
Progressive growing - High-resolution synthesis
Style mixing - Localized control
Key Benefits:ΒΆ
Better disentanglement
Controllable generation
High quality images
Interpretable latent space
Applications:ΒΆ
Face generation
Image editing
Style transfer
Data augmentation
Extensions:ΒΆ
StyleGAN2 (improved quality)
StyleGAN3 (alias-free)
StyleGAN-XL (large scale)
Next Steps:ΒΆ
02_wgan_theory_implementation.ipynb - WGAN
Study progressive growing
Explore GANSpace editing
Advanced StyleGAN: Mathematical Foundations and Architecture Deep DiveΒΆ
Table of ContentsΒΆ
Introduction to StyleGAN
Motivation and Key Innovations
Style-Based Generator Architecture
Adaptive Instance Normalization (AdaIN)
Mapping Network
Noise Injection
Mixing Regularization
Progressive Growing
StyleGAN2 Improvements
StyleGAN3: Alias-Free Generation
Latent Space Analysis
Applications and Extensions
Implementation Details
1. Introduction to StyleGANΒΆ
What is StyleGAN?ΒΆ
StyleGAN (Style-Based Generator Architecture for GANs) is a revolutionary generative model introduced by NVIDIA in 2018 that generates highly realistic images with unprecedented control over image synthesis.
Key Achievement: Generates photorealistic faces indistinguishable from real photos
Core Idea: Instead of feeding latent code directly into generator, transform it into intermediate latent space that controls style at different scales.
Evolution TimelineΒΆ
StyleGAN (Dec 2018): Style-based architecture, AdaIN, mixing regularization
StyleGAN2 (Dec 2019): Fixes artifacts, improved quality, removes progressive growing
StyleGAN3 (June 2021): Alias-free generation, better texture sticking
Why StyleGAN MattersΒΆ
Before StyleGAN: Limited control, mode collapse, artifacts, entangled representations
After StyleGAN:
Disentangled Latent Space: Independent control of features (hair, pose, age, expression)
High Quality: 1024Γ1024 photorealistic images
Controllability: Precise manipulation of attributes
Interpolation: Smooth transitions between images
Artistic Applications: Face editing, style transfer, image synthesis
2. Motivation and Key InnovationsΒΆ
Problems with Traditional GANsΒΆ
1. Entangled Latent Space
In standard GANs: $\( \mathbf{z} \sim \mathcal{N}(0, I) \quad \rightarrow \quad G(\mathbf{z}) = \mathbf{x} \)$
Problem: Changing \(\mathbf{z}\) affects multiple attributes simultaneously
Moving in latent space changes age + pose + expression
No semantic meaning to latent directions
Difficult to find interpretable controls
2. Limited Scale-Specific Control
Traditional generators use \(\mathbf{z}\) only at input:
All scales (coarse to fine) derived from same \(\mathbf{z}\)
Cannot control high-level structure separately from fine details
E.g., cannot change pose without affecting skin texture
3. Lack of Stochasticity
Real images have stochastic variation (hair placement, pores, wrinkles) not captured by deterministic \(G(\mathbf{z})\).
StyleGANβs SolutionsΒΆ
1. Mapping Network: \(\mathbf{z} \rightarrow \mathbf{w}\)
Transform initial latent \(\mathbf{z}\) into intermediate latent \(\mathbf{w}\)
\(\mathbf{w}\) is more disentangled and easier to control
2. Style-Based Generation:
Feed \(\mathbf{w}\) into generator at multiple scales via AdaIN
Each scale controlled independently
Coarse levels: pose, face shape
Fine levels: hair, skin texture
3. Noise Injection:
Add stochastic noise at each layer
Generate realistic micro-variations (hair strands, pores)
4. Mixing Regularization:
Mix two latent codes \(\mathbf{w}_1, \mathbf{w}_2\) at different layers
Forces disentanglement (coarse vs. fine features)
3. Style-Based Generator ArchitectureΒΆ
Overall ArchitectureΒΆ
z (512) β Mapping Network β w (512)
β
[Duplicate 18 times]
β
w_1, w_2, ..., w_18
β
Synthesis Network:
Const 4Γ4Γ512
β (AdaIN with w_1, w_2)
Conv 4Γ4
β (AdaIN with w_3, w_4)
Upsample β 8Γ8
β (AdaIN with w_5, w_6)
Conv 8Γ8
...
β (AdaIN with w_17, w_18)
Conv 1024Γ1024 β RGB output
Key Difference from Traditional GAN:
Traditional: $\( \text{Input: } \mathbf{z} \rightarrow \text{Conv layers} \rightarrow \text{Output} \)$
StyleGAN: $\( \text{Const input} \rightarrow \text{Conv + AdaIN}(\mathbf{w}_i) \rightarrow \text{Output} \)$
Advantages:
Latent \(\mathbf{w}\) not constrained by input spatial structure
Each layer independently controlled
Better disentanglement
Synthesis Network DetailsΒΆ
Starting Point: Learned constant \(4 \times 4 \times 512\) tensor (not random noise!)
Each Resolution Block:
AdaIN with style \(\mathbf{w}_i\): Modulates feature statistics
Conv 3Γ3: Processes features
Noise Injection: Adds stochastic detail
Activation: Leaky ReLU
Repeat AdaIN + Conv for second layer at same resolution
Upsample: Increase resolution (nearest neighbor or bilinear)
Output: RGB image via \(1 \times 1\) convolution
Total Layers:
1024Γ1024 output: 18 layers (2 per resolution from 4Γ4 to 1024Γ1024)
Each layer gets its own \(\mathbf{w}_i\)
4. Adaptive Instance Normalization (AdaIN)ΒΆ
Standard Instance NormalizationΒΆ
Instance Normalization normalizes each channel independently per sample:
where:
\(\mathbf{x}_{i,c}\): Feature at spatial position \(i\), channel \(c\)
\(\mu_c = \frac{1}{HW} \sum_{i} \mathbf{x}_{i,c}\): Mean of channel \(c\)
\(\sigma_c = \sqrt{\frac{1}{HW} \sum_{i} (\mathbf{x}_{i,c} - \mu_c)^2 + \epsilon}\): Std of channel \(c\)
Adaptive Instance Normalization (AdaIN)ΒΆ
Idea: After normalization, apply learned affine transformation based on style \(\mathbf{w}\):
where:
\(\mathbf{y}_s(\mathbf{w}) = A_s \mathbf{w}\): Scale (learned linear projection)
\(\mathbf{y}_b(\mathbf{w}) = A_b \mathbf{w}\): Bias (learned linear projection)
Per Channel: Each of \(C\) channels gets its own scale and bias:
\(\mathbf{y}_s \in \mathbb{R}^C\)
\(\mathbf{y}_b \in \mathbb{R}^C\)
Intuition:
Scale \(\mathbf{y}_s\): Controls variance/contrast of features
Bias \(\mathbf{y}_b\): Shifts mean of features
Together, they control the βstyleβ of that layer
Effect:
Low-res layers (4Γ4, 8Γ8): Control pose, face shape, coarse structure
Mid-res layers (16Γ16 - 128Γ128): Control facial features, hair style
High-res layers (256Γ256 - 1024Γ1024): Control color scheme, fine detail
Mathematical FormulationΒΆ
At layer \(i\) with features \(\mathbf{h}\) and style \(\mathbf{w}_i\):
where:
\(\gamma(\mathbf{w}_i) = \text{FC}^{(\gamma)}_i(\mathbf{w}_i)\): Learned scale
\(\beta(\mathbf{w}_i) = \text{FC}^{(\beta)}_i(\mathbf{w}_i)\): Learned bias
\(\odot\): Element-wise multiplication
Comparison to Conditional Batch Normalization:
Conditional BN (BigGAN, etc.): $\( \text{CBN}(\mathbf{h}, \mathbf{c}) = \gamma(\mathbf{c}) \frac{\mathbf{h} - \mu_{\text{batch}}}{\sigma_{\text{batch}}} + \beta(\mathbf{c}) \)$
AdaIN: $\( \text{AdaIN}(\mathbf{h}, \mathbf{w}) = \gamma(\mathbf{w}) \frac{\mathbf{h} - \mu_{\text{instance}}}{\sigma_{\text{instance}}} + \beta(\mathbf{w}) \)$
Key Difference: Instance norm (per-sample) vs. Batch norm (across batch)
Instance norm: Better for style transfer (donβt mix statistics across samples)
Allows independent control per image
5. Mapping NetworkΒΆ
ArchitectureΒΆ
Goal: Transform random latent \(\mathbf{z} \sim \mathcal{N}(0, I)\) into disentangled latent \(\mathbf{w}\)
Structure: 8-layer MLP (fully-connected network)
Dimensions:
Input: \(\mathbf{z} \in \mathbb{R}^{512}\)
Hidden: \(512\) units per layer
Output: \(\mathbf{w} \in \mathbb{R}^{512}\)
Activation: Leaky ReLU with slope 0.2
No Normalization: No batch norm or instance norm in mapping network
Why Mapping Network?ΒΆ
Hypothesis: The distribution \(p(\mathbf{z})\) (Gaussian) must be mapped to distribution of training images, which may be complex/non-linear.
Problem with Direct \(\mathbf{z}\):
Training data may have complex distribution:
Sparse regions: e.g., no half-male-half-female faces
Correlations: e.g., age correlates with wrinkles
If \(\mathbf{z}\) is Gaussian, generator must βwarpβ this to match data distribution β entanglement.
Solution: Introduce intermediate \(\mathbf{w}\) that doesnβt need to follow fixed distribution.
Result: \(\mathbf{w}\) can be non-linear transformation of \(\mathbf{z}\), allowing:
Features to lie on non-linear manifold
Better disentanglement (independent features)
Disentanglement MetricsΒΆ
Perceptual Path Length (PPL):
Measures how smooth interpolation is in latent space:
where \(w_2 = w_1 + \epsilon \cdot \text{direction}\) and \(d\) is perceptual distance (LPIPS).
Lower PPL = Better Disentanglement
Linear Separability:
Train linear classifier to predict attribute from \(\mathbf{w}\): $\( \text{classifier}: \mathbf{w} \rightarrow \{\text{male}, \text{female}\} \)$
Higher accuracy = more linear/disentangled.
Truncation TrickΒΆ
Observation: Images near center of \(\mathbf{w}\) distribution have higher quality.
Truncation: Sample \(\mathbf{z} \sim \mathcal{N}(0, I)\), then shrink towards mean:
where:
\(\bar{\mathbf{w}} = \mathbb{E}_{\mathbf{z}}[f(\mathbf{z})]\): Mean \(\mathbf{w}\) (computed over many samples)
\(\psi \in [0, 1]\): Truncation factor
Effect:
\(\psi = 1\): Full diversity, may have artifacts
\(\psi = 0.7\): Good quality-diversity trade-off (common choice)
\(\psi = 0\): Collapse to single image (average face)
6. Noise InjectionΒΆ
MotivationΒΆ
Real images have stochastic variation not captured by deterministic generator:
Exact placement of individual hair strands
Skin pores, freckles, stubble
Background textures
Solution: Inject random noise at each layer.
ImplementationΒΆ
At each layer, after convolution and before AdaIN:
where:
\(\mathbf{n} \sim \mathcal{N}(0, I)\): Gaussian noise (same spatial size as \(\mathbf{h}\))
\(B \in \mathbb{R}^C\): Learned per-channel scaling factors
\(\odot\): Broadcasting: \(B_c\) multiplies all spatial locations in channel \(c\)
Key: Noise is broadcasted (same noise value for all spatial locations within a channel).
Per-Layer Noise:
Each layer gets independent noise
Low-res layers: Affects coarse stochasticity (overall lighting variation)
High-res layers: Affects fine details (individual hair strands, pores)
Effect of NoiseΒΆ
Experiment: Remove noise from specific layers:
No noise at all: Images look βplasticβ, overly smooth, lack detail
No noise at low-res: Coarse structure too deterministic
No noise at high-res: Fine details (hair, skin texture) too regular
Optimal: Noise at all layers (default in StyleGAN)
Mathematical FormulationΒΆ
Full forward pass at layer \(i\):
Order of Operations:
Convolution
Add noise
AdaIN (modulate with style)
Activation
7. Mixing RegularizationΒΆ
MotivationΒΆ
Risk: Generator may correlate adjacent styles (e.g., use \(\mathbf{w}_i\) and \(\mathbf{w}_{i+1}\) together).
Goal: Force localization of stylesβeach layer should control independent aspects.
MethodΒΆ
During training, use two latent codes \(\mathbf{w}_1, \mathbf{w}_2\):
Sample two independent latents: \(\mathbf{z}_1, \mathbf{z}_2 \sim \mathcal{N}(0, I)\)
Map to \(\mathbf{w}\)-space: \(\mathbf{w}_1 = f(\mathbf{z}_1), \quad \mathbf{w}_2 = f(\mathbf{z}_2)\)
Choose random crossover point \(k \in \{1, \ldots, 18\}\)
Use \(\mathbf{w}_1\) for layers \(1\) to \(k-1\), and \(\mathbf{w}_2\) for layers \(k\) to \(18\):
Result: Image has coarse features from \(\mathbf{w}_1\) (pose, face shape) and fine features from \(\mathbf{w}_2\) (hair color, expression).
EffectsΒΆ
Low-Level Crossover (\(k\) small, e.g., \(k=4\)):
Only very coarse features from \(\mathbf{w}_1\) (pose, rough face shape)
Most features from \(\mathbf{w}_2\)
Mid-Level Crossover (\(k\) around 8-10):
Pose and face shape from \(\mathbf{w}_1\)
Hair style, facial features, expression from \(\mathbf{w}_2\)
High-Level Crossover (\(k\) large, e.g., \(k=14\)):
Almost everything from \(\mathbf{w}_1\)
Only color scheme and fine texture from \(\mathbf{w}_2\)
Training Benefit:
Prevents βmode collapseβ where generator ignores specific layers
Encourages disentanglement: each layer must be independently useful
Typically applied with probability 0.9 (90% of training batches use mixing)
8. Progressive GrowingΒΆ
Concept (StyleGAN 1)ΒΆ
Idea: Start training at low resolution (4Γ4), gradually add layers to increase resolution (8Γ8, 16Γ16, β¦, 1024Γ1024).
Procedure:
Phase 1: Train generator and discriminator at \(4 \times 4\)
Phase 2: Add layers for \(8 \times 8\), fade them in smoothly
Phase 3: Fully switch to \(8 \times 8\), add layers for \(16 \times 16\)
Repeat until target resolution
Fade-In: Use \(\alpha \in [0, 1]\) to smoothly blend:
Initially \(\alpha = 0\) (only old path), gradually increase to \(\alpha = 1\) (only new path).
Benefits:
Faster training: Low-res training is faster
Stability: Easier optimization at low resolution first
Quality: Helps learn coarse structure before fine details
Note: StyleGAN2 removes progressive growing (found to cause artifacts).
9. StyleGAN2 ImprovementsΒΆ
Problem 1: Characteristic ArtifactsΒΆ
Observation: StyleGAN1 generates βblobβ artifacts that appear/disappear frame-to-frame in videos.
Root Cause: AdaIN normalizes each feature map independently β can βeraseβ information, creating artifacts.
Solution: Modify normalization scheme
StyleGAN2 Normalization:
Instead of: $\( \text{AdaIN}(\mathbf{h}) = \gamma \frac{\mathbf{h} - \mu}{\sigma} + \beta \)$
Use weight demodulation:
Modulate conv weights directly: $\( w'_{ijk} = s_i \cdot w_{ijk} \)\( where \)s_i\( is style-based scale for filter \)i$.
Demodulate (normalize) after convolution: $\( y_i = \frac{\sum_{jk} w'_{ijk} x_{jk}}{\sqrt{\sum_{jk} (w'_{ijk})^2 + \epsilon}} \)$
Benefit: Removes explicit feature normalization β eliminates artifacts.
Problem 2: Progressive Growing ArtifactsΒΆ
Issue: Phase transitions in progressive growing cause position-dependent artifacts.
Solution: Remove progressive growing entirely
Train all layers from start (directly at target resolution)
Use techniques from StyleGAN2 to stabilize training:
Path length regularization
Lazy regularization
Problem 3: PPL (Perceptual Path Length)ΒΆ
Goal: Encourage smooth interpolation in \(\mathbf{w}\)-space.
Path Length Regularization:
Penalize large perceptual changes for small \(\mathbf{w}\) changes:
where:
\(\mathbf{J}_\mathbf{w}\): Jacobian of generator w.r.t. \(\mathbf{w}\)
\(\mathbf{y}\): Random direction in image space
\(a\): Moving average of \(\|\mathbf{J}_\mathbf{w}^T \mathbf{y}\|_2\)
Effect: Smooths \(\mathbf{w}\)-space, improves interpolation quality.
Summary of StyleGAN2ΒΆ
Changes from StyleGAN1:
Weight demodulation instead of AdaIN
No progressive growing
Path length regularization
Lazy regularization (apply R1 less frequently)
Results:
Higher quality (FID improved from 4.4 to 2.8 on FFHQ)
No artifacts
Better latent space
10. StyleGAN3: Alias-Free GenerationΒΆ
Problem: Texture StickingΒΆ
Observation: When interpolating between faces, textures (stubble, hair) βstickβ to screen coordinates instead of object.
Example: Rotate a face β stubble doesnβt rotate with face, stays in same screen location.
Root Cause: Aliasing from upsampling operations.
Aliasing in Neural NetworksΒΆ
Convolution + ReLU + Downsampling/Upsampling violates Nyquist-Shannon sampling theorem:
Creates high-frequency signals (from ReLU nonlinearity)
Upsampling can alias these signals
Aliased signals leak to other frequencies
Result: Generator learns to exploit aliased frequencies β texture sticking.
StyleGAN3 SolutionΒΆ
Alias-Free Generator:
Filtered Nonlinearities: Apply low-pass filter after each ReLU to remove high frequencies
Filtered Upsampling/Downsampling: Use proper anti-aliasing filters
Fourier Features: Input layer uses continuous Fourier features instead of learned constant
Mathematical Formulation:
Standard upsampling (aliases): $\( \mathbf{h}' = \text{Upsample}(\mathbf{h}) \)$
Alias-free upsampling: $\( \mathbf{h}' = \text{LowPassFilter}(\text{Upsample}(\mathbf{h})) \)$
Effect: Textures now rotate/transform with object, not stuck to screen.
Configuration OptionsΒΆ
StyleGAN3 has two configs:
StyleGAN3-T (βtranslationβ): Partial equivariance (translation + rotation)
StyleGAN3-R (βfull rotationβ): Full equivariance (arbitrary rotations)
Trade-Off:
StyleGAN3-R: Perfect equivariance, but slightly lower quality
StyleGAN3-T: Better quality, less strict equivariance
11. Latent Space AnalysisΒΆ
\(\mathcal{Z}\) vs. \(\mathcal{W}\) SpaceΒΆ
\(\mathcal{Z}\)-Space:
Original latent space: \(\mathbf{z} \sim \mathcal{N}(0, I)\)
Entangled: Changing \(\mathbf{z}\) affects multiple attributes
\(\mathcal{W}\)-Space:
Intermediate latent: \(\mathbf{w} = f(\mathbf{z})\)
More disentangled
Fixed distribution (learned manifold)
\(\mathcal{W}+\) Space:
Extended space: Allow different \(\mathbf{w}_i\) per layer (18 separate vectors)
Even more expressive
Used for real image inversion
Latent Space EditingΒΆ
Linear Directions in \(\mathcal{W}\):
Many attributes correspond to linear directions:
where \(\mathbf{d}_{\text{gender}}\) is learned direction.
Finding Directions:
Supervised: Train classifier, use gradient: $\( \mathbf{d} = \nabla_{\mathbf{w}} P(y=\text{male} \mid \mathbf{w}) \)$
Unsupervised: PCA on \(\{\mathbf{w}_i\}\) samples, find principal components
Component 1 might correspond to age
Component 2 might correspond to gender
Etc.
Closed-Form (GANSpace): SVD on generator features: $\( \mathbf{G} = \mathbf{U} \boldsymbol{\Sigma} \mathbf{V}^T \)\( Use columns of \)\mathbf{V}$ as edit directions.
Semantic DirectionsΒΆ
Common Edits:
Age: \(\mathbf{w} + \alpha \cdot \mathbf{d}_{\text{age}}\)
Gender: \(\mathbf{w} + \alpha \cdot \mathbf{d}_{\text{gender}}\)
Smile: \(\mathbf{w} + \alpha \cdot \mathbf{d}_{\text{smile}}\)
Pose: \(\mathbf{w} + \alpha \cdot \mathbf{d}_{\text{yaw}}\)
Orthogonality: In good latent space, directions are approximately orthogonal: $\( \mathbf{d}_i^T \mathbf{d}_j \approx 0 \quad \text{for } i \neq j \)$
12. Applications and ExtensionsΒΆ
1. Image-to-Image Translation (pix2pixHD, SPADE)ΒΆ
Use StyleGAN as backbone for high-quality image translation.
2. 3D-Aware Generation (Ο-GAN, EG3D)ΒΆ
Combine StyleGAN with 3D representations (NeRF) for view-consistent generation.
3. Text-to-Image (StyleCLIP)ΒΆ
Use CLIP to guide StyleGAN edits: $\( \mathbf{w}^* = \arg\min_{\mathbf{w}} \mathcal{L}_{\text{CLIP}}(G(\mathbf{w}), \text{text}) + \lambda \|\mathbf{w} - \mathbf{w}_0\|^2 \)$
4. Real Image Editing (GAN Inversion)ΒΆ
Goal: Edit real photos using StyleGANβs latent space
Steps:
Invert image to latent code: \(\mathbf{x} \rightarrow \mathbf{w}\) $\( \mathbf{w}^* = \arg\min_{\mathbf{w}} \|G(\mathbf{w}) - \mathbf{x}\|^2 + \text{regularization} \)$
Edit latent: \(\mathbf{w}' = \mathbf{w}^* + \alpha \cdot \mathbf{d}\)
Generate: \(\mathbf{x}' = G(\mathbf{w}')\)
Methods:
Optimization: Directly optimize \(\mathbf{w}\) (slow but accurate)
Encoder: Train encoder \(E: \mathbf{x} \rightarrow \mathbf{w}\) (fast, less accurate)
Hybrid: Encoder + optimization fine-tuning
5. Domain Adaptation (StyleGAN-NADA)ΒΆ
Fine-tune StyleGAN on new domain using text guidance (no new images needed!).
Example: βPixar characterβ, βZombie faceβ
13. Implementation DetailsΒΆ
TrainingΒΆ
Datasets:
FFHQ (Flickr Faces HQ): 70k faces at 1024Γ1024
LSUN: Bedroom, church, etc.
Custom datasets (min ~1k images for fine-tuning)
Hyperparameters:
Batch size: 32 (distributed over multiple GPUs)
Learning rate: 0.002 (generator), 0.002 (discriminator)
Optimizer: Adam with \(\beta_1 = 0\), \(\beta_2 = 0.99\)
R1 regularization: \(\gamma = 10\)
Training time: ~1 week on 8Γ V100 GPUs for FFHQ
Architecture DetailsΒΆ
Mapping Network:
8 fully-connected layers
512 units per layer
Leaky ReLU (slope 0.2)
Learning rate multiplier: 0.01 (slower learning for mapping net)
Synthesis Network:
Constant input: \(4 \times 4 \times 512\)
Bilinear upsampling (StyleGAN2) or filtered upsampling (StyleGAN3)
Conv kernels: \(3 \times 3\)
Noise: Per-layer learnable scaling
Discriminator:
Mirror of synthesis network (no mapping network)
Residual connections (StyleGAN2)
MiniBatch StdDev layer at end
InferenceΒΆ
Sampling:
Sample \(\mathbf{z} \sim \mathcal{N}(0, I)\)
\(\mathbf{w} = f(\mathbf{z})\)
(Optional) Truncate: \(\mathbf{w}' = \bar{\mathbf{w}} + \psi(\mathbf{w} - \bar{\mathbf{w}})\)
\(\mathbf{x} = G(\mathbf{w}')\)
Speed:
StyleGAN2 (1024Γ1024): ~30ms per image on V100
StyleGAN3: ~50ms per image (more computation due to filtering)
Best PracticesΒΆ
Training from ScratchΒΆ
Data: At least 1k high-quality images (10k+ for best results)
Augmentation: Use ADA (Adaptive Discriminator Augmentation) for small datasets
Resolution: Start with 256Γ256 or 512Γ512 (faster iteration)
Monitor: FID, qualitative samples, discriminator/generator loss ratio
Fine-Tuning Pretrained ModelΒΆ
Use StyleGAN2-ADA checkpoint (best for transfer)
Lower learning rate: 0.0001 (10Γ smaller)
Freeze mapping network (optional, helps preserve quality)
Small dataset: ADA augmentation essential
Latent Space EditingΒΆ
Use \(\mathcal{W}+\) space for real image inversion (better reconstruction)
Regularize: Encourage \(\mathbf{w}_i \approx \bar{\mathbf{w}}\) to stay in-distribution
Multiple edits: Apply orthogonalized directions to avoid interference
Common Issues and SolutionsΒΆ
1. Mode CollapseΒΆ
Symptoms: Generator produces limited variety Solutions:
Increase batch size
Add R1 regularization
Use mixing regularization
Check discriminator isnβt too strong
2. Training InstabilityΒΆ
Symptoms: Loss oscillates, quality degrades Solutions:
Reduce learning rate
Increase R1 regularization weight
Use gradient clipping
Check for NaNs in gradients
3. Poor Inversion QualityΒΆ
Symptoms: \(G(E(\mathbf{x})) \not\approx \mathbf{x}\) Solutions:
Use \(\mathcal{W}+\) space (18 separate \(\mathbf{w}_i\))
Add perceptual loss (LPIPS)
Optimize with noise regularization
Hybrid encoder + optimization
Future DirectionsΒΆ
3D-Aware StyleGAN: Full 3D control (EG3D, StyleSDF)
Video StyleGAN: Temporally consistent generation
Controllable Editing: More semantic, disentangled controls
Efficient Training: Reduce computational cost
Generalization: Single model for multiple domains
SummaryΒΆ
StyleGAN revolutionized generative modeling by introducing:
Mapping Network: \(\mathbf{z} \rightarrow \mathbf{w}\) for disentanglement
AdaIN: Style control at multiple scales
Noise Injection: Stochastic variation
Mixing Regularization: Forced localization
StyleGAN2 improvements:
Weight demodulation (no artifacts)
No progressive growing
Path length regularization
StyleGAN3 advances:
Alias-free generation (texture sticking solved)
Impact:
SOTA image quality (FID ~2 on FFHQ)
Controllable generation (linear edits)
Wide applications (editing, 3D, text-to-image)
Key Insight: Separation of content (constant input) and style (\(\mathbf{w}\) via AdaIN) enables unprecedented control.
"""
Advanced StyleGAN - Production Implementation
Comprehensive implementation of StyleGAN architecture
Components Implemented:
1. Mapping Network (8-layer MLP)
2. Synthesis Network with AdaIN
3. Noise Injection
4. Style Mixing
5. Truncation Trick
6. Progressive Growing (StyleGAN1)
7. Weight Demodulation (StyleGAN2)
Features:
- Complete StyleGAN generator
- Style-based layer with AdaIN
- Latent space manipulation
- Image generation and interpolation
- Style mixing visualization
Author: Advanced Deep Learning Course
Date: 2024
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple
import numpy as np
import matplotlib.pyplot as plt
from math import log2
# ============================================================================
# 1. Mapping Network
# ============================================================================
class MappingNetwork(nn.Module):
"""
Mapping Network: z -> w
Transforms random latent z into intermediate latent w.
8-layer MLP with LeakyReLU activations.
"""
def __init__(
self,
z_dim: int = 512,
w_dim: int = 512,
num_layers: int = 8,
lr_multiplier: float = 0.01
):
"""
Args:
z_dim: Dimension of z latent
w_dim: Dimension of w latent
num_layers: Number of FC layers
lr_multiplier: Learning rate multiplier (for equalized learning rate)
"""
super().__init__()
self.z_dim = z_dim
self.w_dim = w_dim
# Build layers
layers = []
in_dim = z_dim
for i in range(num_layers):
layers.append(EqualizedLinear(in_dim, w_dim, lr_multiplier=lr_multiplier))
layers.append(nn.LeakyReLU(0.2))
in_dim = w_dim
self.mapping = nn.Sequential(*layers)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Args:
z: Random latent [batch_size, z_dim]
Returns:
w: Intermediate latent [batch_size, w_dim]
"""
return self.mapping(z)
# ============================================================================
# 2. Equalized Learning Rate Layers
# ============================================================================
class EqualizedLinear(nn.Module):
"""
Linear layer with equalized learning rate.
Scales weights at runtime instead of initialization.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True, lr_multiplier: float = 1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.lr_multiplier = lr_multiplier
# Initialize with N(0, 1)
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
# He initialization scale
self.scale = (1 / np.sqrt(in_features)) * lr_multiplier
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Scale weights at runtime
weight = self.weight * self.scale
bias = self.bias * self.lr_multiplier if self.bias is not None else None
return F.linear(x, weight, bias)
class EqualizedConv2d(nn.Module):
"""Conv2d with equalized learning rate."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = True
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
# Initialize with N(0, 1)
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
# He initialization scale
self.scale = 1 / np.sqrt(in_channels * kernel_size * kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self.weight * self.scale
return F.conv2d(x, weight, self.bias, self.stride, self.padding)
# ============================================================================
# 3. Adaptive Instance Normalization (AdaIN)
# ============================================================================
class AdaIN(nn.Module):
"""
Adaptive Instance Normalization.
Modulates features using style vector.
"""
def __init__(self, num_features: int, w_dim: int):
super().__init__()
# Affine transformations for style
self.scale_transform = EqualizedLinear(w_dim, num_features)
self.bias_transform = EqualizedLinear(w_dim, num_features)
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Features [batch, channels, height, width]
w: Style vector [batch, w_dim]
Returns:
Modulated features [batch, channels, height, width]
"""
# Compute style parameters
scale = self.scale_transform(w).unsqueeze(-1).unsqueeze(-1) # [B, C, 1, 1]
bias = self.bias_transform(w).unsqueeze(-1).unsqueeze(-1) # [B, C, 1, 1]
# Instance normalization
mean = x.mean(dim=[2, 3], keepdim=True)
std = x.std(dim=[2, 3], keepdim=True) + 1e-8
# Normalize and apply style
normalized = (x - mean) / std
return scale * normalized + bias
# ============================================================================
# 4. Noise Injection
# ============================================================================
class NoiseInjection(nn.Module):
"""
Inject scaled noise into features.
"""
def __init__(self, num_channels: int):
super().__init__()
# Learnable per-channel scaling
self.weight = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
def forward(self, x: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: Features [batch, channels, height, width]
noise: Optional noise [batch, 1, height, width]
Returns:
Features with noise added
"""
if noise is None:
# Generate noise
batch, _, height, width = x.shape
noise = torch.randn(batch, 1, height, width, device=x.device)
return x + self.weight * noise
# ============================================================================
# 5. Style-Based Generator Block
# ============================================================================
class StyleBlock(nn.Module):
"""
Style-based generator block.
Applies: Conv -> Noise -> AdaIN -> Activation
"""
def __init__(
self,
in_channels: int,
out_channels: int,
w_dim: int,
upsample: bool = False
):
super().__init__()
self.upsample = upsample
# Convolution
self.conv = EqualizedConv2d(in_channels, out_channels, kernel_size=3, padding=1)
# Noise injection
self.noise = NoiseInjection(out_channels)
# AdaIN
self.adain = AdaIN(out_channels, w_dim)
# Activation
self.activation = nn.LeakyReLU(0.2)
def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: Input features
w: Style vector
noise: Optional noise
Returns:
Output features
"""
# Upsample if needed
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
# Conv -> Noise -> AdaIN -> Activation
x = self.conv(x)
x = self.noise(x, noise)
x = self.adain(x, w)
x = self.activation(x)
return x
# ============================================================================
# 6. StyleGAN Generator
# ============================================================================
class StyleGANGenerator(nn.Module):
"""
StyleGAN Generator.
Maps z -> w via mapping network, then generates image via synthesis network.
"""
def __init__(
self,
z_dim: int = 512,
w_dim: int = 512,
img_resolution: int = 256,
img_channels: int = 3,
base_channels: int = 512
):
super().__init__()
self.z_dim = z_dim
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_channels = img_channels
# Mapping network
self.mapping = MappingNetwork(z_dim, w_dim)
# Constant input (learned)
self.const_input = nn.Parameter(torch.randn(1, base_channels, 4, 4))
# Calculate number of layers based on resolution
self.log_resolution = int(log2(img_resolution))
self.num_layers = (self.log_resolution - 1) * 2 # 2 conv per resolution
# Build synthesis network
self.synthesis_blocks = nn.ModuleList()
self.to_rgb_layers = nn.ModuleList()
in_channels = base_channels
resolution = 4
# Initial block (no upsample)
self.initial_block1 = StyleBlock(base_channels, base_channels, w_dim, upsample=False)
self.initial_block2 = StyleBlock(base_channels, base_channels, w_dim, upsample=False)
# Progressive blocks
for i in range(2, self.log_resolution + 1):
out_channels = base_channels // (2 ** (i - 2)) if i > 2 else base_channels
out_channels = max(out_channels, 16) # Min 16 channels
# Two conv blocks per resolution
block1 = StyleBlock(in_channels, out_channels, w_dim, upsample=True)
block2 = StyleBlock(out_channels, out_channels, w_dim, upsample=False)
self.synthesis_blocks.append(nn.ModuleList([block1, block2]))
# To RGB
to_rgb = EqualizedConv2d(out_channels, img_channels, kernel_size=1)
self.to_rgb_layers.append(to_rgb)
in_channels = out_channels
resolution *= 2
def forward(
self,
z: torch.Tensor,
truncation: float = 1.0,
truncation_mean: Optional[torch.Tensor] = None,
styles: Optional[List[torch.Tensor]] = None,
mix_index: Optional[int] = None,
return_latents: bool = False
) -> torch.Tensor:
"""
Args:
z: Random latent [batch, z_dim]
truncation: Truncation psi (1.0 = no truncation)
truncation_mean: Mean w for truncation
styles: Optional list of w vectors for style mixing
mix_index: Layer index to switch styles (for mixing)
return_latents: Whether to return w latents
Returns:
Generated image [batch, channels, resolution, resolution]
"""
batch_size = z.size(0)
# Map z to w
w = self.mapping(z)
# Truncation trick
if truncation < 1.0 and truncation_mean is not None:
w = truncation_mean + truncation * (w - truncation_mean)
# Style mixing
if styles is None:
# Use same w for all layers
styles = [w] * (self.num_layers + 2) # +2 for initial blocks
else:
# Mix styles at mix_index
if mix_index is not None:
styles = [styles[0] if i < mix_index else styles[1] for i in range(self.num_layers + 2)]
# Start with constant input
x = self.const_input.repeat(batch_size, 1, 1, 1)
# Initial blocks
x = self.initial_block1(x, styles[0])
x = self.initial_block2(x, styles[1])
# Progressive synthesis
style_idx = 2
for i, (block1, block2) in enumerate(self.synthesis_blocks):
x = block1(x, styles[style_idx])
style_idx += 1
x = block2(x, styles[style_idx])
style_idx += 1
# Skip to RGB for intermediate resolutions (progressive growing)
# For simplicity, we only use final to_rgb
# Final to RGB
rgb = self.to_rgb_layers[-1](x)
if return_latents:
return rgb, w
return rgb
def generate(
self,
batch_size: int = 1,
truncation: float = 0.7,
device: str = 'cuda'
) -> torch.Tensor:
"""Generate random images."""
z = torch.randn(batch_size, self.z_dim, device=device)
# Compute truncation mean
with torch.no_grad():
z_mean = torch.randn(1000, self.z_dim, device=device)
w_mean = self.mapping(z_mean).mean(0, keepdim=True)
return self.forward(z, truncation=truncation, truncation_mean=w_mean)
def style_mixing(
self,
z1: torch.Tensor,
z2: torch.Tensor,
mix_index: int
) -> torch.Tensor:
"""
Mix styles from two latents.
Args:
z1: First latent
z2: Second latent
mix_index: Layer to switch from z1 to z2
Returns:
Mixed image
"""
w1 = self.mapping(z1)
w2 = self.mapping(z2)
return self.forward(z1, styles=[w1, w2], mix_index=mix_index)
# ============================================================================
# 7. StyleGAN2: Weight Demodulation
# ============================================================================
class ModulatedConv2d(nn.Module):
"""
Modulated convolution with weight demodulation (StyleGAN2).
Replaces AdaIN with direct weight modulation.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
w_dim: int,
demodulate: bool = True
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.demodulate = demodulate
# Convolution weight
self.weight = nn.Parameter(
torch.randn(out_channels, in_channels, kernel_size, kernel_size)
)
self.scale = 1 / np.sqrt(in_channels * kernel_size * kernel_size)
# Style modulation
self.modulation = EqualizedLinear(w_dim, in_channels)
def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input [batch, in_channels, height, width]
w: Style [batch, w_dim]
Returns:
Output [batch, out_channels, height, width]
"""
batch, in_c, height, width = x.shape
# Compute modulation
style = self.modulation(w).view(batch, 1, in_c, 1, 1) # [B, 1, C_in, 1, 1]
# Modulate weights
weight = self.weight.unsqueeze(0) * self.scale # [1, C_out, C_in, K, K]
weight = weight * style # [B, C_out, C_in, K, K]
# Demodulate
if self.demodulate:
demod = torch.rsqrt((weight ** 2).sum(dim=[2, 3, 4], keepdim=True) + 1e-8)
weight = weight * demod
# Reshape for grouped convolution
weight = weight.view(batch * self.out_channels, in_c, self.kernel_size, self.kernel_size)
# Reshape input for grouped conv
x = x.view(1, batch * in_c, height, width)
# Grouped convolution
out = F.conv2d(x, weight, padding=self.kernel_size // 2, groups=batch)
# Reshape output
out = out.view(batch, self.out_channels, height, width)
return out
# ============================================================================
# 8. Visualization and Utils
# ============================================================================
def visualize_generation(generator: StyleGANGenerator, num_samples: int = 16, device: str = 'cuda'):
"""Visualize random generations."""
generator.eval()
with torch.no_grad():
images = generator.generate(num_samples, truncation=0.7, device=device)
# Convert to numpy
images = images.cpu().permute(0, 2, 3, 1).numpy()
images = (images + 1) / 2 # [-1, 1] -> [0, 1]
images = np.clip(images, 0, 1)
# Plot
grid_size = int(np.sqrt(num_samples))
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
if i < num_samples:
ax.imshow(images[i])
ax.axis('off')
plt.tight_layout()
plt.show()
def visualize_style_mixing(
generator: StyleGANGenerator,
num_sources: int = 4,
num_destinations: int = 4,
device: str = 'cuda'
):
"""
Visualize style mixing matrix.
Rows: coarse styles (pose, face shape)
Cols: fine styles (color, texture)
"""
generator.eval()
# Sample latents
z_coarse = torch.randn(num_sources, generator.z_dim, device=device)
z_fine = torch.randn(num_destinations, generator.z_dim, device=device)
# Map to w
with torch.no_grad():
w_coarse = generator.mapping(z_coarse)
w_fine = generator.mapping(z_fine)
# Generate mixing matrix
images = []
for i in range(num_sources):
row = []
for j in range(num_destinations):
# Mix at mid-point (layer 8 out of 18)
img = generator(z_coarse[i:i+1], styles=[w_coarse[i:i+1], w_fine[j:j+1]], mix_index=8)
row.append(img[0])
images.append(torch.stack(row))
images = torch.stack(images) # [sources, destinations, C, H, W]
# Plot
images = images.cpu().permute(0, 1, 3, 4, 2).numpy()
images = (images + 1) / 2
images = np.clip(images, 0, 1)
fig, axes = plt.subplots(num_sources, num_destinations, figsize=(15, 15))
for i in range(num_sources):
for j in range(num_destinations):
axes[i, j].imshow(images[i, j])
axes[i, j].axis('off')
if i == 0:
axes[i, j].set_title(f'Fine {j}', fontsize=10)
if j == 0:
axes[i, j].set_ylabel(f'Coarse {i}', fontsize=10)
plt.suptitle('Style Mixing: Coarse (rows) Γ Fine (cols)', fontsize=14)
plt.tight_layout()
plt.show()
def interpolate_latents(
generator: StyleGANGenerator,
z1: torch.Tensor,
z2: torch.Tensor,
num_steps: int = 10,
device: str = 'cuda'
) -> torch.Tensor:
"""
Interpolate between two latents.
Args:
generator: StyleGAN generator
z1: Start latent
z2: End latent
num_steps: Number of interpolation steps
Returns:
Interpolated images
"""
generator.eval()
# Linear interpolation
alphas = torch.linspace(0, 1, num_steps, device=device)
images = []
with torch.no_grad():
for alpha in alphas:
z = (1 - alpha) * z1 + alpha * z2
img = generator(z)
images.append(img[0])
return torch.stack(images)
# ============================================================================
# 9. Demo
# ============================================================================
def demo_stylegan():
"""Demonstrate StyleGAN capabilities."""
print("=" * 80)
print("StyleGAN - Advanced Generative Model Demo")
print("=" * 80)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nDevice: {device}")
# Create generator (small version for demo)
print("\nCreating StyleGAN generator (64x64 for demo)...")
generator = StyleGANGenerator(
z_dim=512,
w_dim=512,
img_resolution=64,
img_channels=3
).to(device)
# Count parameters
total_params = sum(p.numel() for p in generator.parameters())
print(f"Total parameters: {total_params:,}")
# 1. Random generation
print("\n" + "=" * 60)
print("1. Random Image Generation")
print("=" * 60)
z = torch.randn(4, 512, device=device)
with torch.no_grad():
images = generator(z)
print(f"Generated images: {images.shape}")
print(f"Min: {images.min().item():.3f}, Max: {images.max().item():.3f}")
# 2. Style mixing
print("\n" + "=" * 60)
print("2. Style Mixing")
print("=" * 60)
z1 = torch.randn(1, 512, device=device)
z2 = torch.randn(1, 512, device=device)
print("Mixing at different layer indices:")
for mix_idx in [2, 6, 10]:
with torch.no_grad():
mixed = generator.style_mixing(z1, z2, mix_idx)
print(f" Mix at layer {mix_idx}: {mixed.shape}")
# 3. Latent interpolation
print("\n" + "=" * 60)
print("3. Latent Space Interpolation")
print("=" * 60)
interp_images = interpolate_latents(generator, z1, z2, num_steps=8, device=device)
print(f"Interpolated images: {interp_images.shape}")
# 4. Truncation trick
print("\n" + "=" * 60)
print("4. Truncation Trick")
print("=" * 60)
for psi in [0.5, 0.7, 1.0]:
with torch.no_grad():
# Compute mean w
z_samples = torch.randn(100, 512, device=device)
w_mean = generator.mapping(z_samples).mean(0, keepdim=True)
img = generator(z1, truncation=psi, truncation_mean=w_mean)
print(f" Truncation Ο={psi}: quality-diversity trade-off")
# 5. Architecture analysis
print("\n" + "=" * 60)
print("5. Architecture Analysis")
print("=" * 60)
print(f"\nMapping Network:")
print(f" Input: z β R^{generator.z_dim}")
print(f" Output: w β R^{generator.w_dim}")
print(f" Layers: 8 Γ FC({generator.w_dim}) with LeakyReLU")
print(f"\nSynthesis Network:")
print(f" Starting resolution: 4Γ4")
print(f" Final resolution: {generator.img_resolution}Γ{generator.img_resolution}")
print(f" Number of style blocks: {len(generator.synthesis_blocks) * 2 + 2}")
print(f" AdaIN layers: {generator.num_layers + 2}")
print("\n" + "=" * 80)
print("Demo complete!")
print("=" * 80)
def demo_comparison():
"""Compare standard GAN vs StyleGAN architecture."""
print("\n" + "=" * 80)
print("Architectural Comparison: Standard GAN vs StyleGAN")
print("=" * 80)
print("\nStandard GAN:")
print(" Input: z β [Concat with spatial coords]")
print(" Process: Conv layers progressively increase resolution")
print(" Control: Global (entire z affects all features)")
print(" Problem: Entangled latent space")
print("\nStyleGAN:")
print(" Input: Learned constant 4Γ4 tensor")
print(" Latent: z β MappingNet β w (disentangled)")
print(" Control: Local (each AdaIN layer independently controlled)")
print(" Noise: Per-layer stochastic variation")
print(" Benefits:")
print(" - Disentangled representations")
print(" - Scale-specific control (pose vs texture)")
print(" - Smooth interpolation")
print(" - Style mixing")
if __name__ == "__main__":
print("\nAdvanced StyleGAN Implementation\n")
# Run demos
demo_stylegan()
demo_comparison()
print("\n" + "=" * 80)
print("Key Takeaways:")
print("=" * 80)
print("1. Mapping Network (z β w) enables disentanglement")
print("2. AdaIN provides scale-specific style control")
print("3. Noise injection adds stochastic variation")
print("4. Style mixing forces feature localization")
print("5. Truncation trick balances quality vs diversity")
print("6. StyleGAN2's weight demodulation fixes artifacts")
print("=" * 80)