Conditional GANs (cGAN): Comprehensive TheoryΒΆ
IntroductionΒΆ
Conditional Generative Adversarial Networks (cGANs) extend the standard GAN framework by conditioning both the generator and discriminator on additional information \(y\). This conditioning enables controlled generation, where we can specify what kind of samples to generate (e.g., class labels, text descriptions, images, or other auxiliary data).
Key Innovation: Transform unsupervised GAN into a supervised framework with explicit control over generation.
Applications:
Class-conditional image generation (generate specific digits, objects, faces)
Image-to-image translation (pix2pix: edges β photos, day β night)
Text-to-image synthesis
Super-resolution
Style transfer
Inpainting and outpainting
Mathematical FoundationΒΆ
Standard GAN (Unconditional)ΒΆ
Recall the standard GAN minimax objective:
Where:
\(G(\mathbf{z})\): Generator maps noise \(\mathbf{z}\) to fake samples
\(D(\mathbf{x})\): Discriminator predicts probability that \(\mathbf{x}\) is real
No explicit control over what is generated
Conditional GAN ObjectiveΒΆ
In cGAN, we condition both \(G\) and \(D\) on additional information \(y\):
Where:
\(G(\mathbf{z}, y)\): Generator conditioned on \(y\) (e.g., class label, text, image)
\(D(\mathbf{x}, y)\): Discriminator predicts if \(\mathbf{x}\) is real given condition \(y\)
\(y\) can be discrete (class labels) or continuous (images, embeddings)
Key Insight: Discriminator must verify two things:
Is the sample realistic?
Does the sample match the condition \(y\)?
Conditioning MechanismsΒΆ
1. Label Conditioning (Discrete Classes)ΒΆ
For class-conditional generation (e.g., MNIST digits 0-9):
Generator Conditioning: $\(G(\mathbf{z}, y) = G([\mathbf{z}; \mathbf{e}(y)])\)$
Where \(\mathbf{e}(y)\) is an embedding of class label \(y\) (one-hot or learned embedding).
Discriminator Conditioning: $\(D(\mathbf{x}, y) = D([\mathbf{x}; \mathbf{e}(y)])\)$
Implementation:
Concatenate embedded label with noise vector (generator input)
Concatenate embedded label with image (discriminator input)
Can use projection discriminator for efficiency (discussed later)
2. Image Conditioning (Paired Image-to-Image)ΒΆ
For tasks like pix2pix (edges β photo):
Generator: U-Net architecture $\(G(\mathbf{x}_{\text{input}}) = \mathbf{x}_{\text{output}}\)$
Discriminator: PatchGAN (evaluates \(N \times N\) patches) $\(D(\mathbf{x}_{\text{output}}, \mathbf{x}_{\text{input}}) \rightarrow \mathbb{R}^{H \times W}\)$
Loss: Combines adversarial loss + L1 reconstruction $\(\mathcal{L} = \mathcal{L}_{\text{GAN}} + \lambda \mathbb{E}[\|\mathbf{x}_{\text{real}} - G(\mathbf{x}_{\text{input}})\|_1]\)$
3. Text Conditioning (Text-to-Image)ΒΆ
For generating images from captions:
Text Embedding:
Use pre-trained text encoder (e.g., BERT, CLIP)
\(\mathbf{t} = \text{Encoder}(\text{"a red car on the street"})\)
Generator: $\(G(\mathbf{z}, \mathbf{t}) = G([\mathbf{z}; \mathbf{t}])\)$
Discriminator (matching-aware): $\(D(\mathbf{x}, \mathbf{t}) = \begin{cases} \text{real} & \text{if } \mathbf{x} \text{ is real and matches } \mathbf{t} \\ \text{fake} & \text{otherwise} \end{cases}\)$
4. Continuous Conditioning (Attributes)ΒΆ
For generating faces with specific attributes (age, gender, expression):
Attribute Vector: \(\mathbf{a} \in \mathbb{R}^d\) (e.g., [age=30, male=1, smile=0.5])
Generator: $\(G(\mathbf{z}, \mathbf{a}) = G([\mathbf{z}; \mathbf{a}])\)$
Advanced Conditioning ArchitecturesΒΆ
Projection DiscriminatorΒΆ
Problem: Naive concatenation is inefficient and limits discriminator capacity.
Solution (Miyato & Koyama, 2018): Project conditioning information into discriminatorβs intermediate features.
Architecture: $\(D(\mathbf{x}, y) = \sigma(\mathbf{w}^T \phi(\mathbf{x}) + \mathbf{e}(y)^T \phi(\mathbf{x}))\)$
Where:
\(\phi(\mathbf{x})\): Feature extractor (CNN)
\(\mathbf{e}(y)\): Embedding of condition \(y\)
Inner product \(\mathbf{e}(y)^T \phi(\mathbf{x})\) captures correlation
Benefits:
More parameter-efficient
Better gradient flow
Improved conditioning signal
State-of-the-art for class-conditional generation
Auxiliary Classifier GAN (AC-GAN)ΒΆ
Innovation: Discriminator also predicts the class label.
Discriminator Outputs:
Real/Fake probability: \(D(\mathbf{x})\)
Class probability: \(C(\mathbf{x})\)
Losses:
Adversarial Loss: $\(\mathcal{L}_{\text{adv}} = \mathbb{E}[\log D(\mathbf{x}_{\text{real}})] + \mathbb{E}[\log(1 - D(G(\mathbf{z}, y)))]\)$
Classification Loss (both real and fake): $\(\mathcal{L}_{\text{cls}} = \mathbb{E}[\log C(y | \mathbf{x}_{\text{real}})] + \mathbb{E}[\log C(y | G(\mathbf{z}, y))]\)$
Total Objective: $\(\mathcal{L}_D = \mathcal{L}_{\text{adv}} - \mathcal{L}_{\text{cls}}\)\( \)\(\mathcal{L}_G = -\mathcal{L}_{\text{adv}} + \mathcal{L}_{\text{cls}}\)$
Trade-off:
Pros: Better class separation, higher quality
Cons: Can reduce diversity within classes (mode collapse per class)
Conditional Batch Normalization (CBN)ΒΆ
Idea: Modulate batch normalization parameters based on condition.
Standard Batch Normalization: $\(\text{BN}(\mathbf{h}) = \gamma \frac{\mathbf{h} - \mu}{\sigma} + \beta\)$
Conditional Batch Normalization: $\(\text{CBN}(\mathbf{h}, y) = \gamma(y) \frac{\mathbf{h} - \mu}{\sigma} + \beta(y)\)$
Where \(\gamma(y)\) and \(\beta(y)\) are predicted from condition \(y\) via small MLPs.
Benefits:
Efficient conditioning mechanism
Used in modern architectures (BigGAN, StyleGAN)
Allows fine-grained control over feature maps
Pix2Pix: Image-to-Image TranslationΒΆ
ArchitectureΒΆ
Generator: U-Net
Encoder: Downsample input image (extract features)
Decoder: Upsample to output resolution
Skip connections: Preserve spatial information from encoder to decoder
Output: Generated image matching input resolution
Structure:
Encoder: C64-C128-C256-C512-C512-C512-C512
Decoder: CD512-CD512-CD512-C512-C256-C128-C64
Where:
C\(k\): Convolution-BatchNorm-ReLU with \(k\) filters
CD\(k\): Convolution-BatchNorm-Dropout-ReLU with \(k\) filters
Discriminator: PatchGAN
Evaluates whether \(N \times N\) patches are real/fake
Typical patch size: 70Γ70
Smaller receptive field β faster, more local
Output: \(H/16 \times W/16\) map of patch predictions
Loss FunctionΒΆ
Adversarial Loss (cGAN): $\(\mathcal{L}_{\text{cGAN}}(G, D) = \mathbb{E}_{x,y}[\log D(x, y)] + \mathbb{E}_{x,z}[\log(1 - D(x, G(x, z)))]\)$
L1 Reconstruction Loss: $\(\mathcal{L}_{L1}(G) = \mathbb{E}_{x,y,z}[\|y - G(x, z)\|_1]\)$
Total Generator Loss: $\(G^* = \arg\min_G \max_D \mathcal{L}_{\text{cGAN}}(G, D) + \lambda \mathcal{L}_{L1}(G)\)$
Typical \(\lambda = 100\) (heavily weight reconstruction).
Why L1 instead of L2?
L1 encourages less blurring than L2
L2 averages all plausible outputs β blurry
L1 picks mode closer to ground truth
Training StrategyΒΆ
Paired Data: Requires input-output pairs \((x, y)\)
Preprocessing: Jittering (resize to 286Γ286, random crop to 256Γ256)
Augmentation: Random flips
Optimizer: Adam with \(\beta_1 = 0.5\) (lower momentum for stability)
Learning Rate: \(2 \times 10^{-4}\), constant for first 100 epochs, then linear decay
CycleGAN: Unpaired Image-to-Image TranslationΒΆ
MotivationΒΆ
Problem with pix2pix: Requires paired training data \((x, y)\), which is expensive or impossible to collect (e.g., photo β painting).
Solution: CycleGAN learns mappings between domains \(X\) and \(Y\) using unpaired data.
Cycle ConsistencyΒΆ
Key Idea: If we translate \(x \rightarrow y \rightarrow x'\), we should get back \(x' \approx x\).
Cycle Consistency Loss: $\(\mathcal{L}_{\text{cyc}}(G, F) = \mathbb{E}_{x \sim X}[\|F(G(x)) - x\|_1] + \mathbb{E}_{y \sim Y}[\|G(F(y)) - y\|_1]\)$
Where:
\(G: X \rightarrow Y\) (e.g., horse β zebra)
\(F: Y \rightarrow X\) (e.g., zebra β horse)
Interpretation:
Forward cycle: \(x \xrightarrow{G} y' \xrightarrow{F} x'\), minimize \(\|x - x'\|\)
Backward cycle: \(y \xrightarrow{F} x' \xrightarrow{G} y'\), minimize \(\|y - y'\|\)
Full CycleGAN ObjectiveΒΆ
Adversarial Losses (two GANs): $\(\mathcal{L}_{\text{GAN}}(G, D_Y) = \mathbb{E}_{y}[\log D_Y(y)] + \mathbb{E}_{x}[\log(1 - D_Y(G(x)))]\)\( \)\(\mathcal{L}_{\text{GAN}}(F, D_X) = \mathbb{E}_{x}[\log D_X(x)] + \mathbb{E}_{y}[\log(1 - D_X(F(y)))]\)$
Total Objective: $\(\mathcal{L}(G, F, D_X, D_Y) = \mathcal{L}_{\text{GAN}}(G, D_Y) + \mathcal{L}_{\text{GAN}}(F, D_X) + \lambda \mathcal{L}_{\text{cyc}}(G, F)\)$
Typical \(\lambda = 10\).
Identity Loss (optional, for color preservation): $\(\mathcal{L}_{\text{identity}}(G, F) = \mathbb{E}_{y}[\|G(y) - y\|_1] + \mathbb{E}_{x}[\|F(x) - x\|_1]\)$
Encourages \(G\) to be close to identity when given real images from target domain.
ArchitectureΒΆ
Generator: ResNet-based
Downsampling: 2 strided convolutions
Transformation: 6-9 residual blocks
Upsampling: 2 fractionally-strided convolutions
Output: tanh activation (range [-1, 1])
Discriminator: PatchGAN (70Γ70 patches)
LimitationsΒΆ
Mode Collapse: Can map all inputs to single output
Geometric Changes: Struggles with large shape changes (horse β zebra β, cat β dog β)
Texture vs. Shape: Often changes texture while preserving shape
StarGAN: Multi-Domain TranslationΒΆ
MotivationΒΆ
Problem: CycleGAN requires separate model for each domain pair (N domains β N(N-1) models).
Solution: StarGAN uses single generator for all domain translations.
ArchitectureΒΆ
Generator: $\(G(\mathbf{x}, \mathbf{c}) \rightarrow \mathbf{y}\)$
Where \(\mathbf{c}\) is target domain label (one-hot or binary vector).
Example (CelebA attributes):
\(\mathbf{c} = [1, 0, 1, 0]\) β male, no beard, smiling, young
Discriminator:
Real/Fake prediction: \(D_{\text{src}}(\mathbf{x})\)
Domain classification: \(D_{\text{cls}}(\mathbf{x}) \rightarrow \mathbf{c}\)
Loss FunctionsΒΆ
Adversarial Loss: $\(\mathcal{L}_{\text{adv}} = \mathbb{E}_{\mathbf{x}}[\log D_{\text{src}}(\mathbf{x})] + \mathbb{E}_{\mathbf{x}, \mathbf{c}}[\log(1 - D_{\text{src}}(G(\mathbf{x}, \mathbf{c})))]\)$
Domain Classification Loss:
Real images: \(\mathcal{L}_{\text{cls}}^r = \mathbb{E}_{\mathbf{x}, \mathbf{c}'}[-\log D_{\text{cls}}(\mathbf{c}' | \mathbf{x})]\)
Fake images: \(\mathcal{L}_{\text{cls}}^f = \mathbb{E}_{\mathbf{x}, \mathbf{c}}[-\log D_{\text{cls}}(\mathbf{c} | G(\mathbf{x}, \mathbf{c}))]\)
Reconstruction Loss: $\(\mathcal{L}_{\text{rec}} = \mathbb{E}_{\mathbf{x}, \mathbf{c}, \mathbf{c}'}[\|\mathbf{x} - G(G(\mathbf{x}, \mathbf{c}), \mathbf{c}')\|_1]\)$
Full Objective: $\(\mathcal{L}_D = -\mathcal{L}_{\text{adv}} + \lambda_{\text{cls}} \mathcal{L}_{\text{cls}}^r\)\( \)\(\mathcal{L}_G = \mathcal{L}_{\text{adv}} + \lambda_{\text{cls}} \mathcal{L}_{\text{cls}}^f + \lambda_{\text{rec}} \mathcal{L}_{\text{rec}}\)$
Pix2PixHD: High-Resolution Image SynthesisΒΆ
Challenges at High ResolutionΒΆ
Training Instability: Harder to train GANs at 2048Γ1024 resolution
Mode Collapse: More severe at high resolution
Computational Cost: Memory and compute requirements
InnovationsΒΆ
1. Coarse-to-Fine Generator
\(G_1\): Global generator (low resolution)
\(G_2\): Local enhancer (high resolution)
Train \(G_1\) first, then fix and train \(G_2\)
2. Multi-Scale Discriminators
Use three discriminators at different scales:
\(D_1\): Original resolution
\(D_2\): Downsampled 2Γ
\(D_3\): Downsampled 4Γ
Each discriminator has same architecture (PatchGAN) but operates at different scales.
Benefit:
\(D_1\) focuses on fine details
\(D_3\) guides overall structure
Improved gradient signal
3. Feature Matching Loss
Instead of only fooling discriminator, match intermediate features:
Where \(D_k^{(i)}\) is \(i\)-th layer of discriminator \(D_k\).
Benefit: More stable training, less mode collapse.
SPADE: Semantic Image SynthesisΒΆ
Spatially-Adaptive NormalizationΒΆ
Problem: Standard normalization (BatchNorm, InstanceNorm) washes out semantic information in segmentation masks.
Solution: SPADE (Spatially-Adaptive DEnormalization)
Standard Instance Normalization: $\(\gamma \frac{\mathbf{h} - \mu(\mathbf{h})}{\sigma(\mathbf{h})} + \beta\)$
SPADE: $\(\gamma(\mathbf{m}) \odot \frac{\mathbf{h} - \mu(\mathbf{h})}{\sigma(\mathbf{h})} + \beta(\mathbf{m})\)$
Where:
\(\mathbf{m}\): Semantic segmentation mask
\(\gamma(\mathbf{m}), \beta(\mathbf{m})\): Spatially-varying parameters predicted from \(\mathbf{m}\)
\(\odot\): Element-wise multiplication
Architecture:
m β Conv3Γ3 β ReLU β Conv3Γ3 β Ξ³(m)
ββ Conv3Γ3 β Ξ²(m)
Benefit: Preserves semantic information while normalizing activations.
ApplicationsΒΆ
GauGAN: Segmentation mask β photorealistic image
Interactive editing: Draw segmentation, get realistic scene
Style control: Vary appearance while maintaining layout
Text-to-Image GenerationΒΆ
StackGANΒΆ
Idea: Generate images in two stages (low-res β high-res).
Stage-I:
Input: Text embedding \(\mathbf{t}\), noise \(\mathbf{z}\)
Output: 64Γ64 low-resolution image
Captures rough shape and color
Stage-II:
Input: Stage-I image + text embedding
Output: 256Γ256 high-resolution image
Adds details and fixes defects
Conditioning Augmentation:
Problem: Limited training data β overfitting to text embeddings.
Solution: Sample from Gaussian conditioned on text: $\(\mathbf{c} \sim \mathcal{N}(\mu(\mathbf{t}), \Sigma(\mathbf{t}))\)$
Where \(\mu, \Sigma\) are learned.
Regularization: $\(\mathcal{L}_{\text{CA}} = D_{KL}(\mathcal{N}(\mu(\mathbf{t}), \Sigma(\mathbf{t})) \| \mathcal{N}(0, I))\)$
AttnGANΒΆ
Innovation: Fine-grained attention mechanism.
Word-Level Attention:
Extract word features from text: \(\{\mathbf{w}_1, \ldots, \mathbf{w}_T\}\)
For each spatial location in image features, compute attention over words
Weighted combination of word features guides generation
Attention Mechanism: $\(\alpha_{ij} = \frac{\exp(\mathbf{h}_i^T \mathbf{w}_j)}{\sum_{k} \exp(\mathbf{h}_i^T \mathbf{w}_k)}\)$
Deep Attentional Multimodal Similarity Model (DAMSM):
Measures image-text matching at word level: $\(\mathcal{L}_{\text{DAMSM}} = \log P(y | \mathbf{I}) + \log P(y | \mathbf{w}_{1:T})\)$
Encourages generated images to be semantically consistent with input text.
Advanced Training Techniques for cGANΒΆ
Spectral NormalizationΒΆ
Problem: Training instability, especially for conditional models.
Solution: Constrain Lipschitz constant of discriminator.
Spectral Norm: $\(W_{\text{SN}} = \frac{W}{\sigma(W)}\)$
Where \(\sigma(W)\) is largest singular value of weight matrix \(W\).
Implementation: Power iteration to estimate \(\sigma(W)\) efficiently.
Benefit:
More stable training
Better gradient flow
Can remove batch normalization from discriminator
Hinge LossΒΆ
Standard GAN Loss (saturating): $\(\mathcal{L}_D = -\mathbb{E}[\log D(\mathbf{x})] - \mathbb{E}[\log(1 - D(G(\mathbf{z})))]\)$
Hinge Loss (non-saturating): $\(\mathcal{L}_D = \mathbb{E}[\max(0, 1 - D(\mathbf{x}))] + \mathbb{E}[\max(0, 1 + D(G(\mathbf{z})))]\)\( \)\(\mathcal{L}_G = -\mathbb{E}[D(G(\mathbf{z}))]\)$
Benefits:
Better gradient signal
Used in state-of-the-art models (BigGAN, StyleGAN2)
Two Time-Scale Update Rule (TTUR)ΒΆ
Observation: Discriminator learns faster than generator.
Solution: Use different learning rates:
Generator: \(\alpha_G = 10^{-4}\)
Discriminator: \(\alpha_D = 4 \times 10^{-4}\)
Benefit: Maintains balance between \(G\) and \(D\).
Gradient PenaltyΒΆ
R1 Regularization (on real data): $\(\mathcal{L}_{\text{R1}} = \frac{\gamma}{2} \mathbb{E}_{\mathbf{x}}[\|\nabla D(\mathbf{x})\|^2]\)$
Benefits:
Stabilizes training
Prevents discriminator from becoming too confident
Used in StyleGAN2
Evaluation Metrics for Conditional GenerationΒΆ
Inception Score (IS)ΒΆ
Interpretation:
\(p(y|\mathbf{x})\) should be peaked (high quality)
\(p(y) = \mathbb{E}_{\mathbf{x}}[p(y|\mathbf{x})]\) should be uniform (diversity)
Limitations:
Only for class-conditional generation
Doesnβt measure realism well
Can be fooled by adversarial examples
FrΓ©chet Inception Distance (FID)ΒΆ
Where:
\((\mu_r, \Sigma_r)\): Mean and covariance of real image features
\((\mu_g, \Sigma_g)\): Mean and covariance of generated image features
Features: Inception-v3 pool3 layer
Benefits:
More robust than IS
Measures both quality and diversity
Correlates better with human judgment
Lower is better (unlike IS where higher is better).
LPIPS (Learned Perceptual Image Patch Similarity)ΒΆ
Where \(\mathbf{F}_l\) are features from layer \(l\) of VGG or AlexNet.
Use Case: Paired image translation (pix2pix) - measures perceptual similarity.
User StudiesΒΆ
Amazon Mechanical Turk (AMT):
Real vs. Fake discrimination
Fooling rate (percentage of times humans are fooled)
Preference tests (A vs. B)
Limitations: Expensive, time-consuming, subjective.
Applications and Case StudiesΒΆ
1. Medical ImagingΒΆ
MRI Synthesis: T1 β T2, CT β MRI
Segmentation: Image β organ masks
Data Augmentation: Generate synthetic medical images
Challenge: Limited paired data, high stakes (safety-critical).
2. Fashion and RetailΒΆ
Virtual Try-On: Transfer clothing to different models
Attribute Editing: Change color, pattern, style
Sketch-to-Product: Design sketches β realistic renders
3. Video Game DevelopmentΒΆ
Texture Generation: Semantic labels β textures
Scene Synthesis: Layout β photorealistic environments
Character Creation: Concept art β 3D models
4. Autonomous DrivingΒΆ
Domain Adaptation: Sim2Real transfer (synthetic β real)
Data Augmentation: Generate rare scenarios (rain, night)
Sensor Fusion: LiDAR β camera, camera β segmentation
5. Creative ToolsΒΆ
Photo Editing: Style transfer, colorization, inpainting
Art Generation: Text/sketch β artwork
Animation: Keyframe interpolation, motion transfer
Theoretical AnalysisΒΆ
Nash Equilibrium in cGANΒΆ
Definition: \((G^*, D^*)\) is a Nash equilibrium if: $\(G^* = \arg\min_G V(G, D^*), \quad D^* = \arg\max_D V(G^*, D)\)$
Optimal Discriminator (given \(G\) and condition \(y\)): $\(D^*(\mathbf{x}, y) = \frac{p_{\text{data}}(\mathbf{x} | y)}{p_{\text{data}}(\mathbf{x} | y) + p_G(\mathbf{x} | y)}\)$
Global Optimum: \(G^*\) such that \(p_G(\mathbf{x} | y) = p_{\text{data}}(\mathbf{x} | y)\) for all \(y\).
Conditional vs. Unconditional LearningΒΆ
Theorem (Mirza & Osindero, 2014): Conditioning can improve learning efficiency.
Intuition:
Unconditional GAN must learn \(p(\mathbf{x})\) (entire data distribution)
Conditional GAN learns \(p(\mathbf{x} | y)\) (simpler conditional distributions)
Decomposition: \(p(\mathbf{x}) = \sum_y p(\mathbf{x} | y) p(y)\)
Benefit: Each conditional distribution can be simpler than full distribution.
Limitations and ChallengesΒΆ
1. Paired Data RequirementΒΆ
Pix2pix: Needs aligned pairs \((x, y)\)
Solutions: CycleGAN (unpaired), self-supervision
2. Mode CollapseΒΆ
Generator ignores condition, produces single output for all inputs
Mitigation: Minibatch discrimination, unrolled GAN, spectral normalization
3. Condition LeakageΒΆ
Generator copies condition without transformation
Example: In pix2pix, generator may copy input edges to output
Solution: Careful architecture design, regularization
4. Evaluation DifficultyΒΆ
No single metric captures all aspects (quality, diversity, conditioning)
Best Practice: Combine multiple metrics + user studies
5. Computational CostΒΆ
High-resolution generation requires large models
Solutions: Progressive growing, multi-scale training, efficient architectures
Future DirectionsΒΆ
1. Diffusion Models for Conditional GenerationΒΆ
Classifier-free guidance: \(\nabla \log p(\mathbf{x} | y) \approx \nabla \log p(\mathbf{x}) + s \cdot \nabla \log p(y | \mathbf{x})\)
Superior sample quality compared to GANs
Examples: DALL-E 2, Imagen, Stable Diffusion
2. Transformer-Based ArchitecturesΒΆ
Attention mechanisms for long-range dependencies
Examples: VQGAN + CLIP, Parti, Phenaki
3. 3D-Aware Conditional GenerationΒΆ
Generate 3D-consistent images from text/labels
Examples: EG3D, GET3D
4. Semantic ControlΒΆ
Fine-grained attribute editing
Disentangled representations
Compositional generation
5. Few-Shot and Zero-Shot Conditional GenerationΒΆ
Generate novel categories with few examples
Leverage pre-trained vision-language models (CLIP)
ConclusionΒΆ
Conditional GANs extend the standard GAN framework with explicit control, enabling a wide range of applications from image-to-image translation to text-to-image synthesis. Key innovations include:
Architectural: U-Net, PatchGAN, projection discriminator, SPADE
Training: Spectral normalization, hinge loss, feature matching, gradient penalty
Loss Functions: Cycle consistency, reconstruction loss, perceptual loss
Evolution:
cGAN (2014) β Pix2pix (2017) β CycleGAN (2017) β Pix2pixHD (2018) β SPADE (2019)
Text-to-Image: StackGAN (2017) β AttnGAN (2018) β DALL-E (2021) β Stable Diffusion (2022)
Current State: While GANs pioneered conditional generation, diffusion models now dominate due to superior sample quality and training stability. However, GANs remain relevant for:
Real-time generation (faster inference)
Specific applications (super-resolution, style transfer)
Hybrid approaches (combining GANs with diffusion)
The principles of conditional generationβcontrollability, disentanglement, semantic consistencyβcontinue to drive innovation in generative AI.
"""
Conditional GAN - Complete Implementation
This implementation includes:
1. Basic cGAN with label conditioning
2. Pix2Pix with U-Net generator and PatchGAN discriminator
3. Projection discriminator
4. AC-GAN (Auxiliary Classifier GAN)
5. Training utilities and visualization
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List
from dataclasses import dataclass
# ============================================================================
# 1. BASIC CONDITIONAL GAN (LABEL CONDITIONING)
# ============================================================================
class ConditionalGenerator(nn.Module):
"""
Basic conditional generator for class-conditional generation (e.g., MNIST).
Architecture:
- Concatenate noise z and label embedding
- MLP with multiple hidden layers
- Output image of specified size
Args:
latent_dim: Dimension of noise vector z
num_classes: Number of classes
embed_dim: Dimension of label embedding
img_channels: Number of image channels
img_size: Size of output image (assumes square)
"""
def __init__(
self,
latent_dim: int = 100,
num_classes: int = 10,
embed_dim: int = 100,
img_channels: int = 1,
img_size: int = 28
):
super().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
self.img_channels = img_channels
self.img_size = img_size
# Label embedding
self.label_embedding = nn.Embedding(num_classes, embed_dim)
# Generator network
input_dim = latent_dim + embed_dim
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, img_channels * img_size * img_size),
nn.Tanh()
)
def forward(self, z: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Generate images conditioned on labels.
Args:
z: Noise vectors (batch_size, latent_dim)
labels: Class labels (batch_size,)
Returns:
Generated images (batch_size, img_channels, img_size, img_size)
"""
# Embed labels
label_embed = self.label_embedding(labels) # (batch_size, embed_dim)
# Concatenate noise and label embedding
gen_input = torch.cat([z, label_embed], dim=1) # (batch_size, latent_dim + embed_dim)
# Generate image
img = self.model(gen_input)
img = img.view(img.size(0), self.img_channels, self.img_size, self.img_size)
return img
class ConditionalDiscriminator(nn.Module):
"""
Basic conditional discriminator.
Architecture:
- Concatenate image and label embedding
- MLP with multiple hidden layers
- Output real/fake probability
Args:
num_classes: Number of classes
embed_dim: Dimension of label embedding
img_channels: Number of image channels
img_size: Size of input image
"""
def __init__(
self,
num_classes: int = 10,
embed_dim: int = 100,
img_channels: int = 1,
img_size: int = 28
):
super().__init__()
self.num_classes = num_classes
self.img_channels = img_channels
self.img_size = img_size
# Label embedding
self.label_embedding = nn.Embedding(num_classes, embed_dim)
# Discriminator network
input_dim = img_channels * img_size * img_size + embed_dim
self.model = nn.Sequential(
nn.Linear(input_dim, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Discriminate real/fake conditioned on labels.
Args:
img: Images (batch_size, img_channels, img_size, img_size)
labels: Class labels (batch_size,)
Returns:
Real/fake probabilities (batch_size, 1)
"""
# Flatten image
img_flat = img.view(img.size(0), -1) # (batch_size, img_channels * img_size * img_size)
# Embed labels
label_embed = self.label_embedding(labels) # (batch_size, embed_dim)
# Concatenate image and label
disc_input = torch.cat([img_flat, label_embed], dim=1)
# Discriminate
validity = self.model(disc_input)
return validity
# ============================================================================
# 2. PROJECTION DISCRIMINATOR
# ============================================================================
class ProjectionDiscriminator(nn.Module):
"""
Projection discriminator for class-conditional GAN.
More efficient than concatenation:
D(x, y) = Ο(w^T Ο(x) + e(y)^T Ο(x))
Args:
num_classes: Number of classes
img_channels: Number of image channels
img_size: Image size
hidden_dim: Hidden dimension
"""
def __init__(
self,
num_classes: int = 10,
img_channels: int = 1,
img_size: int = 28,
hidden_dim: int = 512
):
super().__init__()
# Feature extractor Ο(x)
self.feature_extractor = nn.Sequential(
nn.Flatten(),
nn.Linear(img_channels * img_size * img_size, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, hidden_dim),
nn.LeakyReLU(0.2)
)
# Real/fake classifier: w^T Ο(x)
self.classifier = nn.Linear(hidden_dim, 1)
# Projection head: e(y)^T Ο(x)
self.projection = nn.Embedding(num_classes, hidden_dim)
def forward(self, img: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Args:
img: Images (batch_size, img_channels, img_size, img_size)
labels: Class labels (batch_size,)
Returns:
Logits (batch_size, 1)
"""
# Extract features
features = self.feature_extractor(img) # (batch_size, hidden_dim)
# Real/fake prediction
out = self.classifier(features) # (batch_size, 1)
# Projection (conditioning)
label_embed = self.projection(labels) # (batch_size, hidden_dim)
proj = torch.sum(features * label_embed, dim=1, keepdim=True) # (batch_size, 1)
# Combine
logits = out + proj
return torch.sigmoid(logits)
# ============================================================================
# 3. AUXILIARY CLASSIFIER GAN (AC-GAN)
# ============================================================================
class ACGANGenerator(nn.Module):
"""
AC-GAN Generator.
Similar to basic cGAN but trained with auxiliary classification loss.
"""
def __init__(
self,
latent_dim: int = 100,
num_classes: int = 10,
img_channels: int = 1,
img_size: int = 28
):
super().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
# Label embedding
self.label_embedding = nn.Embedding(num_classes, num_classes)
# Generator
input_dim = latent_dim + num_classes
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, img_channels * img_size * img_size),
nn.Tanh()
)
self.img_channels = img_channels
self.img_size = img_size
def forward(self, z: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
label_embed = self.label_embedding(labels)
gen_input = torch.cat([z, label_embed], dim=1)
img = self.model(gen_input)
return img.view(img.size(0), self.img_channels, self.img_size, self.img_size)
class ACGANDiscriminator(nn.Module):
"""
AC-GAN Discriminator with auxiliary classifier.
Outputs:
- Real/fake probability
- Class probabilities
"""
def __init__(
self,
num_classes: int = 10,
img_channels: int = 1,
img_size: int = 28
):
super().__init__()
# Shared feature extractor
self.features = nn.Sequential(
nn.Flatten(),
nn.Linear(img_channels * img_size * img_size, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3)
)
# Real/fake head
self.adv_head = nn.Sequential(
nn.Linear(512, 1),
nn.Sigmoid()
)
# Auxiliary classifier head
self.aux_head = nn.Sequential(
nn.Linear(512, num_classes),
nn.Softmax(dim=1)
)
def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
img: Images (batch_size, img_channels, img_size, img_size)
Returns:
validity: Real/fake probabilities (batch_size, 1)
class_probs: Class probabilities (batch_size, num_classes)
"""
features = self.features(img)
validity = self.adv_head(features)
class_probs = self.aux_head(features)
return validity, class_probs
# ============================================================================
# 4. PIX2PIX COMPONENTS
# ============================================================================
class UNetGenerator(nn.Module):
"""
U-Net generator for pix2pix.
Architecture:
- Encoder: Downsample input image
- Decoder: Upsample with skip connections from encoder
- Output: Image of same size as input
Args:
in_channels: Input image channels
out_channels: Output image channels
features: Base number of features (doubled at each layer)
"""
def __init__(self, in_channels: int = 3, out_channels: int = 3, features: int = 64):
super().__init__()
# Encoder (downsampling)
self.enc1 = self._block(in_channels, features, normalize=False) # 256 β 128
self.enc2 = self._block(features, features * 2) # 128 β 64
self.enc3 = self._block(features * 2, features * 4) # 64 β 32
self.enc4 = self._block(features * 4, features * 8) # 32 β 16
self.enc5 = self._block(features * 8, features * 8) # 16 β 8
self.enc6 = self._block(features * 8, features * 8) # 8 β 4
self.enc7 = self._block(features * 8, features * 8) # 4 β 2
# Bottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(features * 8, features * 8, 4, 2, 1),
nn.ReLU()
) # 2 β 1
# Decoder (upsampling with skip connections)
self.dec1 = self._up_block(features * 8, features * 8, dropout=True)
self.dec2 = self._up_block(features * 16, features * 8, dropout=True)
self.dec3 = self._up_block(features * 16, features * 8, dropout=True)
self.dec4 = self._up_block(features * 16, features * 8)
self.dec5 = self._up_block(features * 16, features * 4)
self.dec6 = self._up_block(features * 8, features * 2)
self.dec7 = self._up_block(features * 4, features)
# Final layer
self.final = nn.Sequential(
nn.ConvTranspose2d(features * 2, out_channels, 4, 2, 1),
nn.Tanh()
)
def _block(self, in_channels: int, out_channels: int, normalize: bool = True):
"""Encoder block: Conv β BatchNorm β LeakyReLU β Downsample"""
layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
if normalize:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2))
return nn.Sequential(*layers)
def _up_block(self, in_channels: int, out_channels: int, dropout: bool = False):
"""Decoder block: ConvTranspose β BatchNorm β Dropout β ReLU β Upsample"""
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
nn.BatchNorm2d(out_channels)
]
if dropout:
layers.append(nn.Dropout(0.5))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input image (batch_size, in_channels, H, W)
Returns:
Generated image (batch_size, out_channels, H, W)
"""
# Encoder with skip connections
e1 = self.enc1(x)
e2 = self.enc2(e1)
e3 = self.enc3(e2)
e4 = self.enc4(e3)
e5 = self.enc5(e4)
e6 = self.enc6(e5)
e7 = self.enc7(e6)
# Bottleneck
b = self.bottleneck(e7)
# Decoder with skip connections (concatenate)
d1 = self.dec1(b)
d1 = torch.cat([d1, e7], dim=1)
d2 = self.dec2(d1)
d2 = torch.cat([d2, e6], dim=1)
d3 = self.dec3(d2)
d3 = torch.cat([d3, e5], dim=1)
d4 = self.dec4(d3)
d4 = torch.cat([d4, e4], dim=1)
d5 = self.dec5(d4)
d5 = torch.cat([d5, e3], dim=1)
d6 = self.dec6(d5)
d6 = torch.cat([d6, e2], dim=1)
d7 = self.dec7(d6)
d7 = torch.cat([d7, e1], dim=1)
# Final output
return self.final(d7)
class PatchGANDiscriminator(nn.Module):
"""
PatchGAN discriminator for pix2pix.
Classifies whether NΓN patches are real or fake (70Γ70 receptive field).
Args:
in_channels: Input channels (img_channels * 2 for paired images)
features: Base number of features
"""
def __init__(self, in_channels: int = 6, features: int = 64):
super().__init__()
self.model = nn.Sequential(
# C64: No BatchNorm in first layer
nn.Conv2d(in_channels, features, 4, 2, 1),
nn.LeakyReLU(0.2),
# C128
nn.Conv2d(features, features * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features * 2),
nn.LeakyReLU(0.2),
# C256
nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features * 4),
nn.LeakyReLU(0.2),
# C512
nn.Conv2d(features * 4, features * 8, 4, 1, 1, bias=False),
nn.BatchNorm2d(features * 8),
nn.LeakyReLU(0.2),
# Output: 1 channel (real/fake for each patch)
nn.Conv2d(features * 8, 1, 4, 1, 1),
nn.Sigmoid()
)
def forward(self, img_input: torch.Tensor, img_target: torch.Tensor) -> torch.Tensor:
"""
Args:
img_input: Input image (batch_size, C, H, W)
img_target: Target image (batch_size, C, H, W)
Returns:
Patch predictions (batch_size, 1, H/16, W/16)
"""
# Concatenate input and target
x = torch.cat([img_input, img_target], dim=1)
return self.model(x)
# ============================================================================
# 5. TRAINING UTILITIES
# ============================================================================
def train_cgan_step(
generator: nn.Module,
discriminator: nn.Module,
real_imgs: torch.Tensor,
labels: torch.Tensor,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
latent_dim: int = 100,
device: str = 'cpu'
) -> Tuple[float, float]:
"""
Single training step for basic conditional GAN.
Returns:
d_loss: Discriminator loss
g_loss: Generator loss
"""
batch_size = real_imgs.size(0)
# Real and fake labels
real_labels = torch.ones(batch_size, 1, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_d.zero_grad()
# Real images
real_validity = discriminator(real_imgs, labels)
d_real_loss = F.binary_cross_entropy(real_validity, real_labels)
# Fake images
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z, labels)
fake_validity = discriminator(fake_imgs.detach(), labels)
d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_d.step()
# -----------------
# Train Generator
# -----------------
optimizer_g.zero_grad()
# Generate fake images and fool discriminator
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z, labels)
fake_validity = discriminator(fake_imgs, labels)
g_loss = F.binary_cross_entropy(fake_validity, real_labels)
g_loss.backward()
optimizer_g.step()
return d_loss.item(), g_loss.item()
def train_acgan_step(
generator: nn.Module,
discriminator: nn.Module,
real_imgs: torch.Tensor,
labels: torch.Tensor,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
latent_dim: int = 100,
device: str = 'cpu'
) -> Tuple[float, float]:
"""
Single training step for AC-GAN.
Returns:
d_loss: Discriminator loss (adversarial + classification)
g_loss: Generator loss (adversarial + classification)
"""
batch_size = real_imgs.size(0)
real_labels = torch.ones(batch_size, 1, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_d.zero_grad()
# Real images
real_validity, real_class = discriminator(real_imgs)
d_real_loss = F.binary_cross_entropy(real_validity, real_labels)
d_real_cls_loss = F.cross_entropy(real_class, labels)
# Fake images
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z, labels)
fake_validity, fake_class = discriminator(fake_imgs.detach())
d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels)
d_fake_cls_loss = F.cross_entropy(fake_class, labels)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2 + (d_real_cls_loss + d_fake_cls_loss) / 2
d_loss.backward()
optimizer_d.step()
# -----------------
# Train Generator
# -----------------
optimizer_g.zero_grad()
z = torch.randn(batch_size, latent_dim, device=device)
fake_imgs = generator(z, labels)
fake_validity, fake_class = discriminator(fake_imgs)
g_loss = F.binary_cross_entropy(fake_validity, real_labels) + F.cross_entropy(fake_class, labels)
g_loss.backward()
optimizer_g.step()
return d_loss.item(), g_loss.item()
def train_pix2pix_step(
generator: nn.Module,
discriminator: nn.Module,
input_imgs: torch.Tensor,
target_imgs: torch.Tensor,
optimizer_g: torch.optim.Optimizer,
optimizer_d: torch.optim.Optimizer,
lambda_l1: float = 100.0,
device: str = 'cpu'
) -> Tuple[float, float]:
"""
Single training step for pix2pix.
Args:
lambda_l1: Weight for L1 reconstruction loss
Returns:
d_loss: Discriminator loss
g_loss: Generator loss (adversarial + L1)
"""
batch_size = input_imgs.size(0)
real_labels = torch.ones(batch_size, 1, device=device)
fake_labels = torch.zeros(batch_size, 1, device=device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_d.zero_grad()
# Real pair
real_validity = discriminator(input_imgs, target_imgs)
d_real_loss = F.binary_cross_entropy(real_validity, real_labels.expand_as(real_validity))
# Fake pair
fake_imgs = generator(input_imgs)
fake_validity = discriminator(input_imgs, fake_imgs.detach())
d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels.expand_as(fake_validity))
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_d.step()
# -----------------
# Train Generator
# -----------------
optimizer_g.zero_grad()
fake_imgs = generator(input_imgs)
fake_validity = discriminator(input_imgs, fake_imgs)
# Adversarial loss
g_adv_loss = F.binary_cross_entropy(fake_validity, real_labels.expand_as(fake_validity))
# L1 reconstruction loss
g_l1_loss = F.l1_loss(fake_imgs, target_imgs)
# Total generator loss
g_loss = g_adv_loss + lambda_l1 * g_l1_loss
g_loss.backward()
optimizer_g.step()
return d_loss.item(), g_loss.item()
# ============================================================================
# 6. VISUALIZATION
# ============================================================================
def visualize_conditional_generation(
generator: nn.Module,
num_classes: int,
latent_dim: int,
device: str = 'cpu',
samples_per_class: int = 5
):
"""
Visualize conditional generation for all classes.
Args:
generator: Trained generator
num_classes: Number of classes
latent_dim: Latent dimension
device: Device
samples_per_class: Samples to generate per class
"""
generator.eval()
fig, axes = plt.subplots(num_classes, samples_per_class, figsize=(samples_per_class * 2, num_classes * 2))
with torch.no_grad():
for class_idx in range(num_classes):
# Generate samples for this class
z = torch.randn(samples_per_class, latent_dim, device=device)
labels = torch.full((samples_per_class,), class_idx, dtype=torch.long, device=device)
fake_imgs = generator(z, labels)
fake_imgs = fake_imgs.cpu()
# Plot samples
for sample_idx in range(samples_per_class):
ax = axes[class_idx, sample_idx] if num_classes > 1 else axes[sample_idx]
img = fake_imgs[sample_idx].squeeze()
if img.dim() == 3: # Color image
img = img.permute(1, 2, 0)
img = (img + 1) / 2 # Denormalize from [-1, 1] to [0, 1]
ax.imshow(img, cmap='gray' if img.dim() == 2 else None)
ax.axis('off')
if sample_idx == 0:
ax.set_title(f'Class {class_idx}', fontsize=10)
plt.tight_layout()
plt.suptitle('Conditional Generation (All Classes)', y=1.01, fontsize=14, fontweight='bold')
plt.show()
def visualize_pix2pix_results(
generator: nn.Module,
input_imgs: torch.Tensor,
target_imgs: torch.Tensor,
num_samples: int = 4
):
"""
Visualize pix2pix results (input β generated β target).
Args:
generator: Trained generator
input_imgs: Input images (batch_size, C, H, W)
target_imgs: Target images (batch_size, C, H, W)
num_samples: Number of samples to visualize
"""
generator.eval()
with torch.no_grad():
generated_imgs = generator(input_imgs[:num_samples])
# Move to CPU and denormalize
input_imgs = input_imgs[:num_samples].cpu()
generated_imgs = generated_imgs.cpu()
target_imgs = target_imgs[:num_samples].cpu()
# Denormalize from [-1, 1] to [0, 1]
input_imgs = (input_imgs + 1) / 2
generated_imgs = (generated_imgs + 1) / 2
target_imgs = (target_imgs + 1) / 2
fig, axes = plt.subplots(num_samples, 3, figsize=(9, num_samples * 3))
for i in range(num_samples):
# Input
axes[i, 0].imshow(input_imgs[i].permute(1, 2, 0))
axes[i, 0].set_title('Input')
axes[i, 0].axis('off')
# Generated
axes[i, 1].imshow(generated_imgs[i].permute(1, 2, 0))
axes[i, 1].set_title('Generated')
axes[i, 1].axis('off')
# Target
axes[i, 2].imshow(target_imgs[i].permute(1, 2, 0))
axes[i, 2].set_title('Target')
axes[i, 2].axis('off')
plt.tight_layout()
plt.show()
# ============================================================================
# 7. DEMONSTRATION
# ============================================================================
def demo_conditional_gan():
"""Demonstrate conditional GAN components."""
print("=" * 80)
print("Conditional GAN Demonstration")
print("=" * 80)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nDevice: {device}")
# 1. Basic cGAN
print("\n1. Basic Conditional GAN (Label Conditioning)")
print("-" * 40)
latent_dim = 100
num_classes = 10
img_size = 28
generator = ConditionalGenerator(
latent_dim=latent_dim,
num_classes=num_classes,
img_size=img_size
).to(device)
discriminator = ConditionalDiscriminator(
num_classes=num_classes,
img_size=img_size
).to(device)
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
# Test forward pass
z = torch.randn(4, latent_dim, device=device)
labels = torch.tensor([0, 1, 2, 3], device=device)
fake_imgs = generator(z, labels)
validity = discriminator(fake_imgs, labels)
print(f"\nGenerated images shape: {fake_imgs.shape}")
print(f"Discriminator output shape: {validity.shape}")
print(f"Discriminator output range: [{validity.min():.3f}, {validity.max():.3f}]")
# 2. Projection Discriminator
print("\n2. Projection Discriminator")
print("-" * 40)
proj_disc = ProjectionDiscriminator(num_classes=num_classes).to(device)
print(f"Projection discriminator parameters: {sum(p.numel() for p in proj_disc.parameters()):,}")
validity_proj = proj_disc(fake_imgs, labels)
print(f"Projection discriminator output shape: {validity_proj.shape}")
# 3. AC-GAN
print("\n3. Auxiliary Classifier GAN (AC-GAN)")
print("-" * 40)
acgan_gen = ACGANGenerator(latent_dim=latent_dim, num_classes=num_classes).to(device)
acgan_disc = ACGANDiscriminator(num_classes=num_classes).to(device)
print(f"AC-GAN Generator parameters: {sum(p.numel() for p in acgan_gen.parameters()):,}")
print(f"AC-GAN Discriminator parameters: {sum(p.numel() for p in acgan_disc.parameters()):,}")
fake_imgs_ac = acgan_gen(z, labels)
validity_ac, class_probs_ac = acgan_disc(fake_imgs_ac)
print(f"\nAC-GAN validity shape: {validity_ac.shape}")
print(f"AC-GAN class probabilities shape: {class_probs_ac.shape}")
print(f"Class probabilities sum: {class_probs_ac.sum(dim=1)}")
# 4. Pix2Pix (U-Net + PatchGAN)
print("\n4. Pix2Pix (Image-to-Image Translation)")
print("-" * 40)
unet_gen = UNetGenerator(in_channels=3, out_channels=3, features=64).to(device)
patch_disc = PatchGANDiscriminator(in_channels=6, features=64).to(device)
print(f"U-Net Generator parameters: {sum(p.numel() for p in unet_gen.parameters()):,}")
print(f"PatchGAN Discriminator parameters: {sum(p.numel() for p in patch_disc.parameters()):,}")
# Test with larger images
input_img = torch.randn(2, 3, 256, 256, device=device)
generated_img = unet_gen(input_img)
patch_validity = patch_disc(input_img, generated_img)
print(f"\nInput image shape: {input_img.shape}")
print(f"Generated image shape: {generated_img.shape}")
print(f"Patch predictions shape: {patch_validity.shape}")
print(f"Number of patches: {patch_validity.numel() // patch_validity.size(0)}")
def demo_training_comparison():
"""Compare different conditioning strategies."""
print("\n" + "=" * 80)
print("Conditioning Strategy Comparison")
print("=" * 80)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Create sample data
batch_size = 16
latent_dim = 100
num_classes = 10
img_size = 28
# 1. Concatenation (Basic cGAN)
print("\n1. Concatenation Strategy")
print("-" * 40)
gen_concat = ConditionalGenerator(latent_dim, num_classes).to(device)
disc_concat = ConditionalDiscriminator(num_classes).to(device)
z = torch.randn(batch_size, latent_dim, device=device)
labels = torch.randint(0, num_classes, (batch_size,), device=device)
real_imgs = torch.randn(batch_size, 1, img_size, img_size, device=device)
fake_imgs = gen_concat(z, labels)
validity = disc_concat(fake_imgs, labels)
print(f"Generated images: {fake_imgs.shape}")
print(f"Validity: {validity.shape}")
# 2. Projection
print("\n2. Projection Strategy")
print("-" * 40)
disc_proj = ProjectionDiscriminator(num_classes).to(device)
validity_proj = disc_proj(fake_imgs, labels)
print(f"Projection validity: {validity_proj.shape}")
# Parameter comparison
print("\n3. Parameter Efficiency")
print("-" * 40)
concat_params = sum(p.numel() for p in disc_concat.parameters())
proj_params = sum(p.numel() for p in disc_proj.parameters())
print(f"Concatenation discriminator: {concat_params:,} parameters")
print(f"Projection discriminator: {proj_params:,} parameters")
print(f"Reduction: {(1 - proj_params / concat_params) * 100:.1f}%")
# Run demonstrations
if __name__ == "__main__":
print("Starting Conditional GAN demonstrations...\n")
# Main demonstration
demo_conditional_gan()
# Training comparison
demo_training_comparison()
print("\n" + "=" * 80)
print("Conditional GAN Implementation Complete!")
print("=" * 80)
print("\nKey Components Implemented:")
print("β Basic cGAN with label conditioning")
print("β Projection discriminator (efficient conditioning)")
print("β AC-GAN (auxiliary classifier)")
print("β U-Net generator for image-to-image translation")
print("β PatchGAN discriminator")
print("β Training utilities for all variants")
print("β Visualization functions")
print("\nExtensions to explore:")
print("- CycleGAN for unpaired translation")
print("- StarGAN for multi-domain translation")
print("- SPADE for semantic image synthesis")
print("- Text-to-image with attention (AttnGAN)")
print("- High-resolution synthesis (Pix2PixHD)")
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')
1. Motivation: Controlled GenerationΒΆ
Standard GAN LimitationΒΆ
Vanilla GAN: $\(\min_G \max_D \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\)$
Problem: No control over what gets generated!
Random samples from learned distribution
Cannot specify desired attributes
Conditional GAN Solution (Mirza & Osindero, 2014)ΒΆ
Add conditioning variable \(y\) (class label, attributes, etc.):
Key changes:
Generator: \(G(z, y)\) produces \(x\) conditioned on \(y\)
Discriminator: \(D(x, y)\) checks if \(x\) matches condition \(y\)
ApplicationsΒΆ
Class-conditional: Generate specific digit/object
Text-to-image: Generate image from description
Image-to-image: Style transfer, super-resolution
Attribute control: Age, expression, hair color
π Reference Materials:
gan.pdf - Gan
generative_models.pdf - Generative Models
2. Architecture DesignΒΆ
Conditioning MethodsΒΆ
1. Concatenation:
Embed \(y\) and concatenate with \(z\) or \(x\)
Simple, widely used
2. Conditional Batch Normalization:
Modulate BN parameters based on \(y\)
More powerful for complex conditioning
3. Projection:
Project \(y\) embedding into discriminator layers
Better gradient flow
Class-Conditional ArchitectureΒΆ
For one-hot encoded labels \(y \in \{0, 1\}^C\):
Generator: $\(h = [z; \text{Embed}(y)]\)\( \)\(x = G(h)\)$
Discriminator: $\(h = [x; \text{Embed}(y)]\)\( \)\(\text{score} = D(h)\)$
Conditional GAN for MNISTΒΆ
A Conditional GAN (cGAN) augments both the generator and discriminator with class label information, enabling controlled generation of specific digit classes. The generator receives a noise vector \(z\) concatenated with a one-hot encoded label \(y\), and learns the conditional distribution \(p(x | y)\). The discriminator similarly receives the label as additional input and must judge whether a (image, label) pair is real or fabricated. Conditioning on labels transforms the GAN from an unconditional density estimator into a powerful tool for class-specific synthesis, which is the foundation for applications like text-to-image generation and data augmentation.
class ConditionalGenerator(nn.Module):
"""Generator conditioned on class labels."""
def __init__(self, latent_dim=100, n_classes=10, img_size=28):
super().__init__()
self.latent_dim = latent_dim
self.n_classes = n_classes
self.img_size = img_size
# Label embedding
self.label_emb = nn.Embedding(n_classes, n_classes)
# Generator network
self.model = nn.Sequential(
nn.Linear(latent_dim + n_classes, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, img_size * img_size),
nn.Tanh()
)
def forward(self, z, labels):
# Embed labels
label_input = self.label_emb(labels)
# Concatenate noise and labels
gen_input = torch.cat([z, label_input], dim=1)
# Generate image
img = self.model(gen_input)
img = img.view(img.size(0), 1, self.img_size, self.img_size)
return img
class ConditionalDiscriminator(nn.Module):
"""Discriminator conditioned on class labels."""
def __init__(self, n_classes=10, img_size=28):
super().__init__()
self.n_classes = n_classes
self.img_size = img_size
# Label embedding
self.label_emb = nn.Embedding(n_classes, n_classes)
# Discriminator network
self.model = nn.Sequential(
nn.Linear(img_size * img_size + n_classes, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
# Flatten image
img_flat = img.view(img.size(0), -1)
# Embed labels
label_input = self.label_emb(labels)
# Concatenate image and labels
disc_input = torch.cat([img_flat, label_input], dim=1)
# Discriminate
validity = self.model(disc_input)
return validity
# Initialize models
latent_dim = 100
n_classes = 10
generator = ConditionalGenerator(latent_dim, n_classes).to(device)
discriminator = ConditionalDiscriminator(n_classes).to(device)
print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
Training Conditional GANΒΆ
The training loop alternates between discriminator and generator updates, following the standard GAN protocol but with labels passed to both networks. The discriminator loss encourages it to accept real (image, label) pairs and reject fake ones, while the generator loss pushes it to produce images that the discriminator accepts for a given label. Tracking per-class generation quality during training (e.g., via FID or visual inspection per digit) helps ensure the model does not mode-collapse to only generating a subset of classes. Label smoothing and spectral normalization are common stabilization techniques used in practice.
# Load MNIST
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # Scale to [-1, 1]
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
# Optimizers
lr = 0.0002
beta1 = 0.5
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# Loss
adversarial_loss = nn.BCELoss()
# Training
n_epochs = 30
history = {'D_loss': [], 'G_loss': [], 'D_real': [], 'D_fake': []}
for epoch in range(n_epochs):
epoch_D_loss = 0
epoch_G_loss = 0
epoch_D_real = 0
epoch_D_fake = 0
for i, (imgs, labels) in enumerate(train_loader):
batch_size = imgs.size(0)
# Labels
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
imgs = imgs.to(device)
labels = labels.to(device)
# -----------------
# Train Discriminator
# -----------------
optimizer_D.zero_grad()
# Real images
real_validity = discriminator(imgs, labels)
d_real_loss = adversarial_loss(real_validity, real_labels)
# Fake images
z = torch.randn(batch_size, latent_dim).to(device)
fake_imgs = generator(z, labels)
fake_validity = discriminator(fake_imgs.detach(), labels)
d_fake_loss = adversarial_loss(fake_validity, fake_labels)
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Generate images
z = torch.randn(batch_size, latent_dim).to(device)
gen_imgs = generator(z, labels)
# Generator loss
validity = discriminator(gen_imgs, labels)
g_loss = adversarial_loss(validity, real_labels)
g_loss.backward()
optimizer_G.step()
# Statistics
epoch_D_loss += d_loss.item()
epoch_G_loss += g_loss.item()
epoch_D_real += real_validity.mean().item()
epoch_D_fake += fake_validity.mean().item()
# Record epoch stats
n_batches = len(train_loader)
history['D_loss'].append(epoch_D_loss / n_batches)
history['G_loss'].append(epoch_G_loss / n_batches)
history['D_real'].append(epoch_D_real / n_batches)
history['D_fake'].append(epoch_D_fake / n_batches)
print(f"Epoch [{epoch+1}/{n_epochs}] "
f"D_loss: {history['D_loss'][-1]:.4f} "
f"G_loss: {history['G_loss'][-1]:.4f} "
f"D(real): {history['D_real'][-1]:.4f} "
f"D(fake): {history['D_fake'][-1]:.4f}")
print("\nTraining complete!")
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(history['D_loss'], label='Discriminator', linewidth=2)
axes[0].plot(history['G_loss'], label='Generator', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].plot(history['D_real'], label='D(real)', linewidth=2)
axes[1].plot(history['D_fake'], label='D(fake)', linewidth=2)
axes[1].axhline(y=0.5, color='r', linestyle='--', label='Equilibrium')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Discriminator Output', fontsize=12)
axes[1].set_title('Discriminator Performance', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Class-Conditional GenerationΒΆ
With a trained cGAN, generating images for a specific class is as simple as fixing the label input and sampling different noise vectors. Generating a grid of samples per class provides a quick visual assessment of generation quality and diversity. A well-trained cGAN should produce distinct, recognizable digits for each class with meaningful intra-class variation driven by the noise vector. This capability makes cGANs directly applicable to data augmentation (generating more samples for underrepresented classes) and controllable content creation.
generator.eval()
# Generate specific digits
with torch.no_grad():
# Generate one sample for each class (0-9)
z = torch.randn(10, latent_dim).to(device)
labels = torch.arange(0, 10).to(device)
generated = generator(z, labels).cpu()
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i in range(10):
img = generated[i].squeeze()
img = (img + 1) / 2 # Denormalize from [-1,1] to [0,1]
axes[i].imshow(img, cmap='gray')
axes[i].set_title(f'Class {i}', fontsize=11)
axes[i].axis('off')
plt.suptitle('Class-Conditional Generation', fontsize=14)
plt.tight_layout()
plt.show()
# Generate multiple samples per class
print("\nMultiple samples per class:")
with torch.no_grad():
n_samples = 5
fig, axes = plt.subplots(10, n_samples, figsize=(12, 20))
for digit in range(10):
z = torch.randn(n_samples, latent_dim).to(device)
labels = torch.full((n_samples,), digit, dtype=torch.long).to(device)
samples = generator(z, labels).cpu()
for i in range(n_samples):
img = samples[i].squeeze()
img = (img + 1) / 2
axes[digit, i].imshow(img, cmap='gray')
axes[digit, i].axis('off')
if i == 0:
axes[digit, i].set_ylabel(f'{digit}', fontsize=13, rotation=0, labelpad=20)
plt.suptitle('Diversity within Each Class', fontsize=14)
plt.tight_layout()
plt.show()
Latent Space Interpolation with Fixed ClassΒΆ
By fixing the class label and smoothly interpolating the noise vector \(z\) between two random endpoints, we can visualize the diversity captured within a single class. Smooth transitions indicate that the generator has learned a continuous mapping from latent space to image space, without sharp discontinuities or mode collapse. Comparing interpolations across different classes also reveals whether the model allocates similar latent capacity to each digit, which is a useful diagnostic for balanced generation quality.
# Interpolate between two latent codes for same class
with torch.no_grad():
digit = 7
n_steps = 10
# Two random starting points
z1 = torch.randn(1, latent_dim).to(device)
z2 = torch.randn(1, latent_dim).to(device)
# Interpolation
alphas = np.linspace(0, 1, n_steps)
fig, axes = plt.subplots(1, n_steps, figsize=(15, 2))
for i, alpha in enumerate(alphas):
z_interp = (1 - alpha) * z1 + alpha * z2
label = torch.tensor([digit]).to(device)
img = generator(z_interp, label).cpu().squeeze()
img = (img + 1) / 2
axes[i].imshow(img, cmap='gray')
axes[i].axis('off')
plt.suptitle(f'Latent Interpolation for Digit {digit}', fontsize=14)
plt.tight_layout()
plt.show()
# Fixed noise, vary class
print("\nFixed noise, varying class:")
with torch.no_grad():
z_fixed = torch.randn(1, latent_dim).to(device)
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for digit in range(10):
label = torch.tensor([digit]).to(device)
img = generator(z_fixed, label).cpu().squeeze()
img = (img + 1) / 2
axes[digit].imshow(img, cmap='gray')
axes[digit].set_title(f'{digit}', fontsize=11)
axes[digit].axis('off')
plt.suptitle('Same Noise, Different Classes', fontsize=14)
plt.tight_layout()
plt.show()
print("\nObservation: Style (from z) is consistent, but class changes")
SummaryΒΆ
Key Innovations:ΒΆ
Conditional generation: Control output via conditioning variable \(y\)
Flexible conditioning: Class labels, attributes, images, text
Discriminator conditioning: Ensures generated samples match condition
Disentanglement: Separate noise (style) from condition (content)
Architecture Patterns:ΒΆ
Concatenation:
Generator: \(G([z; \text{emb}(y)])\)
Discriminator: \(D([x; \text{emb}(y)])\)
Projection (Miyato & Koyama, 2018): $\(D(x, y) = \phi(x)^T W + \psi(y)^T \phi(x)\)$
Advanced Variants:ΒΆ
AC-GAN (Odena et al., 2017): Auxiliary classifier in discriminator
Pix2Pix (Isola et al., 2017): Image-to-image translation
StarGAN (Choi et al., 2018): Multi-domain translation
BigGAN (Brock et al., 2019): Class-conditional at scale
Applications:ΒΆ
Controllable generation: Specify what to generate
Data augmentation: Generate labeled samples
Transfer learning: Condition on domains
Creative tools: Interactive generation
Training Tips:ΒΆ
Balance discriminator and generator updates
Use label smoothing (soft labels)
Spectral normalization for stability
Progressive growing for high resolution
Next Steps:ΒΆ
08_stylegan.ipynb - Style-based generation
09_pix2pix.ipynb - Image-to-image translation
01_gan_mathematics.ipynb - GAN theory review