import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from itertools import chain

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. CycleGAN TheoryΒΆ

Unpaired TranslationΒΆ

Learn mappings \(G: X \rightarrow Y\) and \(F: Y \rightarrow X\) from unpaired sets.

Cycle Consistency LossΒΆ

\[\mathcal{L}_{\text{cyc}}(G, F) = \mathbb{E}_{x \sim p(X)}\|F(G(x)) - x\|_1 + \mathbb{E}_{y \sim p(Y)}\|G(F(y)) - y\|_1\]

Full ObjectiveΒΆ

\[\mathcal{L} = \mathcal{L}_{\text{GAN}}(G, D_Y) + \mathcal{L}_{\text{GAN}}(F, D_X) + \lambda \mathcal{L}_{\text{cyc}}(G, F)\]

πŸ“š Reference Materials:

class ResidualBlock(nn.Module):
    """Residual block with instance norm."""
    
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = nn.InstanceNorm2d(channels)
        self.norm2 = nn.InstanceNorm2d(channels)
    
    def forward(self, x):
        residual = x
        x = F.relu(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))
        return x + residual

class Generator(nn.Module):
    """CycleGAN generator."""
    
    def __init__(self, input_channels=1, output_channels=1, n_residual=6):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, 7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
        # Residual blocks
        self.residual = nn.Sequential(
            *[ResidualBlock(256) for _ in range(n_residual)]
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, output_channels, 7, padding=3),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.residual(x)
        return self.decoder(x)

class Discriminator(nn.Module):
    """PatchGAN discriminator."""
    
    def __init__(self, input_channels=1):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, 4, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)

# Create models
G_XY = Generator().to(device)  # X -> Y
G_YX = Generator().to(device)  # Y -> X
D_X = Discriminator().to(device)
D_Y = Discriminator().to(device)

print("Models created")

Prepare DataΒΆ

CycleGAN requires unpaired data from two domains – for example, photographs and paintings, or summer and winter scenes. Unlike paired image translation (pix2pix), CycleGAN does not need aligned input-output pairs, making it applicable to many more real-world scenarios where paired data is unavailable. Data preparation involves loading images from both domains, applying standard preprocessing (resize, normalize, optional augmentation), and creating two independent data loaders. The training set size per domain can differ, and the images need not depict the same scenes.

transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

mnist = datasets.MNIST('./data', train=True, download=True, transform=transform)

# Create two domains: digits 0-4 (domain X), 5-9 (domain Y)
indices_X = [i for i, (_, label) in enumerate(mnist) if label < 5]
indices_Y = [i for i, (_, label) in enumerate(mnist) if label >= 5]

dataset_X = torch.utils.data.Subset(mnist, indices_X[:1000])
dataset_Y = torch.utils.data.Subset(mnist, indices_Y[:1000])

loader_X = torch.utils.data.DataLoader(dataset_X, batch_size=32, shuffle=True)
loader_Y = torch.utils.data.DataLoader(dataset_Y, batch_size=32, shuffle=True)

print(f"Domain X: {len(dataset_X)}, Domain Y: {len(dataset_Y)}")

Training LoopΒΆ

CycleGAN trains four networks simultaneously: two generators (\(G: A \to B\) and \(F: B \to A\)) and two discriminators (\(D_A\) and \(D_B\)). The total loss combines adversarial losses (each discriminator classifies real vs. translated images), cycle consistency losses (\(\|F(G(x)) - x\|\) and \(\|G(F(y)) - y\|\)), and optionally an identity loss (\(\|G(y) - y\|\), \(\|F(x) - x\|\)). Cycle consistency is the key innovation: it enforces that translating an image to the other domain and back should recover the original, preventing the generators from producing arbitrary outputs that fool the discriminators.

# Optimizers
optimizer_G = torch.optim.Adam(
    chain(G_XY.parameters(), G_YX.parameters()),
    lr=2e-4, betas=(0.5, 0.999)
)
optimizer_D = torch.optim.Adam(
    chain(D_X.parameters(), D_Y.parameters()),
    lr=2e-4, betas=(0.5, 0.999)
)

