import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Style-Based GeneratorΒΆ

Key InnovationΒΆ

Mapping network \(f: \mathcal{Z} \to \mathcal{W}\)

  • Latent \(z \sim \mathcal{N}(0, I)\)

  • Intermediate \(w = f(z)\)

  • Controls synthesis via AdaIN

Adaptive Instance NormalizationΒΆ

\[\text{AdaIN}(x_i, y) = y_{s,i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}\]

where \(y = (y_s, y_b)\) from affine transform of \(w\).

πŸ“š Reference Materials:

class AdaptiveInstanceNorm(nn.Module):
    """AdaIN layer."""
    
    def __init__(self, num_features, w_dim):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.style = nn.Linear(w_dim, num_features * 2)
    
    def forward(self, x, w):
        style = self.style(w).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, dim=1)
        out = self.norm(x)
        return gamma * out + beta

# Test
adain = AdaptiveInstanceNorm(64, 512).to(device)
x = torch.randn(4, 64, 16, 16).to(device)
w = torch.randn(4, 512).to(device)
out = adain(x, w)
print(f"AdaIN output: {out.shape}")

Generator ArchitectureΒΆ

StyleGAN’s generator replaces the traditional single-pass architecture with a mapping network followed by a synthesis network. The mapping network transforms the input latent \(z\) through 8 fully connected layers into an intermediate latent \(w\), which lives in a more disentangled \(\mathcal{W}\) space. The synthesis network then uses \(w\) to modulate feature maps at each resolution through adaptive instance normalization (AdaIN): \(\text{AdaIN}(x_i, y) = y_{s,i} \frac{x_i - \mu(x_i)}{\sigma(x_i)} + y_{b,i}\). This style-based design allows independent control of coarse, medium, and fine image attributes by injecting different \(w\) vectors at different layers.

class MappingNetwork(nn.Module):
    """Maps z to w."""
    
    def __init__(self, z_dim=512, w_dim=512, n_layers=8):
        super().__init__()
        layers = []
        for i in range(n_layers):
            layers.extend([
                nn.Linear(z_dim if i == 0 else w_dim, w_dim),
                nn.LeakyReLU(0.2)
            ])
        self.mapping = nn.Sequential(*layers)
    
    def forward(self, z):
        return self.mapping(z)

class StyleBlock(nn.Module):
    """Synthesis block with AdaIN."""
    
    def __init__(self, in_channels, out_channels, w_dim):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.adain = AdaptiveInstanceNorm(out_channels, w_dim)
        self.activation = nn.LeakyReLU(0.2)
    
    def forward(self, x, w):
        x = self.conv(x)
        x = self.adain(x, w)
        return self.activation(x)

class StyleGAN_Generator(nn.Module):
    """Simplified StyleGAN generator."""
    
    def __init__(self, z_dim=512, w_dim=512, img_size=64, img_channels=1):
        super().__init__()
        self.img_size = img_size
        
        # Mapping network
        self.mapping = MappingNetwork(z_dim, w_dim)
        
        # Constant input
        self.const = nn.Parameter(torch.randn(1, 512, 4, 4))
        
        # Synthesis blocks
        self.block1 = StyleBlock(512, 256, w_dim)
        self.block2 = StyleBlock(256, 128, w_dim)
        self.block3 = StyleBlock(128, 64, w_dim)
        self.block4 = StyleBlock(64, 32, w_dim)
        
        # To RGB
        self.to_rgb = nn.Conv2d(32, img_channels, 1)
    
    def forward(self, z, truncation=1.0):
        # Map to w
        w = self.mapping(z)
        
        # Truncation trick
        if truncation < 1.0:
            w = truncation * w
        
        # Start from constant
        x = self.const.repeat(z.size(0), 1, 1, 1)
        
        # Progressive upsampling
        x = self.block1(x, w)
        x = F.interpolate(x, scale_factor=2)
        
        x = self.block2(x, w)
        x = F.interpolate(x, scale_factor=2)
        
        x = self.block3(x, w)
        x = F.interpolate(x, scale_factor=2)
        
        x = self.block4(x, w)
        x = F.interpolate(x, scale_factor=2)
        
        # To RGB
        return torch.tanh(self.to_rgb(x))

# Test
gen = StyleGAN_Generator(img_size=64).to(device)
z = torch.randn(4, 512).to(device)
imgs = gen(z)
print(f"Generated images: {imgs.shape}")

Training on MNISTΒΆ

While StyleGAN was designed for high-resolution face generation, training a simplified version on MNIST demonstrates the core architectural ideas at a manageable scale. The progressive growing strategy (if used) starts at low resolution (\(4 \times 4\)) and gradually adds layers, stabilizing training at each stage before introducing finer detail. Even on MNIST, the style-based injection should produce better latent space disentanglement than a standard GAN, which we can verify by manipulating styles at different layers.

class Discriminator(nn.Module):
    """Simple discriminator."""
    
    def __init__(self, img_channels=1):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 32, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1)
        )
    
    def forward(self, x):
        return self.model(x)

# Load data
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)

# Models
G = StyleGAN_Generator().to(device)
D = Discriminator().to(device)

# Optimizers
opt_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))

