import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Progressive GrowingΒΆ
StrategyΒΆ
Start at low resolution (4Γ4), progressively add layers:
Fade-inΒΆ
where \(\alpha\) increases from 0 to 1.
π Reference Materials:
gan.pdf - Gan
class PixelNorm(nn.Module):
"""Pixel-wise feature normalization."""
def forward(self, x):
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)
class EqualizedConv2d(nn.Module):
"""Equalized learning rate convolution."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.scale = np.sqrt(2 / (in_channels * kernel_size ** 2))
nn.init.normal_(self.conv.weight)
nn.init.zeros_(self.conv.bias)
def forward(self, x):
return self.conv(x * self.scale)
class GeneratorBlock(nn.Module):
"""Generator block with upsampling."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.conv1 = EqualizedConv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = EqualizedConv2d(out_channels, out_channels, 3, padding=1)
self.pixel_norm = PixelNorm()
def forward(self, x):
x = self.upsample(x)
x = F.leaky_relu(self.conv1(x), 0.2)
x = self.pixel_norm(x)
x = F.leaky_relu(self.conv2(x), 0.2)
x = self.pixel_norm(x)
return x
class DiscriminatorBlock(nn.Module):
"""Discriminator block with downsampling."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = EqualizedConv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = EqualizedConv2d(out_channels, out_channels, 3, padding=1)
self.downsample = nn.AvgPool2d(2)
def forward(self, x):
x = F.leaky_relu(self.conv1(x), 0.2)
x = F.leaky_relu(self.conv2(x), 0.2)
return self.downsample(x)
print("Blocks defined")
Progressive GeneratorΒΆ
The progressive generator starts producing images at the lowest resolution (e.g., \(4 \times 4\)) and gradually adds upsampling layers to reach the target resolution. When a new resolution layer is introduced, it is faded in using a blending parameter \(\alpha\) that transitions from 0 (only the old lower-resolution output) to 1 (only the new higher-resolution output) over several thousand training steps. This smooth transition prevents the sudden shock of new layers and allows the existing weights to adapt gradually. The result is a generator that first learns coarse structure, then progressively refines fine detail β mirroring how artists work from rough sketch to finished piece.
class ProgressiveGenerator(nn.Module):
"""Progressive GAN generator."""
def __init__(self, latent_dim=512, max_resolution=32):
super().__init__()
self.latent_dim = latent_dim
self.max_resolution = max_resolution
# Initial 4x4 block
self.initial = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0),
nn.LeakyReLU(0.2),
PixelNorm(),
EqualizedConv2d(512, 512, 3, padding=1),
nn.LeakyReLU(0.2),
PixelNorm()
)
# Progressive blocks
self.blocks = nn.ModuleList([
GeneratorBlock(512, 512), # 4 -> 8
GeneratorBlock(512, 256), # 8 -> 16
GeneratorBlock(256, 128), # 16 -> 32
])
# To RGB layers
self.to_rgb = nn.ModuleList([
EqualizedConv2d(512, 1, 1), # 4x4
EqualizedConv2d(512, 1, 1), # 8x8
EqualizedConv2d(256, 1, 1), # 16x16
EqualizedConv2d(128, 1, 1), # 32x32
])
def forward(self, z, depth, alpha):
"""Forward pass with fade-in.
Args:
z: Latent code
depth: Current resolution level (0=4x4, 1=8x8, ...)
alpha: Fade-in parameter [0, 1]
"""
x = self.initial(z.view(-1, self.latent_dim, 1, 1))
if depth == 0:
return torch.tanh(self.to_rgb[0](x))
for i in range(depth):
if i == depth - 1:
# Fade-in new layer
x_old = F.interpolate(x, scale_factor=2, mode='nearest')
x_old = self.to_rgb[i](x_old)
x_new = self.blocks[i](x)
x_new = self.to_rgb[i + 1](x_new)
return torch.tanh(alpha * x_new + (1 - alpha) * x_old)
else:
x = self.blocks[i](x)
return torch.tanh(self.to_rgb[depth](x))
gen = ProgressiveGenerator().to(device)
print("Generator created")
Progressive DiscriminatorΒΆ
The discriminator mirrors the generatorβs progressive growth: it starts operating on low-resolution images and gains new downsampling layers as the resolution increases. The same fade-in mechanism is applied, ensuring the discriminator and generator grow in lockstep. At each resolution, the discriminatorβs task is calibrated to the level of detail currently being generated, which provides more useful gradient signals than a full-resolution discriminator judging blurry low-resolution outputs. This matched progression is what makes progressive training stable for generating high-resolution images (up to \(1024 \times 1024\) in the original paper).
class ProgressiveDiscriminator(nn.Module):
"""Progressive GAN discriminator."""
def __init__(self, max_resolution=32):
super().__init__()
# From RGB layers
self.from_rgb = nn.ModuleList([
EqualizedConv2d(1, 512, 1), # 4x4
EqualizedConv2d(1, 512, 1), # 8x8
EqualizedConv2d(1, 256, 1), # 16x16
EqualizedConv2d(1, 128, 1), # 32x32
])
# Progressive blocks
self.blocks = nn.ModuleList([
DiscriminatorBlock(128, 256), # 32 -> 16
DiscriminatorBlock(256, 512), # 16 -> 8
DiscriminatorBlock(512, 512), # 8 -> 4
])
# Final block
self.final = nn.Sequential(
EqualizedConv2d(512, 512, 3, padding=1),
nn.LeakyReLU(0.2),
EqualizedConv2d(512, 512, 4),
nn.LeakyReLU(0.2),
EqualizedConv2d(512, 1, 1)
)
def forward(self, x, depth, alpha):
if depth == 0:
x = F.leaky_relu(self.from_rgb[0](x), 0.2)
return self.final(x).view(-1)
# Fade-in
x_new = F.leaky_relu(self.from_rgb[depth](x), 0.2)
x_old = F.avg_pool2d(x, 2)
x_old = F.leaky_relu(self.from_rgb[depth - 1](x_old), 0.2)
x = alpha * x_new + (1 - alpha) * x_old
for i in range(depth - 1, -1, -1):
x = self.blocks[2 - i](x) if i < 3 else x
return self.final(x).view(-1)
disc = ProgressiveDiscriminator().to(device)
print("Discriminator created")
TrainingΒΆ
Progressive training proceeds through multiple phases, one per resolution level. At each phase, the model trains for a fixed number of iterations at the current resolution before transitioning to the next. The WGAN-GP loss (Wasserstein loss with gradient penalty) is commonly used because it provides smoother gradients than the standard GAN loss, which is especially important during the resolution transitions. Training time scales with resolution: higher resolutions require more parameters, larger batch sizes, and more iterations. Monitoring FID (Frechet Inception Distance) at each resolution level tracks generation quality throughout the progressive training process.
# Load MNIST
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)
def train_progressive(gen, disc, depth=0, n_epochs=3):
"""Train at specific resolution."""
optimizer_G = torch.optim.Adam(gen.parameters(), lr=1e-3, betas=(0.0, 0.99))
optimizer_D = torch.optim.Adam(disc.parameters(), lr=1e-3, betas=(0.0, 0.99))
resolution = 4 * (2 ** depth)
for epoch in range(n_epochs):
# Alpha for fade-in
alpha = min(1.0, (epoch + 1) / n_epochs)
for real_imgs, _ in loader:
batch_size = real_imgs.size(0)
real_imgs = F.interpolate(real_imgs, size=resolution)
real_imgs = real_imgs.to(device)
# Train Discriminator
z = torch.randn(batch_size, 512).to(device)
fake_imgs = gen(z, depth, alpha)
real_score = disc(real_imgs, depth, alpha)
fake_score = disc(fake_imgs.detach(), depth, alpha)
loss_D = -torch.mean(real_score) + torch.mean(fake_score)
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
# Train Generator
z = torch.randn(batch_size, 512).to(device)
fake_imgs = gen(z, depth, alpha)
fake_score = disc(fake_imgs, depth, alpha)
loss_G = -torch.mean(fake_score)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
print(f"Depth {depth}, Epoch {epoch+1}, Ξ±={alpha:.2f}, D={loss_D.item():.4f}, G={loss_G.item():.4f}")
# Train progressively
for depth in range(3):
print(f"\nTraining at {4 * (2 ** depth)}Γ{4 * (2 ** depth)}")
train_progressive(gen, disc, depth, n_epochs=3)
Generate SamplesΒΆ
After progressive training completes at the final resolution, we generate a batch of samples to assess the modelβs output quality and diversity. Comparing samples generated at intermediate resolutions (from earlier training phases) with final-resolution samples shows how the modelβs capabilities evolve during progressive training β from capturing coarse shapes to rendering fine textures and details.
gen.eval()
fig, axes = plt.subplots(3, 8, figsize=(16, 6))
with torch.no_grad():
z = torch.randn(8, 512).to(device)
for depth in range(3):
imgs = gen(z, depth, alpha=1.0)
for i in range(8):
axes[depth, i].imshow(imgs[i, 0].cpu(), cmap='gray')
axes[depth, i].axis('off')
if i == 0:
res = 4 * (2 ** depth)
axes[depth, i].set_ylabel(f'{res}Γ{res}', fontsize=11)
plt.suptitle('Progressive Generation', fontsize=13)
plt.tight_layout()
plt.show()
SummaryΒΆ
Progressive GAN:ΒΆ
Gradual resolution increase
Layer fade-in smooths training
Stable high-resolution synthesis
Faster convergence
Key Techniques:ΒΆ
Pixel normalization
Equalized learning rate
Minibatch standard deviation
Applications:ΒΆ
High-resolution face generation (1024Γ1024)
Texture synthesis
Data augmentation
Extensions:ΒΆ
StyleGAN: Style-based generation
ProGAN: Original implementation
BigGAN: Large-scale training