def train_cyclegan(n_epochs=10, lambda_cyc=10.0):
    losses_G = []
    losses_D = []
    
    for epoch in range(n_epochs):
        for (x_real, _), (y_real, _) in zip(loader_X, loader_Y):
            x_real = x_real.to(device)
            y_real = y_real.to(device)
            
            batch_size = min(x_real.size(0), y_real.size(0))
            x_real = x_real[:batch_size]
            y_real = y_real[:batch_size]
            
            # ============ Train Generators ============
            optimizer_G.zero_grad()
            
            # Forward cycle: X -> Y -> X'
            y_fake = G_XY(x_real)
            x_recon = G_YX(y_fake)
            
            # Backward cycle: Y -> X -> Y'
            x_fake = G_YX(y_real)
            y_recon = G_XY(x_fake)
            
            # Adversarial losses
            loss_G_XY = F.mse_loss(D_Y(y_fake), torch.ones(batch_size, 1, 2, 2).to(device))
            loss_G_YX = F.mse_loss(D_X(x_fake), torch.ones(batch_size, 1, 2, 2).to(device))
            
            # Cycle consistency
            loss_cycle_X = F.l1_loss(x_recon, x_real)
            loss_cycle_Y = F.l1_loss(y_recon, y_real)
            loss_cycle = loss_cycle_X + loss_cycle_Y
            
            # Total generator loss
            loss_G = loss_G_XY + loss_G_YX + lambda_cyc * loss_cycle
            
            loss_G.backward()
            optimizer_G.step()
            
            # ============ Train Discriminators ============
            optimizer_D.zero_grad()
            
            # D_X
            loss_D_X_real = F.mse_loss(D_X(x_real), torch.ones(batch_size, 1, 2, 2).to(device))
            loss_D_X_fake = F.mse_loss(D_X(x_fake.detach()), torch.zeros(batch_size, 1, 2, 2).to(device))
            loss_D_X = (loss_D_X_real + loss_D_X_fake) / 2
            
            # D_Y
            loss_D_Y_real = F.mse_loss(D_Y(y_real), torch.ones(batch_size, 1, 2, 2).to(device))
            loss_D_Y_fake = F.mse_loss(D_Y(y_fake.detach()), torch.zeros(batch_size, 1, 2, 2).to(device))
            loss_D_Y = (loss_D_Y_real + loss_D_Y_fake) / 2
            
            loss_D = loss_D_X + loss_D_Y
            
            loss_D.backward()
            optimizer_D.step()
            
            losses_G.append(loss_G.item())
            losses_D.append(loss_D.item())
        
        print(f"Epoch {epoch+1}/{n_epochs}, G: {loss_G.item():.4f}, D: {loss_D.item():.4f}")
    
    return losses_G, losses_D

losses_G, losses_D = train_cyclegan(n_epochs=10, lambda_cyc=10.0)

Visualize ResultsΒΆ

Displaying input images alongside their translated outputs and their cycle-reconstructions provides a comprehensive view of the model’s performance. Good translations should look natural in the target domain while preserving the semantic content of the source; the cycle-reconstructed images should closely match the originals. Artifacts like color shifts, structural distortions, or content hallucination indicate training instability or insufficient cycle consistency weight.

G_XY.eval()
G_YX.eval()

# Get samples
x_samples, _ = next(iter(loader_X))
y_samples, _ = next(iter(loader_Y))

x_samples = x_samples[:8].to(device)
y_samples = y_samples[:8].to(device)

with torch.no_grad():
    y_fake = G_XY(x_samples)
    x_recon = G_YX(y_fake)
    
    x_fake = G_YX(y_samples)
    y_recon = G_XY(x_fake)

# Plot
fig, axes = plt.subplots(4, 8, figsize=(16, 8))

for i in range(8):
    # X -> Y -> X'
    axes[0, i].imshow(x_samples[i, 0].cpu(), cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_ylabel('X', fontsize=11)
    
    axes[1, i].imshow(y_fake[i, 0].cpu(), cmap='gray')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_ylabel('G(X)', fontsize=11)
    
    # Y -> X -> Y'
    axes[2, i].imshow(y_samples[i, 0].cpu(), cmap='gray')
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_ylabel('Y', fontsize=11)
    
    axes[3, i].imshow(x_fake[i, 0].cpu(), cmap='gray')
    axes[3, i].axis('off')
    if i == 0:
        axes[3, i].set_ylabel('F(Y)', fontsize=11)

plt.suptitle('CycleGAN Translations', fontsize=13)
plt.tight_layout()
plt.show()

Loss CurvesΒΆ

Plotting the generator losses, discriminator losses, and cycle consistency losses over training reveals the dynamics of the adversarial training process. Unlike standard GANs where one loss decreasing may be sufficient, CycleGAN requires all losses to remain balanced. If a discriminator becomes too strong (loss near zero), the corresponding generator struggles to learn; if cycle consistency loss remains high, the translations are not content-preserving. Stable training typically shows all losses fluctuating in a bounded range without sustained divergence.

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Generator loss
axes[0].plot(losses_G, alpha=0.3)
axes[0].plot(np.convolve(losses_G, np.ones(50)/50, mode='valid'), linewidth=2)
axes[0].set_xlabel('Iteration', fontsize=11)
axes[0].set_ylabel('Generator Loss', fontsize=11)
axes[0].set_title('Generator Training', fontsize=12)
axes[0].grid(True, alpha=0.3)

# Discriminator loss
axes[1].plot(losses_D, alpha=0.3)
axes[1].plot(np.convolve(losses_D, np.ones(50)/50, mode='valid'), linewidth=2)
axes[1].set_xlabel('Iteration', fontsize=11)
axes[1].set_ylabel('Discriminator Loss', fontsize=11)
axes[1].set_title('Discriminator Training', fontsize=12)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

SummaryΒΆ

CycleGAN Key Ideas:ΒΆ

  1. Unpaired training - No paired examples needed

  2. Cycle consistency - \(F(G(x)) \approx x\)

  3. Dual generators - \(G: X \rightarrow Y\), \(F: Y \rightarrow X\)

  4. Two discriminators - \(D_X\), \(D_Y\)

Loss Components:ΒΆ

  • Adversarial loss (2 directions)

  • Cycle consistency loss

  • Identity loss (optional)

Applications:ΒΆ

  • Style transfer (photo ↔ painting)

  • Season transfer (summer ↔ winter)

  • Object transfiguration (horse ↔ zebra)

  • Domain adaptation

Advantages:ΒΆ

  • No paired data required

  • Preserves content structure

  • Bidirectional mapping

Extensions:ΒΆ

  • StarGAN: Multi-domain translation

  • UNIT: Shared latent space

  • MUNIT: Multimodal translation