import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Progressive GrowingΒΆ

StrategyΒΆ

Start at low resolution (4Γ—4), progressively add layers:

\[4 \times 4 \rightarrow 8 \times 8 \rightarrow 16 \times 16 \rightarrow 32 \times 32\]

Fade-inΒΆ

\[\text{output} = \alpha \cdot \text{new\_layer} + (1-\alpha) \cdot \text{old\_upsampled}\]

where \(\alpha\) increases from 0 to 1.

πŸ“š Reference Materials:

class PixelNorm(nn.Module):
    """Pixel-wise feature normalization."""
    
    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)

class EqualizedConv2d(nn.Module):
    """Equalized learning rate convolution."""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = np.sqrt(2 / (in_channels * kernel_size ** 2))
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.conv.bias)
    
    def forward(self, x):
        return self.conv(x * self.scale)

class GeneratorBlock(nn.Module):
    """Generator block with upsampling."""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = EqualizedConv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = EqualizedConv2d(out_channels, out_channels, 3, padding=1)
        self.pixel_norm = PixelNorm()
    
    def forward(self, x):
        x = self.upsample(x)
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = self.pixel_norm(x)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = self.pixel_norm(x)
        return x

class DiscriminatorBlock(nn.Module):
    """Discriminator block with downsampling."""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = EqualizedConv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = EqualizedConv2d(out_channels, out_channels, 3, padding=1)
        self.downsample = nn.AvgPool2d(2)
    
    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        return self.downsample(x)

print("Blocks defined")

Progressive GeneratorΒΆ

The progressive generator starts producing images at the lowest resolution (e.g., \(4 \times 4\)) and gradually adds upsampling layers to reach the target resolution. When a new resolution layer is introduced, it is faded in using a blending parameter \(\alpha\) that transitions from 0 (only the old lower-resolution output) to 1 (only the new higher-resolution output) over several thousand training steps. This smooth transition prevents the sudden shock of new layers and allows the existing weights to adapt gradually. The result is a generator that first learns coarse structure, then progressively refines fine detail – mirroring how artists work from rough sketch to finished piece.

class ProgressiveGenerator(nn.Module):
    """Progressive GAN generator."""
    
    def __init__(self, latent_dim=512, max_resolution=32):
        super().__init__()
        self.latent_dim = latent_dim
        self.max_resolution = max_resolution
        
        # Initial 4x4 block
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0),
            nn.LeakyReLU(0.2),
            PixelNorm(),
            EqualizedConv2d(512, 512, 3, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm()
        )
        
        # Progressive blocks
        self.blocks = nn.ModuleList([
            GeneratorBlock(512, 512),  # 4 -> 8
            GeneratorBlock(512, 256),  # 8 -> 16
            GeneratorBlock(256, 128),  # 16 -> 32
        ])
        
        # To RGB layers
        self.to_rgb = nn.ModuleList([
            EqualizedConv2d(512, 1, 1),  # 4x4
            EqualizedConv2d(512, 1, 1),  # 8x8
            EqualizedConv2d(256, 1, 1),  # 16x16
            EqualizedConv2d(128, 1, 1),  # 32x32
        ])
    
    def forward(self, z, depth, alpha):
        """Forward pass with fade-in.
        
        Args:
            z: Latent code
            depth: Current resolution level (0=4x4, 1=8x8, ...)
            alpha: Fade-in parameter [0, 1]
        """
        x = self.initial(z.view(-1, self.latent_dim, 1, 1))
        
        if depth == 0:
            return torch.tanh(self.to_rgb[0](x))
        
        for i in range(depth):
            if i == depth - 1:
                # Fade-in new layer
                x_old = F.interpolate(x, scale_factor=2, mode='nearest')
                x_old = self.to_rgb[i](x_old)
                
                x_new = self.blocks[i](x)
                x_new = self.to_rgb[i + 1](x_new)
                
                return torch.tanh(alpha * x_new + (1 - alpha) * x_old)
            else:
                x = self.blocks[i](x)
        
        return torch.tanh(self.to_rgb[depth](x))

gen = ProgressiveGenerator().to(device)
print("Generator created")

Progressive DiscriminatorΒΆ

The discriminator mirrors the generator’s progressive growth: it starts operating on low-resolution images and gains new downsampling layers as the resolution increases. The same fade-in mechanism is applied, ensuring the discriminator and generator grow in lockstep. At each resolution, the discriminator’s task is calibrated to the level of detail currently being generated, which provides more useful gradient signals than a full-resolution discriminator judging blurry low-resolution outputs. This matched progression is what makes progressive training stable for generating high-resolution images (up to \(1024 \times 1024\) in the original paper).

