import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from itertools import chain
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. CycleGAN TheoryΒΆ
Unpaired TranslationΒΆ
Learn mappings \(G: X \rightarrow Y\) and \(F: Y \rightarrow X\) from unpaired sets.
Cycle Consistency LossΒΆ
Full ObjectiveΒΆ
π Reference Materials:
gan.pdf - Gan
class ResidualBlock(nn.Module):
"""Residual block with instance norm."""
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.norm1 = nn.InstanceNorm2d(channels)
self.norm2 = nn.InstanceNorm2d(channels)
def forward(self, x):
residual = x
x = F.relu(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + residual
class Generator(nn.Module):
"""CycleGAN generator."""
def __init__(self, input_channels=1, output_channels=1, n_residual=6):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(input_channels, 64, 7, padding=3),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.ReLU(inplace=True),
)
# Residual blocks
self.residual = nn.Sequential(
*[ResidualBlock(256) for _ in range(n_residual)]
)
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, output_channels, 7, padding=3),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.residual(x)
return self.decoder(x)
class Discriminator(nn.Module):
"""PatchGAN discriminator."""
def __init__(self, input_channels=1):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, x):
return self.model(x)
# Create models
G_XY = Generator().to(device) # X -> Y
G_YX = Generator().to(device) # Y -> X
D_X = Discriminator().to(device)
D_Y = Discriminator().to(device)
print("Models created")
Prepare DataΒΆ
CycleGAN requires unpaired data from two domains β for example, photographs and paintings, or summer and winter scenes. Unlike paired image translation (pix2pix), CycleGAN does not need aligned input-output pairs, making it applicable to many more real-world scenarios where paired data is unavailable. Data preparation involves loading images from both domains, applying standard preprocessing (resize, normalize, optional augmentation), and creating two independent data loaders. The training set size per domain can differ, and the images need not depict the same scenes.
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)
# Create two domains: digits 0-4 (domain X), 5-9 (domain Y)
indices_X = [i for i, (_, label) in enumerate(mnist) if label < 5]
indices_Y = [i for i, (_, label) in enumerate(mnist) if label >= 5]
dataset_X = torch.utils.data.Subset(mnist, indices_X[:1000])
dataset_Y = torch.utils.data.Subset(mnist, indices_Y[:1000])
loader_X = torch.utils.data.DataLoader(dataset_X, batch_size=32, shuffle=True)
loader_Y = torch.utils.data.DataLoader(dataset_Y, batch_size=32, shuffle=True)
print(f"Domain X: {len(dataset_X)}, Domain Y: {len(dataset_Y)}")
Training LoopΒΆ
CycleGAN trains four networks simultaneously: two generators (\(G: A \to B\) and \(F: B \to A\)) and two discriminators (\(D_A\) and \(D_B\)). The total loss combines adversarial losses (each discriminator classifies real vs. translated images), cycle consistency losses (\(\|F(G(x)) - x\|\) and \(\|G(F(y)) - y\|\)), and optionally an identity loss (\(\|G(y) - y\|\), \(\|F(x) - x\|\)). Cycle consistency is the key innovation: it enforces that translating an image to the other domain and back should recover the original, preventing the generators from producing arbitrary outputs that fool the discriminators.
# Optimizers
optimizer_G = torch.optim.Adam(
chain(G_XY.parameters(), G_YX.parameters()),
lr=2e-4, betas=(0.5, 0.999)
)
optimizer_D = torch.optim.Adam(
chain(D_X.parameters(), D_Y.parameters()),
lr=2e-4, betas=(0.5, 0.999)
)
def train_cyclegan(n_epochs=10, lambda_cyc=10.0):
losses_G = []
losses_D = []
for epoch in range(n_epochs):
for (x_real, _), (y_real, _) in zip(loader_X, loader_Y):
x_real = x_real.to(device)
y_real = y_real.to(device)
batch_size = min(x_real.size(0), y_real.size(0))
x_real = x_real[:batch_size]
y_real = y_real[:batch_size]
# ============ Train Generators ============
optimizer_G.zero_grad()
# Forward cycle: X -> Y -> X'
y_fake = G_XY(x_real)
x_recon = G_YX(y_fake)
# Backward cycle: Y -> X -> Y'
x_fake = G_YX(y_real)
y_recon = G_XY(x_fake)
# Adversarial losses
loss_G_XY = F.mse_loss(D_Y(y_fake), torch.ones(batch_size, 1, 2, 2).to(device))
loss_G_YX = F.mse_loss(D_X(x_fake), torch.ones(batch_size, 1, 2, 2).to(device))
# Cycle consistency
loss_cycle_X = F.l1_loss(x_recon, x_real)
loss_cycle_Y = F.l1_loss(y_recon, y_real)
loss_cycle = loss_cycle_X + loss_cycle_Y
# Total generator loss
loss_G = loss_G_XY + loss_G_YX + lambda_cyc * loss_cycle
loss_G.backward()
optimizer_G.step()
# ============ Train Discriminators ============
optimizer_D.zero_grad()
# D_X
loss_D_X_real = F.mse_loss(D_X(x_real), torch.ones(batch_size, 1, 2, 2).to(device))
loss_D_X_fake = F.mse_loss(D_X(x_fake.detach()), torch.zeros(batch_size, 1, 2, 2).to(device))
loss_D_X = (loss_D_X_real + loss_D_X_fake) / 2
# D_Y
loss_D_Y_real = F.mse_loss(D_Y(y_real), torch.ones(batch_size, 1, 2, 2).to(device))
loss_D_Y_fake = F.mse_loss(D_Y(y_fake.detach()), torch.zeros(batch_size, 1, 2, 2).to(device))
loss_D_Y = (loss_D_Y_real + loss_D_Y_fake) / 2
loss_D = loss_D_X + loss_D_Y
loss_D.backward()
optimizer_D.step()
losses_G.append(loss_G.item())
losses_D.append(loss_D.item())
print(f"Epoch {epoch+1}/{n_epochs}, G: {loss_G.item():.4f}, D: {loss_D.item():.4f}")
return losses_G, losses_D
losses_G, losses_D = train_cyclegan(n_epochs=10, lambda_cyc=10.0)
Visualize ResultsΒΆ
Displaying input images alongside their translated outputs and their cycle-reconstructions provides a comprehensive view of the modelβs performance. Good translations should look natural in the target domain while preserving the semantic content of the source; the cycle-reconstructed images should closely match the originals. Artifacts like color shifts, structural distortions, or content hallucination indicate training instability or insufficient cycle consistency weight.
G_XY.eval()
G_YX.eval()
# Get samples
x_samples, _ = next(iter(loader_X))
y_samples, _ = next(iter(loader_Y))
x_samples = x_samples[:8].to(device)
y_samples = y_samples[:8].to(device)
with torch.no_grad():
y_fake = G_XY(x_samples)
x_recon = G_YX(y_fake)
x_fake = G_YX(y_samples)
y_recon = G_XY(x_fake)
# Plot
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
for i in range(8):
# X -> Y -> X'
axes[0, i].imshow(x_samples[i, 0].cpu(), cmap='gray')
axes[0, i].axis('off')
if i == 0:
axes[0, i].set_ylabel('X', fontsize=11)
axes[1, i].imshow(y_fake[i, 0].cpu(), cmap='gray')
axes[1, i].axis('off')
if i == 0:
axes[1, i].set_ylabel('G(X)', fontsize=11)
# Y -> X -> Y'
axes[2, i].imshow(y_samples[i, 0].cpu(), cmap='gray')
axes[2, i].axis('off')
if i == 0:
axes[2, i].set_ylabel('Y', fontsize=11)
axes[3, i].imshow(x_fake[i, 0].cpu(), cmap='gray')
axes[3, i].axis('off')
if i == 0:
axes[3, i].set_ylabel('F(Y)', fontsize=11)
plt.suptitle('CycleGAN Translations', fontsize=13)
plt.tight_layout()
plt.show()
Loss CurvesΒΆ
Plotting the generator losses, discriminator losses, and cycle consistency losses over training reveals the dynamics of the adversarial training process. Unlike standard GANs where one loss decreasing may be sufficient, CycleGAN requires all losses to remain balanced. If a discriminator becomes too strong (loss near zero), the corresponding generator struggles to learn; if cycle consistency loss remains high, the translations are not content-preserving. Stable training typically shows all losses fluctuating in a bounded range without sustained divergence.
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Generator loss
axes[0].plot(losses_G, alpha=0.3)
axes[0].plot(np.convolve(losses_G, np.ones(50)/50, mode='valid'), linewidth=2)
axes[0].set_xlabel('Iteration', fontsize=11)
axes[0].set_ylabel('Generator Loss', fontsize=11)
axes[0].set_title('Generator Training', fontsize=12)
axes[0].grid(True, alpha=0.3)
# Discriminator loss
axes[1].plot(losses_D, alpha=0.3)
axes[1].plot(np.convolve(losses_D, np.ones(50)/50, mode='valid'), linewidth=2)
axes[1].set_xlabel('Iteration', fontsize=11)
axes[1].set_ylabel('Discriminator Loss', fontsize=11)
axes[1].set_title('Discriminator Training', fontsize=12)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
SummaryΒΆ
CycleGAN Key Ideas:ΒΆ
Unpaired training - No paired examples needed
Cycle consistency - \(F(G(x)) \approx x\)
Dual generators - \(G: X \rightarrow Y\), \(F: Y \rightarrow X\)
Two discriminators - \(D_X\), \(D_Y\)
Loss Components:ΒΆ
Adversarial loss (2 directions)
Cycle consistency loss
Identity loss (optional)
Applications:ΒΆ
Style transfer (photo β painting)
Season transfer (summer β winter)
Object transfiguration (horse β zebra)
Domain adaptation
Advantages:ΒΆ
No paired data required
Preserves content structure
Bidirectional mapping
Extensions:ΒΆ
StarGAN: Multi-domain translation
UNIT: Shared latent space
MUNIT: Multimodal translation