print("Setup complete")
def train_stylegan(n_epochs=5):
    losses_G, losses_D = [], []
    
    for epoch in range(n_epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            real_imgs = real_imgs.to(device)
            batch_size = real_imgs.size(0)
            
            # Train Discriminator
            z = torch.randn(batch_size, 512).to(device)
            fake_imgs = G(z)
            
            real_loss = F.binary_cross_entropy_with_logits(
                D(real_imgs), torch.ones(batch_size, 1).to(device)
            )
            fake_loss = F.binary_cross_entropy_with_logits(
                D(fake_imgs.detach()), torch.zeros(batch_size, 1).to(device)
            )
            d_loss = (real_loss + fake_loss) / 2
            
            opt_D.zero_grad()
            d_loss.backward()
            opt_D.step()
            
            # Train Generator
            z = torch.randn(batch_size, 512).to(device)
            fake_imgs = G(z)
            g_loss = F.binary_cross_entropy_with_logits(
                D(fake_imgs), torch.ones(batch_size, 1).to(device)
            )
            
            opt_G.zero_grad()
            g_loss.backward()
            opt_G.step()
            
            losses_G.append(g_loss.item())
            losses_D.append(d_loss.item())
            
            if i % 200 == 0:
                print(f"Epoch {epoch}, Batch {i}, D: {d_loss.item():.4f}, G: {g_loss.item():.4f}")
    
    return losses_G, losses_D

losses_G, losses_D = train_stylegan(n_epochs=5)

plt.figure(figsize=(10, 5))
plt.plot(losses_G, alpha=0.5, label='Generator')
plt.plot(losses_D, alpha=0.5, label='Discriminator')
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('Loss', fontsize=11)
plt.title('StyleGAN Training', fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Style MixingΒΆ

Style mixing is a key capability of StyleGAN: by injecting the \(w\) vector from one latent code at coarse layers and from another at fine layers, we can combine the overall structure of one image with the fine details of another. For example, in face generation, coarse styles control pose and face shape while fine styles control hair texture and background. On MNIST, style mixing might combine the global shape of one digit with the stroke style of another, demonstrating that different layers have learned to control different levels of abstraction.

def style_mixing_demo():
    """Demonstrate style mixing."""
    G.eval()
    
    with torch.no_grad():
        # Source styles
        z1 = torch.randn(1, 512).to(device)
        z2 = torch.randn(1, 512).to(device)
        
        # Generate with pure styles
        img1 = G(z1)
        img2 = G(z2)
        
        # Mixed: coarse from z1, fine from z2
        w1 = G.mapping(z1)
        w2 = G.mapping(z2)
        
        # Manual forward with mixed styles
        x = G.const.repeat(1, 1, 1, 1)
        x = G.block1(x, w1)  # Coarse from z1
        x = F.interpolate(x, scale_factor=2)
        x = G.block2(x, w1)
        x = F.interpolate(x, scale_factor=2)
        x = G.block3(x, w2)  # Fine from z2
        x = F.interpolate(x, scale_factor=2)
        x = G.block4(x, w2)
        x = F.interpolate(x, scale_factor=2)
        img_mixed = torch.tanh(G.to_rgb(x))
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    axes[0].imshow(img1[0, 0].cpu(), cmap='gray')
    axes[0].set_title('Style 1', fontsize=11)
    axes[0].axis('off')
    
    axes[1].imshow(img_mixed[0, 0].cpu(), cmap='gray')
    axes[1].set_title('Mixed (coarse:1, fine:2)', fontsize=11)
    axes[1].axis('off')
    
    axes[2].imshow(img2[0, 0].cpu(), cmap='gray')
    axes[2].set_title('Style 2', fontsize=11)
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

style_mixing_demo()

Truncation TrickΒΆ

The truncation trick improves sample quality at the expense of diversity by interpolating the \(w\) vector toward the mean of \(\mathcal{W}\): \(w' = \bar{w} + \psi (w - \bar{w})\), where \(\psi \in [0, 1]\) controls the trade-off. At \(\psi = 1\) we get full diversity (standard sampling); at \(\psi = 0\) all samples converge to the β€œaverage” image. Values around \(\psi = 0.7\) typically produce the best visual quality. This technique exploits the fact that the density of training data is highest near the mean of the latent distribution, so truncating avoids low-density regions that may produce artifacts.

def truncation_comparison():
    """Compare different truncation values."""
    G.eval()
    
    z = torch.randn(1, 512).to(device)
    truncations = [0.5, 0.7, 1.0, 1.5]
    
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    with torch.no_grad():
        for i, trunc in enumerate(truncations):
            img = G(z, truncation=trunc)
            axes[i].imshow(img[0, 0].cpu(), cmap='gray')
            axes[i].set_title(f'ψ={trunc}', fontsize=11)
            axes[i].axis('off')
    
    plt.suptitle('Truncation Trick', fontsize=12)
    plt.tight_layout()
    plt.show()

truncation_comparison()

Generate GridΒΆ

Generating a grid of samples provides a comprehensive visual assessment of the model’s output quality and diversity. Each image in the grid is produced from an independent random latent vector, so the grid should exhibit varied attributes (different digit identities, stroke widths, orientations) while maintaining consistent quality. Inspecting the grid for repeated patterns, artifacts, or missing modes helps diagnose training issues like mode collapse or insufficient model capacity.

G.eval()
with torch.no_grad():
    z = torch.randn(16, 512).to(device)
    imgs = G(z, truncation=0.7)

fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(16):
    ax = axes[i // 4, i % 4]
    ax.imshow(imgs[i, 0].cpu(), cmap='gray')
    ax.axis('off')

plt.suptitle('StyleGAN Generated Samples', fontsize=13)
plt.tight_layout()
plt.show()

SummaryΒΆ

StyleGAN Innovations:ΒΆ

  1. Mapping network - \(z \to w\) disentanglement

  2. AdaIN - Style control at each layer

  3. Progressive growing - High-resolution synthesis

  4. Style mixing - Localized control

Key Benefits:ΒΆ

  • Better disentanglement

  • Controllable generation

  • High quality images

  • Interpretable latent space

Applications:ΒΆ

  • Face generation

  • Image editing

  • Style transfer

  • Data augmentation

Extensions:ΒΆ

  • StyleGAN2 (improved quality)

  • StyleGAN3 (alias-free)

  • StyleGAN-XL (large scale)

Next Steps:ΒΆ

  • 02_wgan_theory_implementation.ipynb - WGAN

  • Study progressive growing

  • Explore GANSpace editing

Advanced StyleGAN: Mathematical Foundations and Architecture Deep DiveΒΆ

Table of ContentsΒΆ

  1. Introduction to StyleGAN

  2. Motivation and Key Innovations

  3. Style-Based Generator Architecture

  4. Adaptive Instance Normalization (AdaIN)

  5. Mapping Network

  6. Noise Injection

  7. Mixing Regularization

  8. Progressive Growing

  9. StyleGAN2 Improvements

  10. StyleGAN3: Alias-Free Generation

  11. Latent Space Analysis

  12. Applications and Extensions

  13. Implementation Details

1. Introduction to StyleGANΒΆ

What is StyleGAN?ΒΆ

StyleGAN (Style-Based Generator Architecture for GANs) is a revolutionary generative model introduced by NVIDIA in 2018 that generates highly realistic images with unprecedented control over image synthesis.

Key Achievement: Generates photorealistic faces indistinguishable from real photos

Core Idea: Instead of feeding latent code directly into generator, transform it into intermediate latent space that controls style at different scales.

Evolution TimelineΒΆ

  • StyleGAN (Dec 2018): Style-based architecture, AdaIN, mixing regularization

  • StyleGAN2 (Dec 2019): Fixes artifacts, improved quality, removes progressive growing

  • StyleGAN3 (June 2021): Alias-free generation, better texture sticking

Why StyleGAN MattersΒΆ

Before StyleGAN: Limited control, mode collapse, artifacts, entangled representations

After StyleGAN:

  1. Disentangled Latent Space: Independent control of features (hair, pose, age, expression)

  2. High Quality: 1024Γ—1024 photorealistic images

  3. Controllability: Precise manipulation of attributes

  4. Interpolation: Smooth transitions between images

  5. Artistic Applications: Face editing, style transfer, image synthesis

2. Motivation and Key InnovationsΒΆ

Problems with Traditional GANsΒΆ

1. Entangled Latent Space

In standard GANs: $\( \mathbf{z} \sim \mathcal{N}(0, I) \quad \rightarrow \quad G(\mathbf{z}) = \mathbf{x} \)$

Problem: Changing \(\mathbf{z}\) affects multiple attributes simultaneously

  • Moving in latent space changes age + pose + expression

  • No semantic meaning to latent directions

  • Difficult to find interpretable controls

2. Limited Scale-Specific Control

Traditional generators use \(\mathbf{z}\) only at input:

  • All scales (coarse to fine) derived from same \(\mathbf{z}\)

  • Cannot control high-level structure separately from fine details

  • E.g., cannot change pose without affecting skin texture

3. Lack of Stochasticity

Real images have stochastic variation (hair placement, pores, wrinkles) not captured by deterministic \(G(\mathbf{z})\).

StyleGAN’s SolutionsΒΆ

1. Mapping Network: \(\mathbf{z} \rightarrow \mathbf{w}\)

  • Transform initial latent \(\mathbf{z}\) into intermediate latent \(\mathbf{w}\)

  • \(\mathbf{w}\) is more disentangled and easier to control

2. Style-Based Generation:

  • Feed \(\mathbf{w}\) into generator at multiple scales via AdaIN

  • Each scale controlled independently

  • Coarse levels: pose, face shape

  • Fine levels: hair, skin texture

3. Noise Injection:

  • Add stochastic noise at each layer

  • Generate realistic micro-variations (hair strands, pores)

4. Mixing Regularization:

  • Mix two latent codes \(\mathbf{w}_1, \mathbf{w}_2\) at different layers

  • Forces disentanglement (coarse vs. fine features)

3. Style-Based Generator ArchitectureΒΆ

Overall ArchitectureΒΆ

z (512) β†’ Mapping Network β†’ w (512)
                                ↓
                        [Duplicate 18 times]
                                ↓
                        w_1, w_2, ..., w_18
                                ↓
Synthesis Network: 
    Const 4Γ—4Γ—512
        ↓ (AdaIN with w_1, w_2)
    Conv 4Γ—4
        ↓ (AdaIN with w_3, w_4)
    Upsample β†’ 8Γ—8
        ↓ (AdaIN with w_5, w_6)
    Conv 8Γ—8
        ...
        ↓ (AdaIN with w_17, w_18)
    Conv 1024Γ—1024 β†’ RGB output

Key Difference from Traditional GAN:

Traditional: $\( \text{Input: } \mathbf{z} \rightarrow \text{Conv layers} \rightarrow \text{Output} \)$

StyleGAN: $\( \text{Const input} \rightarrow \text{Conv + AdaIN}(\mathbf{w}_i) \rightarrow \text{Output} \)$

Advantages:

  1. Latent \(\mathbf{w}\) not constrained by input spatial structure

  2. Each layer independently controlled

  3. Better disentanglement

Synthesis Network DetailsΒΆ

Starting Point: Learned constant \(4 \times 4 \times 512\) tensor (not random noise!)

Each Resolution Block:

  1. AdaIN with style \(\mathbf{w}_i\): Modulates feature statistics

  2. Conv 3Γ—3: Processes features

  3. Noise Injection: Adds stochastic detail

  4. Activation: Leaky ReLU

  5. Repeat AdaIN + Conv for second layer at same resolution

  6. Upsample: Increase resolution (nearest neighbor or bilinear)

Output: RGB image via \(1 \times 1\) convolution

Total Layers:

  • 1024Γ—1024 output: 18 layers (2 per resolution from 4Γ—4 to 1024Γ—1024)

  • Each layer gets its own \(\mathbf{w}_i\)

4. Adaptive Instance Normalization (AdaIN)ΒΆ

Standard Instance NormalizationΒΆ

Instance Normalization normalizes each channel independently per sample:

\[ \text{IN}(\mathbf{x})_{i,c} = \frac{\mathbf{x}_{i,c} - \mu_c}{\sigma_c} \]

where:

  • \(\mathbf{x}_{i,c}\): Feature at spatial position \(i\), channel \(c\)

  • \(\mu_c = \frac{1}{HW} \sum_{i} \mathbf{x}_{i,c}\): Mean of channel \(c\)

  • \(\sigma_c = \sqrt{\frac{1}{HW} \sum_{i} (\mathbf{x}_{i,c} - \mu_c)^2 + \epsilon}\): Std of channel \(c\)

Adaptive Instance Normalization (AdaIN)ΒΆ

Idea: After normalization, apply learned affine transformation based on style \(\mathbf{w}\):

\[ \text{AdaIN}(\mathbf{x}, \mathbf{w}) = \mathbf{y}_s(\mathbf{w}) \frac{\mathbf{x} - \mu(\mathbf{x})}{\sigma(\mathbf{x})} + \mathbf{y}_b(\mathbf{w}) \]

where:

  • \(\mathbf{y}_s(\mathbf{w}) = A_s \mathbf{w}\): Scale (learned linear projection)

  • \(\mathbf{y}_b(\mathbf{w}) = A_b \mathbf{w}\): Bias (learned linear projection)

Per Channel: Each of \(C\) channels gets its own scale and bias:

  • \(\mathbf{y}_s \in \mathbb{R}^C\)

  • \(\mathbf{y}_b \in \mathbb{R}^C\)

Intuition:

  • Scale \(\mathbf{y}_s\): Controls variance/contrast of features

  • Bias \(\mathbf{y}_b\): Shifts mean of features

  • Together, they control the β€œstyle” of that layer

Effect:

  • Low-res layers (4Γ—4, 8Γ—8): Control pose, face shape, coarse structure

  • Mid-res layers (16Γ—16 - 128Γ—128): Control facial features, hair style

  • High-res layers (256Γ—256 - 1024Γ—1024): Control color scheme, fine detail

Mathematical FormulationΒΆ

At layer \(i\) with features \(\mathbf{h}\) and style \(\mathbf{w}_i\):

\[ \mathbf{h}' = \text{AdaIN}(\mathbf{h}, \mathbf{w}_i) = \gamma(\mathbf{w}_i) \odot \frac{\mathbf{h} - \mu(\mathbf{h})}{\sigma(\mathbf{h})} + \beta(\mathbf{w}_i) \]

where:

  • \(\gamma(\mathbf{w}_i) = \text{FC}^{(\gamma)}_i(\mathbf{w}_i)\): Learned scale

  • \(\beta(\mathbf{w}_i) = \text{FC}^{(\beta)}_i(\mathbf{w}_i)\): Learned bias

  • \(\odot\): Element-wise multiplication

Comparison to Conditional Batch Normalization:

Conditional BN (BigGAN, etc.): $\( \text{CBN}(\mathbf{h}, \mathbf{c}) = \gamma(\mathbf{c}) \frac{\mathbf{h} - \mu_{\text{batch}}}{\sigma_{\text{batch}}} + \beta(\mathbf{c}) \)$

AdaIN: $\( \text{AdaIN}(\mathbf{h}, \mathbf{w}) = \gamma(\mathbf{w}) \frac{\mathbf{h} - \mu_{\text{instance}}}{\sigma_{\text{instance}}} + \beta(\mathbf{w}) \)$

Key Difference: Instance norm (per-sample) vs. Batch norm (across batch)

  • Instance norm: Better for style transfer (don’t mix statistics across samples)

  • Allows independent control per image

5. Mapping NetworkΒΆ

ArchitectureΒΆ

Goal: Transform random latent \(\mathbf{z} \sim \mathcal{N}(0, I)\) into disentangled latent \(\mathbf{w}\)

Structure: 8-layer MLP (fully-connected network)

\[ \mathbf{w} = f(\mathbf{z}) = \text{FC}_8 \circ \text{LeakyReLU} \circ \cdots \circ \text{FC}_1(\mathbf{z}) \]

Dimensions:

  • Input: \(\mathbf{z} \in \mathbb{R}^{512}\)

  • Hidden: \(512\) units per layer

  • Output: \(\mathbf{w} \in \mathbb{R}^{512}\)

Activation: Leaky ReLU with slope 0.2

No Normalization: No batch norm or instance norm in mapping network

Why Mapping Network?ΒΆ

Hypothesis: The distribution \(p(\mathbf{z})\) (Gaussian) must be mapped to distribution of training images, which may be complex/non-linear.

Problem with Direct \(\mathbf{z}\):

Training data may have complex distribution:

  • Sparse regions: e.g., no half-male-half-female faces

  • Correlations: e.g., age correlates with wrinkles

If \(\mathbf{z}\) is Gaussian, generator must β€œwarp” this to match data distribution β†’ entanglement.

Solution: Introduce intermediate \(\mathbf{w}\) that doesn’t need to follow fixed distribution.

Result: \(\mathbf{w}\) can be non-linear transformation of \(\mathbf{z}\), allowing:

  • Features to lie on non-linear manifold

  • Better disentanglement (independent features)

Disentanglement MetricsΒΆ

Perceptual Path Length (PPL):

Measures how smooth interpolation is in latent space:

\[ \text{PPL} = \mathbb{E}\left[\frac{1}{\epsilon^2} d\left(G(w_1), G(w_2)\right)\right] \]

where \(w_2 = w_1 + \epsilon \cdot \text{direction}\) and \(d\) is perceptual distance (LPIPS).

Lower PPL = Better Disentanglement

Linear Separability:

Train linear classifier to predict attribute from \(\mathbf{w}\): $\( \text{classifier}: \mathbf{w} \rightarrow \{\text{male}, \text{female}\} \)$

Higher accuracy = more linear/disentangled.

Truncation TrickΒΆ

Observation: Images near center of \(\mathbf{w}\) distribution have higher quality.

Truncation: Sample \(\mathbf{z} \sim \mathcal{N}(0, I)\), then shrink towards mean:

\[ \mathbf{w}' = \bar{\mathbf{w}} + \psi (\mathbf{w} - \bar{\mathbf{w}}) \]

where:

  • \(\bar{\mathbf{w}} = \mathbb{E}_{\mathbf{z}}[f(\mathbf{z})]\): Mean \(\mathbf{w}\) (computed over many samples)

  • \(\psi \in [0, 1]\): Truncation factor

Effect:

  • \(\psi = 1\): Full diversity, may have artifacts

  • \(\psi = 0.7\): Good quality-diversity trade-off (common choice)

  • \(\psi = 0\): Collapse to single image (average face)

6. Noise InjectionΒΆ

MotivationΒΆ

Real images have stochastic variation not captured by deterministic generator:

  • Exact placement of individual hair strands

  • Skin pores, freckles, stubble

  • Background textures

Solution: Inject random noise at each layer.

ImplementationΒΆ

At each layer, after convolution and before AdaIN:

\[ \mathbf{h}' = \mathbf{h} + B \odot \mathbf{n} \]

where:

  • \(\mathbf{n} \sim \mathcal{N}(0, I)\): Gaussian noise (same spatial size as \(\mathbf{h}\))

  • \(B \in \mathbb{R}^C\): Learned per-channel scaling factors

  • \(\odot\): Broadcasting: \(B_c\) multiplies all spatial locations in channel \(c\)

Key: Noise is broadcasted (same noise value for all spatial locations within a channel).

Per-Layer Noise:

  • Each layer gets independent noise

  • Low-res layers: Affects coarse stochasticity (overall lighting variation)

  • High-res layers: Affects fine details (individual hair strands, pores)

Effect of NoiseΒΆ

Experiment: Remove noise from specific layers:

  • No noise at all: Images look β€œplastic”, overly smooth, lack detail

  • No noise at low-res: Coarse structure too deterministic

  • No noise at high-res: Fine details (hair, skin texture) too regular

Optimal: Noise at all layers (default in StyleGAN)

Mathematical FormulationΒΆ

Full forward pass at layer \(i\):

\[ \mathbf{h}_i = \text{LeakyReLU}\left(\text{AdaIN}(\text{Conv}(\mathbf{h}_{i-1}) + B_i \odot \mathbf{n}_i, \mathbf{w}_i)\right) \]

Order of Operations:

  1. Convolution

  2. Add noise

  3. AdaIN (modulate with style)

  4. Activation

7. Mixing RegularizationΒΆ

MotivationΒΆ

Risk: Generator may correlate adjacent styles (e.g., use \(\mathbf{w}_i\) and \(\mathbf{w}_{i+1}\) together).

Goal: Force localization of stylesβ€”each layer should control independent aspects.

MethodΒΆ

During training, use two latent codes \(\mathbf{w}_1, \mathbf{w}_2\):

  1. Sample two independent latents: \(\mathbf{z}_1, \mathbf{z}_2 \sim \mathcal{N}(0, I)\)

  2. Map to \(\mathbf{w}\)-space: \(\mathbf{w}_1 = f(\mathbf{z}_1), \quad \mathbf{w}_2 = f(\mathbf{z}_2)\)

  3. Choose random crossover point \(k \in \{1, \ldots, 18\}\)

  4. Use \(\mathbf{w}_1\) for layers \(1\) to \(k-1\), and \(\mathbf{w}_2\) for layers \(k\) to \(18\):

\[\begin{split} \mathbf{w}_i = \begin{cases} \mathbf{w}_1 & \text{if } i < k \\ \mathbf{w}_2 & \text{if } i \geq k \end{cases} \end{split}\]

Result: Image has coarse features from \(\mathbf{w}_1\) (pose, face shape) and fine features from \(\mathbf{w}_2\) (hair color, expression).

EffectsΒΆ

Low-Level Crossover (\(k\) small, e.g., \(k=4\)):

  • Only very coarse features from \(\mathbf{w}_1\) (pose, rough face shape)

  • Most features from \(\mathbf{w}_2\)

Mid-Level Crossover (\(k\) around 8-10):

  • Pose and face shape from \(\mathbf{w}_1\)

  • Hair style, facial features, expression from \(\mathbf{w}_2\)

High-Level Crossover (\(k\) large, e.g., \(k=14\)):

  • Almost everything from \(\mathbf{w}_1\)

  • Only color scheme and fine texture from \(\mathbf{w}_2\)

Training Benefit:

  • Prevents β€œmode collapse” where generator ignores specific layers

  • Encourages disentanglement: each layer must be independently useful

  • Typically applied with probability 0.9 (90% of training batches use mixing)

8. Progressive GrowingΒΆ

Concept (StyleGAN 1)ΒΆ

Idea: Start training at low resolution (4Γ—4), gradually add layers to increase resolution (8Γ—8, 16Γ—16, …, 1024Γ—1024).

Procedure:

  1. Phase 1: Train generator and discriminator at \(4 \times 4\)

  2. Phase 2: Add layers for \(8 \times 8\), fade them in smoothly

  3. Phase 3: Fully switch to \(8 \times 8\), add layers for \(16 \times 16\)

  4. Repeat until target resolution

Fade-In: Use \(\alpha \in [0, 1]\) to smoothly blend:

\[ \text{output} = \alpha \cdot \text{new\_path} + (1 - \alpha) \cdot \text{old\_path\_upsampled} \]

Initially \(\alpha = 0\) (only old path), gradually increase to \(\alpha = 1\) (only new path).

Benefits:

  • Faster training: Low-res training is faster

  • Stability: Easier optimization at low resolution first

  • Quality: Helps learn coarse structure before fine details

Note: StyleGAN2 removes progressive growing (found to cause artifacts).

9. StyleGAN2 ImprovementsΒΆ

Problem 1: Characteristic ArtifactsΒΆ

Observation: StyleGAN1 generates β€œblob” artifacts that appear/disappear frame-to-frame in videos.

Root Cause: AdaIN normalizes each feature map independently β†’ can β€œerase” information, creating artifacts.

Solution: Modify normalization scheme

StyleGAN2 Normalization:

Instead of: $\( \text{AdaIN}(\mathbf{h}) = \gamma \frac{\mathbf{h} - \mu}{\sigma} + \beta \)$

Use weight demodulation:

  1. Modulate conv weights directly: $\( w'_{ijk} = s_i \cdot w_{ijk} \)\( where \)s_i\( is style-based scale for filter \)i$.

  2. Demodulate (normalize) after convolution: $\( y_i = \frac{\sum_{jk} w'_{ijk} x_{jk}}{\sqrt{\sum_{jk} (w'_{ijk})^2 + \epsilon}} \)$

Benefit: Removes explicit feature normalization β†’ eliminates artifacts.

Problem 2: Progressive Growing ArtifactsΒΆ

Issue: Phase transitions in progressive growing cause position-dependent artifacts.

Solution: Remove progressive growing entirely

  • Train all layers from start (directly at target resolution)

  • Use techniques from StyleGAN2 to stabilize training:

    • Path length regularization

    • Lazy regularization

Problem 3: PPL (Perceptual Path Length)ΒΆ

Goal: Encourage smooth interpolation in \(\mathbf{w}\)-space.

Path Length Regularization:

Penalize large perceptual changes for small \(\mathbf{w}\) changes:

\[ \mathcal{L}_{\text{PPL}} = \mathbb{E}\left[\left(\left\|\mathbf{J}_\mathbf{w}^T \mathbf{y}\right\|_2 - a\right)^2\right] \]

where:

  • \(\mathbf{J}_\mathbf{w}\): Jacobian of generator w.r.t. \(\mathbf{w}\)

  • \(\mathbf{y}\): Random direction in image space

  • \(a\): Moving average of \(\|\mathbf{J}_\mathbf{w}^T \mathbf{y}\|_2\)

Effect: Smooths \(\mathbf{w}\)-space, improves interpolation quality.

Summary of StyleGAN2ΒΆ

Changes from StyleGAN1:

  1. Weight demodulation instead of AdaIN

  2. No progressive growing

  3. Path length regularization

  4. Lazy regularization (apply R1 less frequently)

Results:

  • Higher quality (FID improved from 4.4 to 2.8 on FFHQ)

  • No artifacts

  • Better latent space

10. StyleGAN3: Alias-Free GenerationΒΆ

Problem: Texture StickingΒΆ

Observation: When interpolating between faces, textures (stubble, hair) β€œstick” to screen coordinates instead of object.

Example: Rotate a face β†’ stubble doesn’t rotate with face, stays in same screen location.

Root Cause: Aliasing from upsampling operations.

Aliasing in Neural NetworksΒΆ

Convolution + ReLU + Downsampling/Upsampling violates Nyquist-Shannon sampling theorem:

  • Creates high-frequency signals (from ReLU nonlinearity)

  • Upsampling can alias these signals

  • Aliased signals leak to other frequencies

Result: Generator learns to exploit aliased frequencies β†’ texture sticking.

StyleGAN3 SolutionΒΆ

Alias-Free Generator:

  1. Filtered Nonlinearities: Apply low-pass filter after each ReLU to remove high frequencies

  2. Filtered Upsampling/Downsampling: Use proper anti-aliasing filters

  3. Fourier Features: Input layer uses continuous Fourier features instead of learned constant

Mathematical Formulation:

Standard upsampling (aliases): $\( \mathbf{h}' = \text{Upsample}(\mathbf{h}) \)$

Alias-free upsampling: $\( \mathbf{h}' = \text{LowPassFilter}(\text{Upsample}(\mathbf{h})) \)$

Effect: Textures now rotate/transform with object, not stuck to screen.

Configuration OptionsΒΆ

StyleGAN3 has two configs:

  • StyleGAN3-T (β€œtranslation”): Partial equivariance (translation + rotation)

  • StyleGAN3-R (β€œfull rotation”): Full equivariance (arbitrary rotations)

Trade-Off:

  • StyleGAN3-R: Perfect equivariance, but slightly lower quality

  • StyleGAN3-T: Better quality, less strict equivariance

11. Latent Space AnalysisΒΆ

\(\mathcal{Z}\) vs. \(\mathcal{W}\) SpaceΒΆ

\(\mathcal{Z}\)-Space:

  • Original latent space: \(\mathbf{z} \sim \mathcal{N}(0, I)\)

  • Entangled: Changing \(\mathbf{z}\) affects multiple attributes

\(\mathcal{W}\)-Space:

  • Intermediate latent: \(\mathbf{w} = f(\mathbf{z})\)

  • More disentangled

  • Fixed distribution (learned manifold)

\(\mathcal{W}+\) Space:

  • Extended space: Allow different \(\mathbf{w}_i\) per layer (18 separate vectors)

  • Even more expressive

  • Used for real image inversion

Latent Space EditingΒΆ

Linear Directions in \(\mathcal{W}\):

Many attributes correspond to linear directions:

\[ \mathbf{w}_{\text{male}} = \mathbf{w} + \alpha \cdot \mathbf{d}_{\text{gender}} \]

where \(\mathbf{d}_{\text{gender}}\) is learned direction.

Finding Directions:

  1. Supervised: Train classifier, use gradient: $\( \mathbf{d} = \nabla_{\mathbf{w}} P(y=\text{male} \mid \mathbf{w}) \)$

  2. Unsupervised: PCA on \(\{\mathbf{w}_i\}\) samples, find principal components

    • Component 1 might correspond to age

    • Component 2 might correspond to gender

    • Etc.

  3. Closed-Form (GANSpace): SVD on generator features: $\( \mathbf{G} = \mathbf{U} \boldsymbol{\Sigma} \mathbf{V}^T \)\( Use columns of \)\mathbf{V}$ as edit directions.

Semantic DirectionsΒΆ

Common Edits:

  • Age: \(\mathbf{w} + \alpha \cdot \mathbf{d}_{\text{age}}\)

  • Gender: \(\mathbf{w} + \alpha \cdot \mathbf{d}_{\text{gender}}\)

  • Smile: \(\mathbf{w} + \alpha \cdot \mathbf{d}_{\text{smile}}\)

  • Pose: \(\mathbf{w} + \alpha \cdot \mathbf{d}_{\text{yaw}}\)

Orthogonality: In good latent space, directions are approximately orthogonal: $\( \mathbf{d}_i^T \mathbf{d}_j \approx 0 \quad \text{for } i \neq j \)$

12. Applications and ExtensionsΒΆ

1. Image-to-Image Translation (pix2pixHD, SPADE)ΒΆ

Use StyleGAN as backbone for high-quality image translation.

2. 3D-Aware Generation (Ο€-GAN, EG3D)ΒΆ

Combine StyleGAN with 3D representations (NeRF) for view-consistent generation.

3. Text-to-Image (StyleCLIP)ΒΆ

Use CLIP to guide StyleGAN edits: $\( \mathbf{w}^* = \arg\min_{\mathbf{w}} \mathcal{L}_{\text{CLIP}}(G(\mathbf{w}), \text{text}) + \lambda \|\mathbf{w} - \mathbf{w}_0\|^2 \)$

4. Real Image Editing (GAN Inversion)ΒΆ

Goal: Edit real photos using StyleGAN’s latent space

Steps:

  1. Invert image to latent code: \(\mathbf{x} \rightarrow \mathbf{w}\) $\( \mathbf{w}^* = \arg\min_{\mathbf{w}} \|G(\mathbf{w}) - \mathbf{x}\|^2 + \text{regularization} \)$

  2. Edit latent: \(\mathbf{w}' = \mathbf{w}^* + \alpha \cdot \mathbf{d}\)

  3. Generate: \(\mathbf{x}' = G(\mathbf{w}')\)

Methods:

  • Optimization: Directly optimize \(\mathbf{w}\) (slow but accurate)

  • Encoder: Train encoder \(E: \mathbf{x} \rightarrow \mathbf{w}\) (fast, less accurate)

  • Hybrid: Encoder + optimization fine-tuning

5. Domain Adaptation (StyleGAN-NADA)ΒΆ

Fine-tune StyleGAN on new domain using text guidance (no new images needed!).

Example: β€œPixar character”, β€œZombie face”

13. Implementation DetailsΒΆ

TrainingΒΆ

Datasets:

  • FFHQ (Flickr Faces HQ): 70k faces at 1024Γ—1024

  • LSUN: Bedroom, church, etc.

  • Custom datasets (min ~1k images for fine-tuning)

Hyperparameters:

  • Batch size: 32 (distributed over multiple GPUs)

  • Learning rate: 0.002 (generator), 0.002 (discriminator)

  • Optimizer: Adam with \(\beta_1 = 0\), \(\beta_2 = 0.99\)

  • R1 regularization: \(\gamma = 10\)

  • Training time: ~1 week on 8Γ— V100 GPUs for FFHQ

Architecture DetailsΒΆ

Mapping Network:

  • 8 fully-connected layers

  • 512 units per layer

  • Leaky ReLU (slope 0.2)

  • Learning rate multiplier: 0.01 (slower learning for mapping net)

Synthesis Network:

  • Constant input: \(4 \times 4 \times 512\)

  • Bilinear upsampling (StyleGAN2) or filtered upsampling (StyleGAN3)

  • Conv kernels: \(3 \times 3\)

  • Noise: Per-layer learnable scaling

Discriminator:

  • Mirror of synthesis network (no mapping network)

  • Residual connections (StyleGAN2)

  • MiniBatch StdDev layer at end

InferenceΒΆ

Sampling:

  1. Sample \(\mathbf{z} \sim \mathcal{N}(0, I)\)

  2. \(\mathbf{w} = f(\mathbf{z})\)

  3. (Optional) Truncate: \(\mathbf{w}' = \bar{\mathbf{w}} + \psi(\mathbf{w} - \bar{\mathbf{w}})\)

  4. \(\mathbf{x} = G(\mathbf{w}')\)

Speed:

  • StyleGAN2 (1024Γ—1024): ~30ms per image on V100

  • StyleGAN3: ~50ms per image (more computation due to filtering)

Best PracticesΒΆ

Training from ScratchΒΆ

  1. Data: At least 1k high-quality images (10k+ for best results)

  2. Augmentation: Use ADA (Adaptive Discriminator Augmentation) for small datasets

  3. Resolution: Start with 256Γ—256 or 512Γ—512 (faster iteration)

  4. Monitor: FID, qualitative samples, discriminator/generator loss ratio

Fine-Tuning Pretrained ModelΒΆ

  1. Use StyleGAN2-ADA checkpoint (best for transfer)

  2. Lower learning rate: 0.0001 (10Γ— smaller)

  3. Freeze mapping network (optional, helps preserve quality)

  4. Small dataset: ADA augmentation essential

Latent Space EditingΒΆ

  1. Use \(\mathcal{W}+\) space for real image inversion (better reconstruction)

  2. Regularize: Encourage \(\mathbf{w}_i \approx \bar{\mathbf{w}}\) to stay in-distribution

  3. Multiple edits: Apply orthogonalized directions to avoid interference

Common Issues and SolutionsΒΆ

1. Mode CollapseΒΆ

Symptoms: Generator produces limited variety Solutions:

  • Increase batch size

  • Add R1 regularization

  • Use mixing regularization

  • Check discriminator isn’t too strong

2. Training InstabilityΒΆ

Symptoms: Loss oscillates, quality degrades Solutions:

  • Reduce learning rate

  • Increase R1 regularization weight

  • Use gradient clipping

  • Check for NaNs in gradients

3. Poor Inversion QualityΒΆ

Symptoms: \(G(E(\mathbf{x})) \not\approx \mathbf{x}\) Solutions:

  • Use \(\mathcal{W}+\) space (18 separate \(\mathbf{w}_i\))

  • Add perceptual loss (LPIPS)

  • Optimize with noise regularization

  • Hybrid encoder + optimization

Future DirectionsΒΆ

  1. 3D-Aware StyleGAN: Full 3D control (EG3D, StyleSDF)

  2. Video StyleGAN: Temporally consistent generation

  3. Controllable Editing: More semantic, disentangled controls

  4. Efficient Training: Reduce computational cost

  5. Generalization: Single model for multiple domains

SummaryΒΆ

StyleGAN revolutionized generative modeling by introducing:

  1. Mapping Network: \(\mathbf{z} \rightarrow \mathbf{w}\) for disentanglement

  2. AdaIN: Style control at multiple scales

  3. Noise Injection: Stochastic variation

  4. Mixing Regularization: Forced localization

StyleGAN2 improvements:

  • Weight demodulation (no artifacts)

  • No progressive growing

  • Path length regularization

StyleGAN3 advances:

  • Alias-free generation (texture sticking solved)

Impact:

  • SOTA image quality (FID ~2 on FFHQ)

  • Controllable generation (linear edits)

  • Wide applications (editing, 3D, text-to-image)

Key Insight: Separation of content (constant input) and style (\(\mathbf{w}\) via AdaIN) enables unprecedented control.

"""
Advanced StyleGAN - Production Implementation
Comprehensive implementation of StyleGAN architecture

Components Implemented:
1. Mapping Network (8-layer MLP)
2. Synthesis Network with AdaIN
3. Noise Injection
4. Style Mixing
5. Truncation Trick
6. Progressive Growing (StyleGAN1)
7. Weight Demodulation (StyleGAN2)

Features:
- Complete StyleGAN generator
- Style-based layer with AdaIN
- Latent space manipulation
- Image generation and interpolation
- Style mixing visualization

Author: Advanced Deep Learning Course
Date: 2024
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple
import numpy as np
import matplotlib.pyplot as plt
from math import log2


# ============================================================================
# 1. Mapping Network
# ============================================================================

class MappingNetwork(nn.Module):
    """
    Mapping Network: z -> w
    
    Transforms random latent z into intermediate latent w.
    8-layer MLP with LeakyReLU activations.
    """
    
    def __init__(
        self,
        z_dim: int = 512,
        w_dim: int = 512,
        num_layers: int = 8,
        lr_multiplier: float = 0.01
    ):
        """
        Args:
            z_dim: Dimension of z latent
            w_dim: Dimension of w latent
            num_layers: Number of FC layers
            lr_multiplier: Learning rate multiplier (for equalized learning rate)
        """
        super().__init__()
        self.z_dim = z_dim
        self.w_dim = w_dim
        
        # Build layers
        layers = []
        in_dim = z_dim
        
        for i in range(num_layers):
            layers.append(EqualizedLinear(in_dim, w_dim, lr_multiplier=lr_multiplier))
            layers.append(nn.LeakyReLU(0.2))
            in_dim = w_dim
        
        self.mapping = nn.Sequential(*layers)
        
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Args:
            z: Random latent [batch_size, z_dim]
        
        Returns:
            w: Intermediate latent [batch_size, w_dim]
        """
        return self.mapping(z)


# ============================================================================
# 2. Equalized Learning Rate Layers
# ============================================================================

class EqualizedLinear(nn.Module):
    """
    Linear layer with equalized learning rate.
    
    Scales weights at runtime instead of initialization.
    """
    
    def __init__(self, in_features: int, out_features: int, bias: bool = True, lr_multiplier: float = 1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.lr_multiplier = lr_multiplier
        
        # Initialize with N(0, 1)
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
        
        # He initialization scale
        self.scale = (1 / np.sqrt(in_features)) * lr_multiplier
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Scale weights at runtime
        weight = self.weight * self.scale
        bias = self.bias * self.lr_multiplier if self.bias is not None else None
        
        return F.linear(x, weight, bias)


class EqualizedConv2d(nn.Module):
    """Conv2d with equalized learning rate."""
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        bias: bool = True
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
        # Initialize with N(0, 1)
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
        
        # He initialization scale
        self.scale = 1 / np.sqrt(in_channels * kernel_size * kernel_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        weight = self.weight * self.scale
        return F.conv2d(x, weight, self.bias, self.stride, self.padding)


# ============================================================================
# 3. Adaptive Instance Normalization (AdaIN)
# ============================================================================

class AdaIN(nn.Module):
    """
    Adaptive Instance Normalization.
    
    Modulates features using style vector.
    """
    
    def __init__(self, num_features: int, w_dim: int):
        super().__init__()
        # Affine transformations for style
        self.scale_transform = EqualizedLinear(w_dim, num_features)
        self.bias_transform = EqualizedLinear(w_dim, num_features)
        
    def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Features [batch, channels, height, width]
            w: Style vector [batch, w_dim]
        
        Returns:
            Modulated features [batch, channels, height, width]
        """
        # Compute style parameters
        scale = self.scale_transform(w).unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]
        bias = self.bias_transform(w).unsqueeze(-1).unsqueeze(-1)    # [B, C, 1, 1]
        
        # Instance normalization
        mean = x.mean(dim=[2, 3], keepdim=True)
        std = x.std(dim=[2, 3], keepdim=True) + 1e-8
        
        # Normalize and apply style
        normalized = (x - mean) / std
        return scale * normalized + bias


# ============================================================================
# 4. Noise Injection
# ============================================================================

class NoiseInjection(nn.Module):
    """
    Inject scaled noise into features.
    """
    
    def __init__(self, num_channels: int):
        super().__init__()
        # Learnable per-channel scaling
        self.weight = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        
    def forward(self, x: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: Features [batch, channels, height, width]
            noise: Optional noise [batch, 1, height, width]
        
        Returns:
            Features with noise added
        """
        if noise is None:
            # Generate noise
            batch, _, height, width = x.shape
            noise = torch.randn(batch, 1, height, width, device=x.device)
        
        return x + self.weight * noise


# ============================================================================
# 5. Style-Based Generator Block
# ============================================================================

class StyleBlock(nn.Module):
    """
    Style-based generator block.
    
    Applies: Conv -> Noise -> AdaIN -> Activation
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        w_dim: int,
        upsample: bool = False
    ):
        super().__init__()
        self.upsample = upsample
        
        # Convolution
        self.conv = EqualizedConv2d(in_channels, out_channels, kernel_size=3, padding=1)
        
        # Noise injection
        self.noise = NoiseInjection(out_channels)
        
        # AdaIN
        self.adain = AdaIN(out_channels, w_dim)
        
        # Activation
        self.activation = nn.LeakyReLU(0.2)
        
    def forward(self, x: torch.Tensor, w: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Args:
            x: Input features
            w: Style vector
            noise: Optional noise
        
        Returns:
            Output features
        """
        # Upsample if needed
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        
        # Conv -> Noise -> AdaIN -> Activation
        x = self.conv(x)
        x = self.noise(x, noise)
        x = self.adain(x, w)
        x = self.activation(x)
        
        return x


# ============================================================================
# 6. StyleGAN Generator
# ============================================================================

class StyleGANGenerator(nn.Module):
    """
    StyleGAN Generator.
    
    Maps z -> w via mapping network, then generates image via synthesis network.
    """
    
    def __init__(
        self,
        z_dim: int = 512,
        w_dim: int = 512,
        img_resolution: int = 256,
        img_channels: int = 3,
        base_channels: int = 512
    ):
        super().__init__()
        self.z_dim = z_dim
        self.w_dim = w_dim
        self.img_resolution = img_resolution
        self.img_channels = img_channels
        
        # Mapping network
        self.mapping = MappingNetwork(z_dim, w_dim)
        
        # Constant input (learned)
        self.const_input = nn.Parameter(torch.randn(1, base_channels, 4, 4))
        
        # Calculate number of layers based on resolution
        self.log_resolution = int(log2(img_resolution))
        self.num_layers = (self.log_resolution - 1) * 2  # 2 conv per resolution
        
        # Build synthesis network
        self.synthesis_blocks = nn.ModuleList()
        self.to_rgb_layers = nn.ModuleList()
        
        in_channels = base_channels
        resolution = 4
        
        # Initial block (no upsample)
        self.initial_block1 = StyleBlock(base_channels, base_channels, w_dim, upsample=False)
        self.initial_block2 = StyleBlock(base_channels, base_channels, w_dim, upsample=False)
        
        # Progressive blocks
        for i in range(2, self.log_resolution + 1):
            out_channels = base_channels // (2 ** (i - 2)) if i > 2 else base_channels
            out_channels = max(out_channels, 16)  # Min 16 channels
            
            # Two conv blocks per resolution
            block1 = StyleBlock(in_channels, out_channels, w_dim, upsample=True)
            block2 = StyleBlock(out_channels, out_channels, w_dim, upsample=False)
            
            self.synthesis_blocks.append(nn.ModuleList([block1, block2]))
            
            # To RGB
            to_rgb = EqualizedConv2d(out_channels, img_channels, kernel_size=1)
            self.to_rgb_layers.append(to_rgb)
            
            in_channels = out_channels
            resolution *= 2
        
    def forward(
        self,
        z: torch.Tensor,
        truncation: float = 1.0,
        truncation_mean: Optional[torch.Tensor] = None,
        styles: Optional[List[torch.Tensor]] = None,
        mix_index: Optional[int] = None,
        return_latents: bool = False
    ) -> torch.Tensor:
        """
        Args:
            z: Random latent [batch, z_dim]
            truncation: Truncation psi (1.0 = no truncation)
            truncation_mean: Mean w for truncation
            styles: Optional list of w vectors for style mixing
            mix_index: Layer index to switch styles (for mixing)
            return_latents: Whether to return w latents
        
        Returns:
            Generated image [batch, channels, resolution, resolution]
        """
        batch_size = z.size(0)
        
        # Map z to w
        w = self.mapping(z)
        
        # Truncation trick
        if truncation < 1.0 and truncation_mean is not None:
            w = truncation_mean + truncation * (w - truncation_mean)
        
        # Style mixing
        if styles is None:
            # Use same w for all layers
            styles = [w] * (self.num_layers + 2)  # +2 for initial blocks
        else:
            # Mix styles at mix_index
            if mix_index is not None:
                styles = [styles[0] if i < mix_index else styles[1] for i in range(self.num_layers + 2)]
        
        # Start with constant input
        x = self.const_input.repeat(batch_size, 1, 1, 1)
        
        # Initial blocks
        x = self.initial_block1(x, styles[0])
        x = self.initial_block2(x, styles[1])
        
        # Progressive synthesis
        style_idx = 2
        for i, (block1, block2) in enumerate(self.synthesis_blocks):
            x = block1(x, styles[style_idx])
            style_idx += 1
            x = block2(x, styles[style_idx])
            style_idx += 1
            
            # Skip to RGB for intermediate resolutions (progressive growing)
            # For simplicity, we only use final to_rgb
        
        # Final to RGB
        rgb = self.to_rgb_layers[-1](x)
        
        if return_latents:
            return rgb, w
        return rgb
    
    def generate(
        self,
        batch_size: int = 1,
        truncation: float = 0.7,
        device: str = 'cuda'
    ) -> torch.Tensor:
        """Generate random images."""
        z = torch.randn(batch_size, self.z_dim, device=device)
        
        # Compute truncation mean
        with torch.no_grad():
            z_mean = torch.randn(1000, self.z_dim, device=device)
            w_mean = self.mapping(z_mean).mean(0, keepdim=True)
        
        return self.forward(z, truncation=truncation, truncation_mean=w_mean)
    
    def style_mixing(
        self,
        z1: torch.Tensor,
        z2: torch.Tensor,
        mix_index: int
    ) -> torch.Tensor:
        """
        Mix styles from two latents.
        
        Args:
            z1: First latent
            z2: Second latent
            mix_index: Layer to switch from z1 to z2
        
        Returns:
            Mixed image
        """
        w1 = self.mapping(z1)
        w2 = self.mapping(z2)
        
        return self.forward(z1, styles=[w1, w2], mix_index=mix_index)


# ============================================================================
# 7. StyleGAN2: Weight Demodulation
# ============================================================================

class ModulatedConv2d(nn.Module):
    """
    Modulated convolution with weight demodulation (StyleGAN2).
    
    Replaces AdaIN with direct weight modulation.
    """
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        w_dim: int,
        demodulate: bool = True
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.demodulate = demodulate
        
        # Convolution weight
        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        self.scale = 1 / np.sqrt(in_channels * kernel_size * kernel_size)
        
        # Style modulation
        self.modulation = EqualizedLinear(w_dim, in_channels)
        
    def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input [batch, in_channels, height, width]
            w: Style [batch, w_dim]
        
        Returns:
            Output [batch, out_channels, height, width]
        """
        batch, in_c, height, width = x.shape
        
        # Compute modulation
        style = self.modulation(w).view(batch, 1, in_c, 1, 1)  # [B, 1, C_in, 1, 1]
        
        # Modulate weights
        weight = self.weight.unsqueeze(0) * self.scale  # [1, C_out, C_in, K, K]
        weight = weight * style  # [B, C_out, C_in, K, K]
        
        # Demodulate
        if self.demodulate:
            demod = torch.rsqrt((weight ** 2).sum(dim=[2, 3, 4], keepdim=True) + 1e-8)
            weight = weight * demod
        
        # Reshape for grouped convolution
        weight = weight.view(batch * self.out_channels, in_c, self.kernel_size, self.kernel_size)
        
        # Reshape input for grouped conv
        x = x.view(1, batch * in_c, height, width)
        
        # Grouped convolution
        out = F.conv2d(x, weight, padding=self.kernel_size // 2, groups=batch)
        
        # Reshape output
        out = out.view(batch, self.out_channels, height, width)
        
        return out


# ============================================================================
# 8. Visualization and Utils
# ============================================================================

def visualize_generation(generator: StyleGANGenerator, num_samples: int = 16, device: str = 'cuda'):
    """Visualize random generations."""
    generator.eval()
    
    with torch.no_grad():
        images = generator.generate(num_samples, truncation=0.7, device=device)
    
    # Convert to numpy
    images = images.cpu().permute(0, 2, 3, 1).numpy()
    images = (images + 1) / 2  # [-1, 1] -> [0, 1]
    images = np.clip(images, 0, 1)
    
    # Plot
    grid_size = int(np.sqrt(num_samples))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
    
    for i, ax in enumerate(axes.flat):
        if i < num_samples:
            ax.imshow(images[i])
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()


def visualize_style_mixing(
    generator: StyleGANGenerator,
    num_sources: int = 4,
    num_destinations: int = 4,
    device: str = 'cuda'
):
    """
    Visualize style mixing matrix.
    
    Rows: coarse styles (pose, face shape)
    Cols: fine styles (color, texture)
    """
    generator.eval()
    
    # Sample latents
    z_coarse = torch.randn(num_sources, generator.z_dim, device=device)
    z_fine = torch.randn(num_destinations, generator.z_dim, device=device)
    
    # Map to w
    with torch.no_grad():
        w_coarse = generator.mapping(z_coarse)
        w_fine = generator.mapping(z_fine)
    
    # Generate mixing matrix
    images = []
    
    for i in range(num_sources):
        row = []
        for j in range(num_destinations):
            # Mix at mid-point (layer 8 out of 18)
            img = generator(z_coarse[i:i+1], styles=[w_coarse[i:i+1], w_fine[j:j+1]], mix_index=8)
            row.append(img[0])
        images.append(torch.stack(row))
    
    images = torch.stack(images)  # [sources, destinations, C, H, W]
    
    # Plot
    images = images.cpu().permute(0, 1, 3, 4, 2).numpy()
    images = (images + 1) / 2
    images = np.clip(images, 0, 1)
    
    fig, axes = plt.subplots(num_sources, num_destinations, figsize=(15, 15))
    
    for i in range(num_sources):
        for j in range(num_destinations):
            axes[i, j].imshow(images[i, j])
            axes[i, j].axis('off')
            
            if i == 0:
                axes[i, j].set_title(f'Fine {j}', fontsize=10)
            if j == 0:
                axes[i, j].set_ylabel(f'Coarse {i}', fontsize=10)
    
    plt.suptitle('Style Mixing: Coarse (rows) Γ— Fine (cols)', fontsize=14)
    plt.tight_layout()
    plt.show()


def interpolate_latents(
    generator: StyleGANGenerator,
    z1: torch.Tensor,
    z2: torch.Tensor,
    num_steps: int = 10,
    device: str = 'cuda'
) -> torch.Tensor:
    """
    Interpolate between two latents.
    
    Args:
        generator: StyleGAN generator
        z1: Start latent
        z2: End latent
        num_steps: Number of interpolation steps
    
    Returns:
        Interpolated images
    """
    generator.eval()
    
    # Linear interpolation
    alphas = torch.linspace(0, 1, num_steps, device=device)
    
    images = []
    with torch.no_grad():
        for alpha in alphas:
            z = (1 - alpha) * z1 + alpha * z2
            img = generator(z)
            images.append(img[0])
    
    return torch.stack(images)


# ============================================================================
# 9. Demo
# ============================================================================

def demo_stylegan():
    """Demonstrate StyleGAN capabilities."""
    print("=" * 80)
    print("StyleGAN - Advanced Generative Model Demo")
    print("=" * 80)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nDevice: {device}")
    
    # Create generator (small version for demo)
    print("\nCreating StyleGAN generator (64x64 for demo)...")
    generator = StyleGANGenerator(
        z_dim=512,
        w_dim=512,
        img_resolution=64,
        img_channels=3
    ).to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in generator.parameters())
    print(f"Total parameters: {total_params:,}")
    
    # 1. Random generation
    print("\n" + "=" * 60)
    print("1. Random Image Generation")
    print("=" * 60)
    
    z = torch.randn(4, 512, device=device)
    
    with torch.no_grad():
        images = generator(z)
    
    print(f"Generated images: {images.shape}")
    print(f"Min: {images.min().item():.3f}, Max: {images.max().item():.3f}")
    
    # 2. Style mixing
    print("\n" + "=" * 60)
    print("2. Style Mixing")
    print("=" * 60)
    
    z1 = torch.randn(1, 512, device=device)
    z2 = torch.randn(1, 512, device=device)
    
    print("Mixing at different layer indices:")
    for mix_idx in [2, 6, 10]:
        with torch.no_grad():
            mixed = generator.style_mixing(z1, z2, mix_idx)
        print(f"  Mix at layer {mix_idx}: {mixed.shape}")
    
    # 3. Latent interpolation
    print("\n" + "=" * 60)
    print("3. Latent Space Interpolation")
    print("=" * 60)
    
    interp_images = interpolate_latents(generator, z1, z2, num_steps=8, device=device)
    print(f"Interpolated images: {interp_images.shape}")
    
    # 4. Truncation trick
    print("\n" + "=" * 60)
    print("4. Truncation Trick")
    print("=" * 60)
    
    for psi in [0.5, 0.7, 1.0]:
        with torch.no_grad():
            # Compute mean w
            z_samples = torch.randn(100, 512, device=device)
            w_mean = generator.mapping(z_samples).mean(0, keepdim=True)
            
            img = generator(z1, truncation=psi, truncation_mean=w_mean)
        print(f"  Truncation ψ={psi}: quality-diversity trade-off")
    
    # 5. Architecture analysis
    print("\n" + "=" * 60)
    print("5. Architecture Analysis")
    print("=" * 60)
    
    print(f"\nMapping Network:")
    print(f"  Input: z ∈ R^{generator.z_dim}")
    print(f"  Output: w ∈ R^{generator.w_dim}")
    print(f"  Layers: 8 Γ— FC({generator.w_dim}) with LeakyReLU")
    
    print(f"\nSynthesis Network:")
    print(f"  Starting resolution: 4Γ—4")
    print(f"  Final resolution: {generator.img_resolution}Γ—{generator.img_resolution}")
    print(f"  Number of style blocks: {len(generator.synthesis_blocks) * 2 + 2}")
    print(f"  AdaIN layers: {generator.num_layers + 2}")
    
    print("\n" + "=" * 80)
    print("Demo complete!")
    print("=" * 80)


def demo_comparison():
    """Compare standard GAN vs StyleGAN architecture."""
    print("\n" + "=" * 80)
    print("Architectural Comparison: Standard GAN vs StyleGAN")
    print("=" * 80)
    
    print("\nStandard GAN:")
    print("  Input: z β†’ [Concat with spatial coords]")
    print("  Process: Conv layers progressively increase resolution")
    print("  Control: Global (entire z affects all features)")
    print("  Problem: Entangled latent space")
    
    print("\nStyleGAN:")
    print("  Input: Learned constant 4Γ—4 tensor")
    print("  Latent: z β†’ MappingNet β†’ w (disentangled)")
    print("  Control: Local (each AdaIN layer independently controlled)")
    print("  Noise: Per-layer stochastic variation")
    print("  Benefits:")
    print("    - Disentangled representations")
    print("    - Scale-specific control (pose vs texture)")
    print("    - Smooth interpolation")
    print("    - Style mixing")


if __name__ == "__main__":
    print("\nAdvanced StyleGAN Implementation\n")
    
    # Run demos
    demo_stylegan()
    demo_comparison()
    
    print("\n" + "=" * 80)
    print("Key Takeaways:")
    print("=" * 80)
    print("1. Mapping Network (z β†’ w) enables disentanglement")
    print("2. AdaIN provides scale-specific style control")
    print("3. Noise injection adds stochastic variation")
    print("4. Style mixing forces feature localization")
    print("5. Truncation trick balances quality vs diversity")
    print("6. StyleGAN2's weight demodulation fixes artifacts")
    print("=" * 80)