class ProgressiveDiscriminator(nn.Module):
    """Progressive GAN discriminator."""
    
    def __init__(self, max_resolution=32):
        super().__init__()
        
        # From RGB layers
        self.from_rgb = nn.ModuleList([
            EqualizedConv2d(1, 512, 1),  # 4x4
            EqualizedConv2d(1, 512, 1),  # 8x8
            EqualizedConv2d(1, 256, 1),  # 16x16
            EqualizedConv2d(1, 128, 1),  # 32x32
        ])
        
        # Progressive blocks
        self.blocks = nn.ModuleList([
            DiscriminatorBlock(128, 256),  # 32 -> 16
            DiscriminatorBlock(256, 512),  # 16 -> 8
            DiscriminatorBlock(512, 512),  # 8 -> 4
        ])
        
        # Final block
        self.final = nn.Sequential(
            EqualizedConv2d(512, 512, 3, padding=1),
            nn.LeakyReLU(0.2),
            EqualizedConv2d(512, 512, 4),
            nn.LeakyReLU(0.2),
            EqualizedConv2d(512, 1, 1)
        )
    
    def forward(self, x, depth, alpha):
        if depth == 0:
            x = F.leaky_relu(self.from_rgb[0](x), 0.2)
            return self.final(x).view(-1)
        
        # Fade-in
        x_new = F.leaky_relu(self.from_rgb[depth](x), 0.2)
        x_old = F.avg_pool2d(x, 2)
        x_old = F.leaky_relu(self.from_rgb[depth - 1](x_old), 0.2)
        
        x = alpha * x_new + (1 - alpha) * x_old
        
        for i in range(depth - 1, -1, -1):
            x = self.blocks[2 - i](x) if i < 3 else x
        
        return self.final(x).view(-1)

disc = ProgressiveDiscriminator().to(device)
print("Discriminator created")

TrainingΒΆ

Progressive training proceeds through multiple phases, one per resolution level. At each phase, the model trains for a fixed number of iterations at the current resolution before transitioning to the next. The WGAN-GP loss (Wasserstein loss with gradient penalty) is commonly used because it provides smoother gradients than the standard GAN loss, which is especially important during the resolution transitions. Training time scales with resolution: higher resolutions require more parameters, larger batch sizes, and more iterations. Monitoring FID (Frechet Inception Distance) at each resolution level tracks generation quality throughout the progressive training process.

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)

def train_progressive(gen, disc, depth=0, n_epochs=3):
    """Train at specific resolution."""
    optimizer_G = torch.optim.Adam(gen.parameters(), lr=1e-3, betas=(0.0, 0.99))
    optimizer_D = torch.optim.Adam(disc.parameters(), lr=1e-3, betas=(0.0, 0.99))
    
    resolution = 4 * (2 ** depth)
    
    for epoch in range(n_epochs):
        # Alpha for fade-in
        alpha = min(1.0, (epoch + 1) / n_epochs)
        
        for real_imgs, _ in loader:
            batch_size = real_imgs.size(0)
            real_imgs = F.interpolate(real_imgs, size=resolution)
            real_imgs = real_imgs.to(device)
            
            # Train Discriminator
            z = torch.randn(batch_size, 512).to(device)
            fake_imgs = gen(z, depth, alpha)
            
            real_score = disc(real_imgs, depth, alpha)
            fake_score = disc(fake_imgs.detach(), depth, alpha)
            
            loss_D = -torch.mean(real_score) + torch.mean(fake_score)
            
            optimizer_D.zero_grad()
            loss_D.backward()
            optimizer_D.step()
            
            # Train Generator
            z = torch.randn(batch_size, 512).to(device)
            fake_imgs = gen(z, depth, alpha)
            fake_score = disc(fake_imgs, depth, alpha)
            
            loss_G = -torch.mean(fake_score)
            
            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()
        
        print(f"Depth {depth}, Epoch {epoch+1}, Ξ±={alpha:.2f}, D={loss_D.item():.4f}, G={loss_G.item():.4f}")

# Train progressively
for depth in range(3):
    print(f"\nTraining at {4 * (2 ** depth)}Γ—{4 * (2 ** depth)}")
    train_progressive(gen, disc, depth, n_epochs=3)

Generate SamplesΒΆ

After progressive training completes at the final resolution, we generate a batch of samples to assess the model’s output quality and diversity. Comparing samples generated at intermediate resolutions (from earlier training phases) with final-resolution samples shows how the model’s capabilities evolve during progressive training – from capturing coarse shapes to rendering fine textures and details.

gen.eval()

fig, axes = plt.subplots(3, 8, figsize=(16, 6))

with torch.no_grad():
    z = torch.randn(8, 512).to(device)
    
    for depth in range(3):
        imgs = gen(z, depth, alpha=1.0)
        for i in range(8):
            axes[depth, i].imshow(imgs[i, 0].cpu(), cmap='gray')
            axes[depth, i].axis('off')
            if i == 0:
                res = 4 * (2 ** depth)
                axes[depth, i].set_ylabel(f'{res}Γ—{res}', fontsize=11)

plt.suptitle('Progressive Generation', fontsize=13)
plt.tight_layout()
plt.show()

SummaryΒΆ

Progressive GAN:ΒΆ

  1. Gradual resolution increase

  2. Layer fade-in smooths training

  3. Stable high-resolution synthesis

  4. Faster convergence

Key Techniques:ΒΆ

  • Pixel normalization

  • Equalized learning rate

  • Minibatch standard deviation

Applications:ΒΆ

  • High-resolution face generation (1024Γ—1024)

  • Texture synthesis

  • Data augmentation

Extensions:ΒΆ

  • StyleGAN: Style-based generation

  • ProGAN: Original implementation

  • BigGAN: Large-scale training