Conditional GANs (cGAN): Comprehensive TheoryΒΆ

IntroductionΒΆ

Conditional Generative Adversarial Networks (cGANs) extend the standard GAN framework by conditioning both the generator and discriminator on additional information \(y\). This conditioning enables controlled generation, where we can specify what kind of samples to generate (e.g., class labels, text descriptions, images, or other auxiliary data).

Key Innovation: Transform unsupervised GAN into a supervised framework with explicit control over generation.

Applications:

  • Class-conditional image generation (generate specific digits, objects, faces)

  • Image-to-image translation (pix2pix: edges β†’ photos, day β†’ night)

  • Text-to-image synthesis

  • Super-resolution

  • Style transfer

  • Inpainting and outpainting

Mathematical FoundationΒΆ

Standard GAN (Unconditional)ΒΆ

Recall the standard GAN minimax objective:

\[\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{z} \sim p_z}[\log(1 - D(G(\mathbf{z})))]\]

Where:

  • \(G(\mathbf{z})\): Generator maps noise \(\mathbf{z}\) to fake samples

  • \(D(\mathbf{x})\): Discriminator predicts probability that \(\mathbf{x}\) is real

  • No explicit control over what is generated

Conditional GAN ObjectiveΒΆ

In cGAN, we condition both \(G\) and \(D\) on additional information \(y\):

\[\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x}, y \sim p_{\text{data}}}[\log D(\mathbf{x}, y)] + \mathbb{E}_{\mathbf{z} \sim p_z, y \sim p_y}[\log(1 - D(G(\mathbf{z}, y), y))]\]

Where:

  • \(G(\mathbf{z}, y)\): Generator conditioned on \(y\) (e.g., class label, text, image)

  • \(D(\mathbf{x}, y)\): Discriminator predicts if \(\mathbf{x}\) is real given condition \(y\)

  • \(y\) can be discrete (class labels) or continuous (images, embeddings)

Key Insight: Discriminator must verify two things:

  1. Is the sample realistic?

  2. Does the sample match the condition \(y\)?

Conditioning MechanismsΒΆ

1. Label Conditioning (Discrete Classes)ΒΆ

For class-conditional generation (e.g., MNIST digits 0-9):

Generator Conditioning: $\(G(\mathbf{z}, y) = G([\mathbf{z}; \mathbf{e}(y)])\)$

Where \(\mathbf{e}(y)\) is an embedding of class label \(y\) (one-hot or learned embedding).

Discriminator Conditioning: $\(D(\mathbf{x}, y) = D([\mathbf{x}; \mathbf{e}(y)])\)$

Implementation:

  • Concatenate embedded label with noise vector (generator input)

  • Concatenate embedded label with image (discriminator input)

  • Can use projection discriminator for efficiency (discussed later)

2. Image Conditioning (Paired Image-to-Image)ΒΆ

For tasks like pix2pix (edges β†’ photo):

Generator: U-Net architecture $\(G(\mathbf{x}_{\text{input}}) = \mathbf{x}_{\text{output}}\)$

Discriminator: PatchGAN (evaluates \(N \times N\) patches) $\(D(\mathbf{x}_{\text{output}}, \mathbf{x}_{\text{input}}) \rightarrow \mathbb{R}^{H \times W}\)$

Loss: Combines adversarial loss + L1 reconstruction $\(\mathcal{L} = \mathcal{L}_{\text{GAN}} + \lambda \mathbb{E}[\|\mathbf{x}_{\text{real}} - G(\mathbf{x}_{\text{input}})\|_1]\)$

3. Text Conditioning (Text-to-Image)ΒΆ

For generating images from captions:

Text Embedding:

  • Use pre-trained text encoder (e.g., BERT, CLIP)

  • \(\mathbf{t} = \text{Encoder}(\text{"a red car on the street"})\)

Generator: $\(G(\mathbf{z}, \mathbf{t}) = G([\mathbf{z}; \mathbf{t}])\)$

Discriminator (matching-aware): $\(D(\mathbf{x}, \mathbf{t}) = \begin{cases} \text{real} & \text{if } \mathbf{x} \text{ is real and matches } \mathbf{t} \\ \text{fake} & \text{otherwise} \end{cases}\)$

4. Continuous Conditioning (Attributes)ΒΆ

For generating faces with specific attributes (age, gender, expression):

Attribute Vector: \(\mathbf{a} \in \mathbb{R}^d\) (e.g., [age=30, male=1, smile=0.5])

Generator: $\(G(\mathbf{z}, \mathbf{a}) = G([\mathbf{z}; \mathbf{a}])\)$

Advanced Conditioning ArchitecturesΒΆ

Projection DiscriminatorΒΆ

Problem: Naive concatenation is inefficient and limits discriminator capacity.

Solution (Miyato & Koyama, 2018): Project conditioning information into discriminator’s intermediate features.

Architecture: $\(D(\mathbf{x}, y) = \sigma(\mathbf{w}^T \phi(\mathbf{x}) + \mathbf{e}(y)^T \phi(\mathbf{x}))\)$

Where:

  • \(\phi(\mathbf{x})\): Feature extractor (CNN)

  • \(\mathbf{e}(y)\): Embedding of condition \(y\)

  • Inner product \(\mathbf{e}(y)^T \phi(\mathbf{x})\) captures correlation

Benefits:

  • More parameter-efficient

  • Better gradient flow

  • Improved conditioning signal

  • State-of-the-art for class-conditional generation

Auxiliary Classifier GAN (AC-GAN)ΒΆ

Innovation: Discriminator also predicts the class label.

Discriminator Outputs:

  1. Real/Fake probability: \(D(\mathbf{x})\)

  2. Class probability: \(C(\mathbf{x})\)

Losses:

Adversarial Loss: $\(\mathcal{L}_{\text{adv}} = \mathbb{E}[\log D(\mathbf{x}_{\text{real}})] + \mathbb{E}[\log(1 - D(G(\mathbf{z}, y)))]\)$

Classification Loss (both real and fake): $\(\mathcal{L}_{\text{cls}} = \mathbb{E}[\log C(y | \mathbf{x}_{\text{real}})] + \mathbb{E}[\log C(y | G(\mathbf{z}, y))]\)$

Total Objective: $\(\mathcal{L}_D = \mathcal{L}_{\text{adv}} - \mathcal{L}_{\text{cls}}\)\( \)\(\mathcal{L}_G = -\mathcal{L}_{\text{adv}} + \mathcal{L}_{\text{cls}}\)$

Trade-off:

  • Pros: Better class separation, higher quality

  • Cons: Can reduce diversity within classes (mode collapse per class)

Conditional Batch Normalization (CBN)ΒΆ

Idea: Modulate batch normalization parameters based on condition.

Standard Batch Normalization: $\(\text{BN}(\mathbf{h}) = \gamma \frac{\mathbf{h} - \mu}{\sigma} + \beta\)$

Conditional Batch Normalization: $\(\text{CBN}(\mathbf{h}, y) = \gamma(y) \frac{\mathbf{h} - \mu}{\sigma} + \beta(y)\)$

Where \(\gamma(y)\) and \(\beta(y)\) are predicted from condition \(y\) via small MLPs.

Benefits:

  • Efficient conditioning mechanism

  • Used in modern architectures (BigGAN, StyleGAN)

  • Allows fine-grained control over feature maps

Pix2Pix: Image-to-Image TranslationΒΆ

ArchitectureΒΆ

Generator: U-Net

  • Encoder: Downsample input image (extract features)

  • Decoder: Upsample to output resolution

  • Skip connections: Preserve spatial information from encoder to decoder

  • Output: Generated image matching input resolution

Structure:

Encoder:  C64-C128-C256-C512-C512-C512-C512
Decoder:  CD512-CD512-CD512-C512-C256-C128-C64

Where:

  • C\(k\): Convolution-BatchNorm-ReLU with \(k\) filters

  • CD\(k\): Convolution-BatchNorm-Dropout-ReLU with \(k\) filters

Discriminator: PatchGAN

  • Evaluates whether \(N \times N\) patches are real/fake

  • Typical patch size: 70Γ—70

  • Smaller receptive field β†’ faster, more local

  • Output: \(H/16 \times W/16\) map of patch predictions

Loss FunctionΒΆ

Adversarial Loss (cGAN): $\(\mathcal{L}_{\text{cGAN}}(G, D) = \mathbb{E}_{x,y}[\log D(x, y)] + \mathbb{E}_{x,z}[\log(1 - D(x, G(x, z)))]\)$

L1 Reconstruction Loss: $\(\mathcal{L}_{L1}(G) = \mathbb{E}_{x,y,z}[\|y - G(x, z)\|_1]\)$

Total Generator Loss: $\(G^* = \arg\min_G \max_D \mathcal{L}_{\text{cGAN}}(G, D) + \lambda \mathcal{L}_{L1}(G)\)$

Typical \(\lambda = 100\) (heavily weight reconstruction).

Why L1 instead of L2?

  • L1 encourages less blurring than L2

  • L2 averages all plausible outputs β†’ blurry

  • L1 picks mode closer to ground truth

Training StrategyΒΆ

  1. Paired Data: Requires input-output pairs \((x, y)\)

  2. Preprocessing: Jittering (resize to 286Γ—286, random crop to 256Γ—256)

  3. Augmentation: Random flips

  4. Optimizer: Adam with \(\beta_1 = 0.5\) (lower momentum for stability)

  5. Learning Rate: \(2 \times 10^{-4}\), constant for first 100 epochs, then linear decay

CycleGAN: Unpaired Image-to-Image TranslationΒΆ

MotivationΒΆ

Problem with pix2pix: Requires paired training data \((x, y)\), which is expensive or impossible to collect (e.g., photo ↔ painting).

Solution: CycleGAN learns mappings between domains \(X\) and \(Y\) using unpaired data.

Cycle ConsistencyΒΆ

Key Idea: If we translate \(x \rightarrow y \rightarrow x'\), we should get back \(x' \approx x\).

Cycle Consistency Loss: $\(\mathcal{L}_{\text{cyc}}(G, F) = \mathbb{E}_{x \sim X}[\|F(G(x)) - x\|_1] + \mathbb{E}_{y \sim Y}[\|G(F(y)) - y\|_1]\)$

Where:

  • \(G: X \rightarrow Y\) (e.g., horse β†’ zebra)

  • \(F: Y \rightarrow X\) (e.g., zebra β†’ horse)

Interpretation:

  • Forward cycle: \(x \xrightarrow{G} y' \xrightarrow{F} x'\), minimize \(\|x - x'\|\)

  • Backward cycle: \(y \xrightarrow{F} x' \xrightarrow{G} y'\), minimize \(\|y - y'\|\)

Full CycleGAN ObjectiveΒΆ

Adversarial Losses (two GANs): $\(\mathcal{L}_{\text{GAN}}(G, D_Y) = \mathbb{E}_{y}[\log D_Y(y)] + \mathbb{E}_{x}[\log(1 - D_Y(G(x)))]\)\( \)\(\mathcal{L}_{\text{GAN}}(F, D_X) = \mathbb{E}_{x}[\log D_X(x)] + \mathbb{E}_{y}[\log(1 - D_X(F(y)))]\)$

Total Objective: $\(\mathcal{L}(G, F, D_X, D_Y) = \mathcal{L}_{\text{GAN}}(G, D_Y) + \mathcal{L}_{\text{GAN}}(F, D_X) + \lambda \mathcal{L}_{\text{cyc}}(G, F)\)$

Typical \(\lambda = 10\).

Identity Loss (optional, for color preservation): $\(\mathcal{L}_{\text{identity}}(G, F) = \mathbb{E}_{y}[\|G(y) - y\|_1] + \mathbb{E}_{x}[\|F(x) - x\|_1]\)$

Encourages \(G\) to be close to identity when given real images from target domain.

ArchitectureΒΆ

Generator: ResNet-based

  • Downsampling: 2 strided convolutions

  • Transformation: 6-9 residual blocks

  • Upsampling: 2 fractionally-strided convolutions

  • Output: tanh activation (range [-1, 1])

Discriminator: PatchGAN (70Γ—70 patches)

LimitationsΒΆ

  1. Mode Collapse: Can map all inputs to single output

  2. Geometric Changes: Struggles with large shape changes (horse β†’ zebra βœ“, cat β†’ dog βœ—)

  3. Texture vs. Shape: Often changes texture while preserving shape

StarGAN: Multi-Domain TranslationΒΆ

MotivationΒΆ

Problem: CycleGAN requires separate model for each domain pair (N domains β†’ N(N-1) models).

Solution: StarGAN uses single generator for all domain translations.

ArchitectureΒΆ

Generator: $\(G(\mathbf{x}, \mathbf{c}) \rightarrow \mathbf{y}\)$

Where \(\mathbf{c}\) is target domain label (one-hot or binary vector).

Example (CelebA attributes):

  • \(\mathbf{c} = [1, 0, 1, 0]\) β†’ male, no beard, smiling, young

Discriminator:

  • Real/Fake prediction: \(D_{\text{src}}(\mathbf{x})\)

  • Domain classification: \(D_{\text{cls}}(\mathbf{x}) \rightarrow \mathbf{c}\)

Loss FunctionsΒΆ

Adversarial Loss: $\(\mathcal{L}_{\text{adv}} = \mathbb{E}_{\mathbf{x}}[\log D_{\text{src}}(\mathbf{x})] + \mathbb{E}_{\mathbf{x}, \mathbf{c}}[\log(1 - D_{\text{src}}(G(\mathbf{x}, \mathbf{c})))]\)$

Domain Classification Loss:

  • Real images: \(\mathcal{L}_{\text{cls}}^r = \mathbb{E}_{\mathbf{x}, \mathbf{c}'}[-\log D_{\text{cls}}(\mathbf{c}' | \mathbf{x})]\)

  • Fake images: \(\mathcal{L}_{\text{cls}}^f = \mathbb{E}_{\mathbf{x}, \mathbf{c}}[-\log D_{\text{cls}}(\mathbf{c} | G(\mathbf{x}, \mathbf{c}))]\)

Reconstruction Loss: $\(\mathcal{L}_{\text{rec}} = \mathbb{E}_{\mathbf{x}, \mathbf{c}, \mathbf{c}'}[\|\mathbf{x} - G(G(\mathbf{x}, \mathbf{c}), \mathbf{c}')\|_1]\)$

Full Objective: $\(\mathcal{L}_D = -\mathcal{L}_{\text{adv}} + \lambda_{\text{cls}} \mathcal{L}_{\text{cls}}^r\)\( \)\(\mathcal{L}_G = \mathcal{L}_{\text{adv}} + \lambda_{\text{cls}} \mathcal{L}_{\text{cls}}^f + \lambda_{\text{rec}} \mathcal{L}_{\text{rec}}\)$

Pix2PixHD: High-Resolution Image SynthesisΒΆ

Challenges at High ResolutionΒΆ

  1. Training Instability: Harder to train GANs at 2048Γ—1024 resolution

  2. Mode Collapse: More severe at high resolution

  3. Computational Cost: Memory and compute requirements

InnovationsΒΆ

1. Coarse-to-Fine Generator

  • \(G_1\): Global generator (low resolution)

  • \(G_2\): Local enhancer (high resolution)

  • Train \(G_1\) first, then fix and train \(G_2\)

2. Multi-Scale Discriminators

Use three discriminators at different scales:

  • \(D_1\): Original resolution

  • \(D_2\): Downsampled 2Γ—

  • \(D_3\): Downsampled 4Γ—

Each discriminator has same architecture (PatchGAN) but operates at different scales.

Benefit:

  • \(D_1\) focuses on fine details

  • \(D_3\) guides overall structure

  • Improved gradient signal

3. Feature Matching Loss

Instead of only fooling discriminator, match intermediate features:

\[\mathcal{L}_{\text{FM}}(G, D_k) = \mathbb{E} \sum_{i=1}^{T} \frac{1}{N_i}[\|D_k^{(i)}(\mathbf{x}) - D_k^{(i)}(G(\mathbf{s}))\|_1]\]

Where \(D_k^{(i)}\) is \(i\)-th layer of discriminator \(D_k\).

Benefit: More stable training, less mode collapse.

SPADE: Semantic Image SynthesisΒΆ

Spatially-Adaptive NormalizationΒΆ

Problem: Standard normalization (BatchNorm, InstanceNorm) washes out semantic information in segmentation masks.

Solution: SPADE (Spatially-Adaptive DEnormalization)

Standard Instance Normalization: $\(\gamma \frac{\mathbf{h} - \mu(\mathbf{h})}{\sigma(\mathbf{h})} + \beta\)$

SPADE: $\(\gamma(\mathbf{m}) \odot \frac{\mathbf{h} - \mu(\mathbf{h})}{\sigma(\mathbf{h})} + \beta(\mathbf{m})\)$

Where:

  • \(\mathbf{m}\): Semantic segmentation mask

  • \(\gamma(\mathbf{m}), \beta(\mathbf{m})\): Spatially-varying parameters predicted from \(\mathbf{m}\)

  • \(\odot\): Element-wise multiplication

Architecture:

m β†’ Conv3Γ—3 β†’ ReLU β†’ Conv3Γ—3 β†’ Ξ³(m)
                    β””β†’ Conv3Γ—3 β†’ Ξ²(m)

Benefit: Preserves semantic information while normalizing activations.

ApplicationsΒΆ

  • GauGAN: Segmentation mask β†’ photorealistic image

  • Interactive editing: Draw segmentation, get realistic scene

  • Style control: Vary appearance while maintaining layout

Text-to-Image GenerationΒΆ

StackGANΒΆ

Idea: Generate images in two stages (low-res β†’ high-res).

Stage-I:

  • Input: Text embedding \(\mathbf{t}\), noise \(\mathbf{z}\)

  • Output: 64Γ—64 low-resolution image

  • Captures rough shape and color

Stage-II:

  • Input: Stage-I image + text embedding

  • Output: 256Γ—256 high-resolution image

  • Adds details and fixes defects

Conditioning Augmentation:

Problem: Limited training data β†’ overfitting to text embeddings.

Solution: Sample from Gaussian conditioned on text: $\(\mathbf{c} \sim \mathcal{N}(\mu(\mathbf{t}), \Sigma(\mathbf{t}))\)$

Where \(\mu, \Sigma\) are learned.

Regularization: $\(\mathcal{L}_{\text{CA}} = D_{KL}(\mathcal{N}(\mu(\mathbf{t}), \Sigma(\mathbf{t})) \| \mathcal{N}(0, I))\)$

AttnGANΒΆ

Innovation: Fine-grained attention mechanism.

Word-Level Attention:

  1. Extract word features from text: \(\{\mathbf{w}_1, \ldots, \mathbf{w}_T\}\)

  2. For each spatial location in image features, compute attention over words

  3. Weighted combination of word features guides generation

Attention Mechanism: $\(\alpha_{ij} = \frac{\exp(\mathbf{h}_i^T \mathbf{w}_j)}{\sum_{k} \exp(\mathbf{h}_i^T \mathbf{w}_k)}\)$

\[\mathbf{c}_i = \sum_{j} \alpha_{ij} \mathbf{w}_j\]

Deep Attentional Multimodal Similarity Model (DAMSM):

Measures image-text matching at word level: $\(\mathcal{L}_{\text{DAMSM}} = \log P(y | \mathbf{I}) + \log P(y | \mathbf{w}_{1:T})\)$

Encourages generated images to be semantically consistent with input text.

Advanced Training Techniques for cGANΒΆ

Spectral NormalizationΒΆ

Problem: Training instability, especially for conditional models.

Solution: Constrain Lipschitz constant of discriminator.

Spectral Norm: $\(W_{\text{SN}} = \frac{W}{\sigma(W)}\)$

Where \(\sigma(W)\) is largest singular value of weight matrix \(W\).

Implementation: Power iteration to estimate \(\sigma(W)\) efficiently.

Benefit:

  • More stable training

  • Better gradient flow

  • Can remove batch normalization from discriminator

Hinge LossΒΆ

Standard GAN Loss (saturating): $\(\mathcal{L}_D = -\mathbb{E}[\log D(\mathbf{x})] - \mathbb{E}[\log(1 - D(G(\mathbf{z})))]\)$

Hinge Loss (non-saturating): $\(\mathcal{L}_D = \mathbb{E}[\max(0, 1 - D(\mathbf{x}))] + \mathbb{E}[\max(0, 1 + D(G(\mathbf{z})))]\)\( \)\(\mathcal{L}_G = -\mathbb{E}[D(G(\mathbf{z}))]\)$

Benefits:

  • Better gradient signal

  • Used in state-of-the-art models (BigGAN, StyleGAN2)

Two Time-Scale Update Rule (TTUR)ΒΆ

Observation: Discriminator learns faster than generator.

Solution: Use different learning rates:

  • Generator: \(\alpha_G = 10^{-4}\)

  • Discriminator: \(\alpha_D = 4 \times 10^{-4}\)

Benefit: Maintains balance between \(G\) and \(D\).

Gradient PenaltyΒΆ

R1 Regularization (on real data): $\(\mathcal{L}_{\text{R1}} = \frac{\gamma}{2} \mathbb{E}_{\mathbf{x}}[\|\nabla D(\mathbf{x})\|^2]\)$

Benefits:

  • Stabilizes training

  • Prevents discriminator from becoming too confident

  • Used in StyleGAN2

Evaluation Metrics for Conditional GenerationΒΆ

Inception Score (IS)ΒΆ

\[\text{IS} = \exp(\mathbb{E}_{\mathbf{x}}[D_{KL}(p(y|\mathbf{x}) \| p(y))])\]

Interpretation:

  • \(p(y|\mathbf{x})\) should be peaked (high quality)

  • \(p(y) = \mathbb{E}_{\mathbf{x}}[p(y|\mathbf{x})]\) should be uniform (diversity)

Limitations:

  • Only for class-conditional generation

  • Doesn’t measure realism well

  • Can be fooled by adversarial examples

FrΓ©chet Inception Distance (FID)ΒΆ

\[\text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2})\]

Where:

  • \((\mu_r, \Sigma_r)\): Mean and covariance of real image features

  • \((\mu_g, \Sigma_g)\): Mean and covariance of generated image features

  • Features: Inception-v3 pool3 layer

Benefits:

  • More robust than IS

  • Measures both quality and diversity

  • Correlates better with human judgment

Lower is better (unlike IS where higher is better).

LPIPS (Learned Perceptual Image Patch Similarity)ΒΆ

\[\text{LPIPS}(\mathbf{x}, \mathbf{y}) = \sum_l \frac{1}{H_l W_l} \|\mathbf{F}_l(\mathbf{x}) - \mathbf{F}_l(\mathbf{y})\|^2\]

Where \(\mathbf{F}_l\) are features from layer \(l\) of VGG or AlexNet.

Use Case: Paired image translation (pix2pix) - measures perceptual similarity.

User StudiesΒΆ

Amazon Mechanical Turk (AMT):

  • Real vs. Fake discrimination

  • Fooling rate (percentage of times humans are fooled)

  • Preference tests (A vs. B)

Limitations: Expensive, time-consuming, subjective.

Applications and Case StudiesΒΆ

1. Medical ImagingΒΆ

  • MRI Synthesis: T1 β†’ T2, CT β†’ MRI

  • Segmentation: Image β†’ organ masks

  • Data Augmentation: Generate synthetic medical images

Challenge: Limited paired data, high stakes (safety-critical).

2. Fashion and RetailΒΆ

  • Virtual Try-On: Transfer clothing to different models

  • Attribute Editing: Change color, pattern, style

  • Sketch-to-Product: Design sketches β†’ realistic renders

3. Video Game DevelopmentΒΆ

  • Texture Generation: Semantic labels β†’ textures

  • Scene Synthesis: Layout β†’ photorealistic environments

  • Character Creation: Concept art β†’ 3D models

4. Autonomous DrivingΒΆ

  • Domain Adaptation: Sim2Real transfer (synthetic β†’ real)

  • Data Augmentation: Generate rare scenarios (rain, night)

  • Sensor Fusion: LiDAR β†’ camera, camera β†’ segmentation

5. Creative ToolsΒΆ

  • Photo Editing: Style transfer, colorization, inpainting

  • Art Generation: Text/sketch β†’ artwork

  • Animation: Keyframe interpolation, motion transfer

Theoretical AnalysisΒΆ

Nash Equilibrium in cGANΒΆ

Definition: \((G^*, D^*)\) is a Nash equilibrium if: $\(G^* = \arg\min_G V(G, D^*), \quad D^* = \arg\max_D V(G^*, D)\)$

Optimal Discriminator (given \(G\) and condition \(y\)): $\(D^*(\mathbf{x}, y) = \frac{p_{\text{data}}(\mathbf{x} | y)}{p_{\text{data}}(\mathbf{x} | y) + p_G(\mathbf{x} | y)}\)$

Global Optimum: \(G^*\) such that \(p_G(\mathbf{x} | y) = p_{\text{data}}(\mathbf{x} | y)\) for all \(y\).

Conditional vs. Unconditional LearningΒΆ

Theorem (Mirza & Osindero, 2014): Conditioning can improve learning efficiency.

Intuition:

  • Unconditional GAN must learn \(p(\mathbf{x})\) (entire data distribution)

  • Conditional GAN learns \(p(\mathbf{x} | y)\) (simpler conditional distributions)

  • Decomposition: \(p(\mathbf{x}) = \sum_y p(\mathbf{x} | y) p(y)\)

Benefit: Each conditional distribution can be simpler than full distribution.

Limitations and ChallengesΒΆ

1. Paired Data RequirementΒΆ

  • Pix2pix: Needs aligned pairs \((x, y)\)

  • Solutions: CycleGAN (unpaired), self-supervision

2. Mode CollapseΒΆ

  • Generator ignores condition, produces single output for all inputs

  • Mitigation: Minibatch discrimination, unrolled GAN, spectral normalization

3. Condition LeakageΒΆ

  • Generator copies condition without transformation

  • Example: In pix2pix, generator may copy input edges to output

  • Solution: Careful architecture design, regularization

4. Evaluation DifficultyΒΆ

  • No single metric captures all aspects (quality, diversity, conditioning)

  • Best Practice: Combine multiple metrics + user studies

5. Computational CostΒΆ

  • High-resolution generation requires large models

  • Solutions: Progressive growing, multi-scale training, efficient architectures

Future DirectionsΒΆ

1. Diffusion Models for Conditional GenerationΒΆ

  • Classifier-free guidance: \(\nabla \log p(\mathbf{x} | y) \approx \nabla \log p(\mathbf{x}) + s \cdot \nabla \log p(y | \mathbf{x})\)

  • Superior sample quality compared to GANs

  • Examples: DALL-E 2, Imagen, Stable Diffusion

2. Transformer-Based ArchitecturesΒΆ

  • Attention mechanisms for long-range dependencies

  • Examples: VQGAN + CLIP, Parti, Phenaki

3. 3D-Aware Conditional GenerationΒΆ

  • Generate 3D-consistent images from text/labels

  • Examples: EG3D, GET3D

4. Semantic ControlΒΆ

  • Fine-grained attribute editing

  • Disentangled representations

  • Compositional generation

5. Few-Shot and Zero-Shot Conditional GenerationΒΆ

  • Generate novel categories with few examples

  • Leverage pre-trained vision-language models (CLIP)

ConclusionΒΆ

Conditional GANs extend the standard GAN framework with explicit control, enabling a wide range of applications from image-to-image translation to text-to-image synthesis. Key innovations include:

  1. Architectural: U-Net, PatchGAN, projection discriminator, SPADE

  2. Training: Spectral normalization, hinge loss, feature matching, gradient penalty

  3. Loss Functions: Cycle consistency, reconstruction loss, perceptual loss

Evolution:

  • cGAN (2014) β†’ Pix2pix (2017) β†’ CycleGAN (2017) β†’ Pix2pixHD (2018) β†’ SPADE (2019)

  • Text-to-Image: StackGAN (2017) β†’ AttnGAN (2018) β†’ DALL-E (2021) β†’ Stable Diffusion (2022)

Current State: While GANs pioneered conditional generation, diffusion models now dominate due to superior sample quality and training stability. However, GANs remain relevant for:

  • Real-time generation (faster inference)

  • Specific applications (super-resolution, style transfer)

  • Hybrid approaches (combining GANs with diffusion)

The principles of conditional generationβ€”controllability, disentanglement, semantic consistencyβ€”continue to drive innovation in generative AI.

"""
Conditional GAN - Complete Implementation

This implementation includes:
1. Basic cGAN with label conditioning
2. Pix2Pix with U-Net generator and PatchGAN discriminator
3. Projection discriminator
4. AC-GAN (Auxiliary Classifier GAN)
5. Training utilities and visualization
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List
from dataclasses import dataclass


# ============================================================================
# 1. BASIC CONDITIONAL GAN (LABEL CONDITIONING)
# ============================================================================

class ConditionalGenerator(nn.Module):
    """
    Basic conditional generator for class-conditional generation (e.g., MNIST).
    
    Architecture:
    - Concatenate noise z and label embedding
    - MLP with multiple hidden layers
    - Output image of specified size
    
    Args:
        latent_dim: Dimension of noise vector z
        num_classes: Number of classes
        embed_dim: Dimension of label embedding
        img_channels: Number of image channels
        img_size: Size of output image (assumes square)
    """
    def __init__(
        self,
        latent_dim: int = 100,
        num_classes: int = 10,
        embed_dim: int = 100,
        img_channels: int = 1,
        img_size: int = 28
    ):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_channels = img_channels
        self.img_size = img_size
        
        # Label embedding
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        
        # Generator network
        input_dim = latent_dim + embed_dim
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, img_channels * img_size * img_size),
            nn.Tanh()
        )
        
    def forward(self, z: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Generate images conditioned on labels.
        
        Args:
            z: Noise vectors (batch_size, latent_dim)
            labels: Class labels (batch_size,)
            
        Returns:
            Generated images (batch_size, img_channels, img_size, img_size)
        """
        # Embed labels
        label_embed = self.label_embedding(labels)  # (batch_size, embed_dim)
        
        # Concatenate noise and label embedding
        gen_input = torch.cat([z, label_embed], dim=1)  # (batch_size, latent_dim + embed_dim)
        
        # Generate image
        img = self.model(gen_input)
        img = img.view(img.size(0), self.img_channels, self.img_size, self.img_size)
        
        return img


class ConditionalDiscriminator(nn.Module):
    """
    Basic conditional discriminator.
    
    Architecture:
    - Concatenate image and label embedding
    - MLP with multiple hidden layers
    - Output real/fake probability
    
    Args:
        num_classes: Number of classes
        embed_dim: Dimension of label embedding
        img_channels: Number of image channels
        img_size: Size of input image
    """
    def __init__(
        self,
        num_classes: int = 10,
        embed_dim: int = 100,
        img_channels: int = 1,
        img_size: int = 28
    ):
        super().__init__()
        
        self.num_classes = num_classes
        self.img_channels = img_channels
        self.img_size = img_size
        
        # Label embedding
        self.label_embedding = nn.Embedding(num_classes, embed_dim)
        
        # Discriminator network
        input_dim = img_channels * img_size * img_size + embed_dim
        self.model = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Discriminate real/fake conditioned on labels.
        
        Args:
            img: Images (batch_size, img_channels, img_size, img_size)
            labels: Class labels (batch_size,)
            
        Returns:
            Real/fake probabilities (batch_size, 1)
        """
        # Flatten image
        img_flat = img.view(img.size(0), -1)  # (batch_size, img_channels * img_size * img_size)
        
        # Embed labels
        label_embed = self.label_embedding(labels)  # (batch_size, embed_dim)
        
        # Concatenate image and label
        disc_input = torch.cat([img_flat, label_embed], dim=1)
        
        # Discriminate
        validity = self.model(disc_input)
        
        return validity


# ============================================================================
# 2. PROJECTION DISCRIMINATOR
# ============================================================================

class ProjectionDiscriminator(nn.Module):
    """
    Projection discriminator for class-conditional GAN.
    
    More efficient than concatenation:
    D(x, y) = Οƒ(w^T Ο†(x) + e(y)^T Ο†(x))
    
    Args:
        num_classes: Number of classes
        img_channels: Number of image channels
        img_size: Image size
        hidden_dim: Hidden dimension
    """
    def __init__(
        self,
        num_classes: int = 10,
        img_channels: int = 1,
        img_size: int = 28,
        hidden_dim: int = 512
    ):
        super().__init__()
        
        # Feature extractor Ο†(x)
        self.feature_extractor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(img_channels * img_size * img_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, hidden_dim),
            nn.LeakyReLU(0.2)
        )
        
        # Real/fake classifier: w^T Ο†(x)
        self.classifier = nn.Linear(hidden_dim, 1)
        
        # Projection head: e(y)^T Ο†(x)
        self.projection = nn.Embedding(num_classes, hidden_dim)
        
    def forward(self, img: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Args:
            img: Images (batch_size, img_channels, img_size, img_size)
            labels: Class labels (batch_size,)
            
        Returns:
            Logits (batch_size, 1)
        """
        # Extract features
        features = self.feature_extractor(img)  # (batch_size, hidden_dim)
        
        # Real/fake prediction
        out = self.classifier(features)  # (batch_size, 1)
        
        # Projection (conditioning)
        label_embed = self.projection(labels)  # (batch_size, hidden_dim)
        proj = torch.sum(features * label_embed, dim=1, keepdim=True)  # (batch_size, 1)
        
        # Combine
        logits = out + proj
        
        return torch.sigmoid(logits)


# ============================================================================
# 3. AUXILIARY CLASSIFIER GAN (AC-GAN)
# ============================================================================

class ACGANGenerator(nn.Module):
    """
    AC-GAN Generator.
    
    Similar to basic cGAN but trained with auxiliary classification loss.
    """
    def __init__(
        self,
        latent_dim: int = 100,
        num_classes: int = 10,
        img_channels: int = 1,
        img_size: int = 28
    ):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Label embedding
        self.label_embedding = nn.Embedding(num_classes, num_classes)
        
        # Generator
        input_dim = latent_dim + num_classes
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            
            nn.Linear(1024, img_channels * img_size * img_size),
            nn.Tanh()
        )
        
        self.img_channels = img_channels
        self.img_size = img_size
        
    def forward(self, z: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        label_embed = self.label_embedding(labels)
        gen_input = torch.cat([z, label_embed], dim=1)
        img = self.model(gen_input)
        return img.view(img.size(0), self.img_channels, self.img_size, self.img_size)


class ACGANDiscriminator(nn.Module):
    """
    AC-GAN Discriminator with auxiliary classifier.
    
    Outputs:
    - Real/fake probability
    - Class probabilities
    """
    def __init__(
        self,
        num_classes: int = 10,
        img_channels: int = 1,
        img_size: int = 28
    ):
        super().__init__()
        
        # Shared feature extractor
        self.features = nn.Sequential(
            nn.Flatten(),
            nn.Linear(img_channels * img_size * img_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        
        # Real/fake head
        self.adv_head = nn.Sequential(
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
        
        # Auxiliary classifier head
        self.aux_head = nn.Sequential(
            nn.Linear(512, num_classes),
            nn.Softmax(dim=1)
        )
        
    def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            img: Images (batch_size, img_channels, img_size, img_size)
            
        Returns:
            validity: Real/fake probabilities (batch_size, 1)
            class_probs: Class probabilities (batch_size, num_classes)
        """
        features = self.features(img)
        validity = self.adv_head(features)
        class_probs = self.aux_head(features)
        
        return validity, class_probs


# ============================================================================
# 4. PIX2PIX COMPONENTS
# ============================================================================

class UNetGenerator(nn.Module):
    """
    U-Net generator for pix2pix.
    
    Architecture:
    - Encoder: Downsample input image
    - Decoder: Upsample with skip connections from encoder
    - Output: Image of same size as input
    
    Args:
        in_channels: Input image channels
        out_channels: Output image channels
        features: Base number of features (doubled at each layer)
    """
    def __init__(self, in_channels: int = 3, out_channels: int = 3, features: int = 64):
        super().__init__()
        
        # Encoder (downsampling)
        self.enc1 = self._block(in_channels, features, normalize=False)  # 256 β†’ 128
        self.enc2 = self._block(features, features * 2)                   # 128 β†’ 64
        self.enc3 = self._block(features * 2, features * 4)               # 64 β†’ 32
        self.enc4 = self._block(features * 4, features * 8)               # 32 β†’ 16
        self.enc5 = self._block(features * 8, features * 8)               # 16 β†’ 8
        self.enc6 = self._block(features * 8, features * 8)               # 8 β†’ 4
        self.enc7 = self._block(features * 8, features * 8)               # 4 β†’ 2
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1),
            nn.ReLU()
        )  # 2 β†’ 1
        
        # Decoder (upsampling with skip connections)
        self.dec1 = self._up_block(features * 8, features * 8, dropout=True)
        self.dec2 = self._up_block(features * 16, features * 8, dropout=True)
        self.dec3 = self._up_block(features * 16, features * 8, dropout=True)
        self.dec4 = self._up_block(features * 16, features * 8)
        self.dec5 = self._up_block(features * 16, features * 4)
        self.dec6 = self._up_block(features * 8, features * 2)
        self.dec7 = self._up_block(features * 4, features)
        
        # Final layer
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features * 2, out_channels, 4, 2, 1),
            nn.Tanh()
        )
        
    def _block(self, in_channels: int, out_channels: int, normalize: bool = True):
        """Encoder block: Conv β†’ BatchNorm β†’ LeakyReLU β†’ Downsample"""
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)
    
    def _up_block(self, in_channels: int, out_channels: int, dropout: bool = False):
        """Decoder block: ConvTranspose β†’ BatchNorm β†’ Dropout β†’ ReLU β†’ Upsample"""
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        ]
        if dropout:
            layers.append(nn.Dropout(0.5))
        layers.append(nn.ReLU())
        return nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input image (batch_size, in_channels, H, W)
            
        Returns:
            Generated image (batch_size, out_channels, H, W)
        """
        # Encoder with skip connections
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)
        
        # Bottleneck
        b = self.bottleneck(e7)
        
        # Decoder with skip connections (concatenate)
        d1 = self.dec1(b)
        d1 = torch.cat([d1, e7], dim=1)
        
        d2 = self.dec2(d1)
        d2 = torch.cat([d2, e6], dim=1)
        
        d3 = self.dec3(d2)
        d3 = torch.cat([d3, e5], dim=1)
        
        d4 = self.dec4(d3)
        d4 = torch.cat([d4, e4], dim=1)
        
        d5 = self.dec5(d4)
        d5 = torch.cat([d5, e3], dim=1)
        
        d6 = self.dec6(d5)
        d6 = torch.cat([d6, e2], dim=1)
        
        d7 = self.dec7(d6)
        d7 = torch.cat([d7, e1], dim=1)
        
        # Final output
        return self.final(d7)


class PatchGANDiscriminator(nn.Module):
    """
    PatchGAN discriminator for pix2pix.
    
    Classifies whether NΓ—N patches are real or fake (70Γ—70 receptive field).
    
    Args:
        in_channels: Input channels (img_channels * 2 for paired images)
        features: Base number of features
    """
    def __init__(self, in_channels: int = 6, features: int = 64):
        super().__init__()
        
        self.model = nn.Sequential(
            # C64: No BatchNorm in first layer
            nn.Conv2d(in_channels, features, 4, 2, 1),
            nn.LeakyReLU(0.2),
            
            # C128
            nn.Conv2d(features, features * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 2),
            nn.LeakyReLU(0.2),
            
            # C256
            nn.Conv2d(features * 2, features * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features * 4),
            nn.LeakyReLU(0.2),
            
            # C512
            nn.Conv2d(features * 4, features * 8, 4, 1, 1, bias=False),
            nn.BatchNorm2d(features * 8),
            nn.LeakyReLU(0.2),
            
            # Output: 1 channel (real/fake for each patch)
            nn.Conv2d(features * 8, 1, 4, 1, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img_input: torch.Tensor, img_target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            img_input: Input image (batch_size, C, H, W)
            img_target: Target image (batch_size, C, H, W)
            
        Returns:
            Patch predictions (batch_size, 1, H/16, W/16)
        """
        # Concatenate input and target
        x = torch.cat([img_input, img_target], dim=1)
        return self.model(x)


# ============================================================================
# 5. TRAINING UTILITIES
# ============================================================================

def train_cgan_step(
    generator: nn.Module,
    discriminator: nn.Module,
    real_imgs: torch.Tensor,
    labels: torch.Tensor,
    optimizer_g: torch.optim.Optimizer,
    optimizer_d: torch.optim.Optimizer,
    latent_dim: int = 100,
    device: str = 'cpu'
) -> Tuple[float, float]:
    """
    Single training step for basic conditional GAN.
    
    Returns:
        d_loss: Discriminator loss
        g_loss: Generator loss
    """
    batch_size = real_imgs.size(0)
    
    # Real and fake labels
    real_labels = torch.ones(batch_size, 1, device=device)
    fake_labels = torch.zeros(batch_size, 1, device=device)
    
    # ---------------------
    # Train Discriminator
    # ---------------------
    optimizer_d.zero_grad()
    
    # Real images
    real_validity = discriminator(real_imgs, labels)
    d_real_loss = F.binary_cross_entropy(real_validity, real_labels)
    
    # Fake images
    z = torch.randn(batch_size, latent_dim, device=device)
    fake_imgs = generator(z, labels)
    fake_validity = discriminator(fake_imgs.detach(), labels)
    d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels)
    
    # Total discriminator loss
    d_loss = (d_real_loss + d_fake_loss) / 2
    d_loss.backward()
    optimizer_d.step()
    
    # -----------------
    # Train Generator
    # -----------------
    optimizer_g.zero_grad()
    
    # Generate fake images and fool discriminator
    z = torch.randn(batch_size, latent_dim, device=device)
    fake_imgs = generator(z, labels)
    fake_validity = discriminator(fake_imgs, labels)
    
    g_loss = F.binary_cross_entropy(fake_validity, real_labels)
    g_loss.backward()
    optimizer_g.step()
    
    return d_loss.item(), g_loss.item()


def train_acgan_step(
    generator: nn.Module,
    discriminator: nn.Module,
    real_imgs: torch.Tensor,
    labels: torch.Tensor,
    optimizer_g: torch.optim.Optimizer,
    optimizer_d: torch.optim.Optimizer,
    latent_dim: int = 100,
    device: str = 'cpu'
) -> Tuple[float, float]:
    """
    Single training step for AC-GAN.
    
    Returns:
        d_loss: Discriminator loss (adversarial + classification)
        g_loss: Generator loss (adversarial + classification)
    """
    batch_size = real_imgs.size(0)
    
    real_labels = torch.ones(batch_size, 1, device=device)
    fake_labels = torch.zeros(batch_size, 1, device=device)
    
    # ---------------------
    # Train Discriminator
    # ---------------------
    optimizer_d.zero_grad()
    
    # Real images
    real_validity, real_class = discriminator(real_imgs)
    d_real_loss = F.binary_cross_entropy(real_validity, real_labels)
    d_real_cls_loss = F.cross_entropy(real_class, labels)
    
    # Fake images
    z = torch.randn(batch_size, latent_dim, device=device)
    fake_imgs = generator(z, labels)
    fake_validity, fake_class = discriminator(fake_imgs.detach())
    d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels)
    d_fake_cls_loss = F.cross_entropy(fake_class, labels)
    
    # Total discriminator loss
    d_loss = (d_real_loss + d_fake_loss) / 2 + (d_real_cls_loss + d_fake_cls_loss) / 2
    d_loss.backward()
    optimizer_d.step()
    
    # -----------------
    # Train Generator
    # -----------------
    optimizer_g.zero_grad()
    
    z = torch.randn(batch_size, latent_dim, device=device)
    fake_imgs = generator(z, labels)
    fake_validity, fake_class = discriminator(fake_imgs)
    
    g_loss = F.binary_cross_entropy(fake_validity, real_labels) + F.cross_entropy(fake_class, labels)
    g_loss.backward()
    optimizer_g.step()
    
    return d_loss.item(), g_loss.item()


def train_pix2pix_step(
    generator: nn.Module,
    discriminator: nn.Module,
    input_imgs: torch.Tensor,
    target_imgs: torch.Tensor,
    optimizer_g: torch.optim.Optimizer,
    optimizer_d: torch.optim.Optimizer,
    lambda_l1: float = 100.0,
    device: str = 'cpu'
) -> Tuple[float, float]:
    """
    Single training step for pix2pix.
    
    Args:
        lambda_l1: Weight for L1 reconstruction loss
        
    Returns:
        d_loss: Discriminator loss
        g_loss: Generator loss (adversarial + L1)
    """
    batch_size = input_imgs.size(0)
    
    real_labels = torch.ones(batch_size, 1, device=device)
    fake_labels = torch.zeros(batch_size, 1, device=device)
    
    # ---------------------
    # Train Discriminator
    # ---------------------
    optimizer_d.zero_grad()
    
    # Real pair
    real_validity = discriminator(input_imgs, target_imgs)
    d_real_loss = F.binary_cross_entropy(real_validity, real_labels.expand_as(real_validity))
    
    # Fake pair
    fake_imgs = generator(input_imgs)
    fake_validity = discriminator(input_imgs, fake_imgs.detach())
    d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels.expand_as(fake_validity))
    
    d_loss = (d_real_loss + d_fake_loss) / 2
    d_loss.backward()
    optimizer_d.step()
    
    # -----------------
    # Train Generator
    # -----------------
    optimizer_g.zero_grad()
    
    fake_imgs = generator(input_imgs)
    fake_validity = discriminator(input_imgs, fake_imgs)
    
    # Adversarial loss
    g_adv_loss = F.binary_cross_entropy(fake_validity, real_labels.expand_as(fake_validity))
    
    # L1 reconstruction loss
    g_l1_loss = F.l1_loss(fake_imgs, target_imgs)
    
    # Total generator loss
    g_loss = g_adv_loss + lambda_l1 * g_l1_loss
    g_loss.backward()
    optimizer_g.step()
    
    return d_loss.item(), g_loss.item()


# ============================================================================
# 6. VISUALIZATION
# ============================================================================

def visualize_conditional_generation(
    generator: nn.Module,
    num_classes: int,
    latent_dim: int,
    device: str = 'cpu',
    samples_per_class: int = 5
):
    """
    Visualize conditional generation for all classes.
    
    Args:
        generator: Trained generator
        num_classes: Number of classes
        latent_dim: Latent dimension
        device: Device
        samples_per_class: Samples to generate per class
    """
    generator.eval()
    
    fig, axes = plt.subplots(num_classes, samples_per_class, figsize=(samples_per_class * 2, num_classes * 2))
    
    with torch.no_grad():
        for class_idx in range(num_classes):
            # Generate samples for this class
            z = torch.randn(samples_per_class, latent_dim, device=device)
            labels = torch.full((samples_per_class,), class_idx, dtype=torch.long, device=device)
            
            fake_imgs = generator(z, labels)
            fake_imgs = fake_imgs.cpu()
            
            # Plot samples
            for sample_idx in range(samples_per_class):
                ax = axes[class_idx, sample_idx] if num_classes > 1 else axes[sample_idx]
                img = fake_imgs[sample_idx].squeeze()
                
                if img.dim() == 3:  # Color image
                    img = img.permute(1, 2, 0)
                    img = (img + 1) / 2  # Denormalize from [-1, 1] to [0, 1]
                
                ax.imshow(img, cmap='gray' if img.dim() == 2 else None)
                ax.axis('off')
                
                if sample_idx == 0:
                    ax.set_title(f'Class {class_idx}', fontsize=10)
    
    plt.tight_layout()
    plt.suptitle('Conditional Generation (All Classes)', y=1.01, fontsize=14, fontweight='bold')
    plt.show()


def visualize_pix2pix_results(
    generator: nn.Module,
    input_imgs: torch.Tensor,
    target_imgs: torch.Tensor,
    num_samples: int = 4
):
    """
    Visualize pix2pix results (input β†’ generated β†’ target).
    
    Args:
        generator: Trained generator
        input_imgs: Input images (batch_size, C, H, W)
        target_imgs: Target images (batch_size, C, H, W)
        num_samples: Number of samples to visualize
    """
    generator.eval()
    
    with torch.no_grad():
        generated_imgs = generator(input_imgs[:num_samples])
    
    # Move to CPU and denormalize
    input_imgs = input_imgs[:num_samples].cpu()
    generated_imgs = generated_imgs.cpu()
    target_imgs = target_imgs[:num_samples].cpu()
    
    # Denormalize from [-1, 1] to [0, 1]
    input_imgs = (input_imgs + 1) / 2
    generated_imgs = (generated_imgs + 1) / 2
    target_imgs = (target_imgs + 1) / 2
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(9, num_samples * 3))
    
    for i in range(num_samples):
        # Input
        axes[i, 0].imshow(input_imgs[i].permute(1, 2, 0))
        axes[i, 0].set_title('Input')
        axes[i, 0].axis('off')
        
        # Generated
        axes[i, 1].imshow(generated_imgs[i].permute(1, 2, 0))
        axes[i, 1].set_title('Generated')
        axes[i, 1].axis('off')
        
        # Target
        axes[i, 2].imshow(target_imgs[i].permute(1, 2, 0))
        axes[i, 2].set_title('Target')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()


# ============================================================================
# 7. DEMONSTRATION
# ============================================================================

def demo_conditional_gan():
    """Demonstrate conditional GAN components."""
    
    print("=" * 80)
    print("Conditional GAN Demonstration")
    print("=" * 80)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nDevice: {device}")
    
    # 1. Basic cGAN
    print("\n1. Basic Conditional GAN (Label Conditioning)")
    print("-" * 40)
    
    latent_dim = 100
    num_classes = 10
    img_size = 28
    
    generator = ConditionalGenerator(
        latent_dim=latent_dim,
        num_classes=num_classes,
        img_size=img_size
    ).to(device)
    
    discriminator = ConditionalDiscriminator(
        num_classes=num_classes,
        img_size=img_size
    ).to(device)
    
    print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
    
    # Test forward pass
    z = torch.randn(4, latent_dim, device=device)
    labels = torch.tensor([0, 1, 2, 3], device=device)
    
    fake_imgs = generator(z, labels)
    validity = discriminator(fake_imgs, labels)
    
    print(f"\nGenerated images shape: {fake_imgs.shape}")
    print(f"Discriminator output shape: {validity.shape}")
    print(f"Discriminator output range: [{validity.min():.3f}, {validity.max():.3f}]")
    
    # 2. Projection Discriminator
    print("\n2. Projection Discriminator")
    print("-" * 40)
    
    proj_disc = ProjectionDiscriminator(num_classes=num_classes).to(device)
    
    print(f"Projection discriminator parameters: {sum(p.numel() for p in proj_disc.parameters()):,}")
    
    validity_proj = proj_disc(fake_imgs, labels)
    print(f"Projection discriminator output shape: {validity_proj.shape}")
    
    # 3. AC-GAN
    print("\n3. Auxiliary Classifier GAN (AC-GAN)")
    print("-" * 40)
    
    acgan_gen = ACGANGenerator(latent_dim=latent_dim, num_classes=num_classes).to(device)
    acgan_disc = ACGANDiscriminator(num_classes=num_classes).to(device)
    
    print(f"AC-GAN Generator parameters: {sum(p.numel() for p in acgan_gen.parameters()):,}")
    print(f"AC-GAN Discriminator parameters: {sum(p.numel() for p in acgan_disc.parameters()):,}")
    
    fake_imgs_ac = acgan_gen(z, labels)
    validity_ac, class_probs_ac = acgan_disc(fake_imgs_ac)
    
    print(f"\nAC-GAN validity shape: {validity_ac.shape}")
    print(f"AC-GAN class probabilities shape: {class_probs_ac.shape}")
    print(f"Class probabilities sum: {class_probs_ac.sum(dim=1)}")
    
    # 4. Pix2Pix (U-Net + PatchGAN)
    print("\n4. Pix2Pix (Image-to-Image Translation)")
    print("-" * 40)
    
    unet_gen = UNetGenerator(in_channels=3, out_channels=3, features=64).to(device)
    patch_disc = PatchGANDiscriminator(in_channels=6, features=64).to(device)
    
    print(f"U-Net Generator parameters: {sum(p.numel() for p in unet_gen.parameters()):,}")
    print(f"PatchGAN Discriminator parameters: {sum(p.numel() for p in patch_disc.parameters()):,}")
    
    # Test with larger images
    input_img = torch.randn(2, 3, 256, 256, device=device)
    generated_img = unet_gen(input_img)
    patch_validity = patch_disc(input_img, generated_img)
    
    print(f"\nInput image shape: {input_img.shape}")
    print(f"Generated image shape: {generated_img.shape}")
    print(f"Patch predictions shape: {patch_validity.shape}")
    print(f"Number of patches: {patch_validity.numel() // patch_validity.size(0)}")


def demo_training_comparison():
    """Compare different conditioning strategies."""
    
    print("\n" + "=" * 80)
    print("Conditioning Strategy Comparison")
    print("=" * 80)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Create sample data
    batch_size = 16
    latent_dim = 100
    num_classes = 10
    img_size = 28
    
    # 1. Concatenation (Basic cGAN)
    print("\n1. Concatenation Strategy")
    print("-" * 40)
    
    gen_concat = ConditionalGenerator(latent_dim, num_classes).to(device)
    disc_concat = ConditionalDiscriminator(num_classes).to(device)
    
    z = torch.randn(batch_size, latent_dim, device=device)
    labels = torch.randint(0, num_classes, (batch_size,), device=device)
    real_imgs = torch.randn(batch_size, 1, img_size, img_size, device=device)
    
    fake_imgs = gen_concat(z, labels)
    validity = disc_concat(fake_imgs, labels)
    
    print(f"Generated images: {fake_imgs.shape}")
    print(f"Validity: {validity.shape}")
    
    # 2. Projection
    print("\n2. Projection Strategy")
    print("-" * 40)
    
    disc_proj = ProjectionDiscriminator(num_classes).to(device)
    validity_proj = disc_proj(fake_imgs, labels)
    
    print(f"Projection validity: {validity_proj.shape}")
    
    # Parameter comparison
    print("\n3. Parameter Efficiency")
    print("-" * 40)
    
    concat_params = sum(p.numel() for p in disc_concat.parameters())
    proj_params = sum(p.numel() for p in disc_proj.parameters())
    
    print(f"Concatenation discriminator: {concat_params:,} parameters")
    print(f"Projection discriminator: {proj_params:,} parameters")
    print(f"Reduction: {(1 - proj_params / concat_params) * 100:.1f}%")


# Run demonstrations
if __name__ == "__main__":
    print("Starting Conditional GAN demonstrations...\n")
    
    # Main demonstration
    demo_conditional_gan()
    
    # Training comparison
    demo_training_comparison()
    
    print("\n" + "=" * 80)
    print("Conditional GAN Implementation Complete!")
    print("=" * 80)
    print("\nKey Components Implemented:")
    print("βœ“ Basic cGAN with label conditioning")
    print("βœ“ Projection discriminator (efficient conditioning)")
    print("βœ“ AC-GAN (auxiliary classifier)")
    print("βœ“ U-Net generator for image-to-image translation")
    print("βœ“ PatchGAN discriminator")
    print("βœ“ Training utilities for all variants")
    print("βœ“ Visualization functions")
    print("\nExtensions to explore:")
    print("- CycleGAN for unpaired translation")
    print("- StarGAN for multi-domain translation")
    print("- SPADE for semantic image synthesis")
    print("- Text-to-image with attention (AttnGAN)")
    print("- High-resolution synthesis (Pix2PixHD)")
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')

1. Motivation: Controlled GenerationΒΆ

Standard GAN LimitationΒΆ

Vanilla GAN: $\(\min_G \max_D \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[\log(1 - D(G(z)))]\)$

Problem: No control over what gets generated!

  • Random samples from learned distribution

  • Cannot specify desired attributes

Conditional GAN Solution (Mirza & Osindero, 2014)ΒΆ

Add conditioning variable \(y\) (class label, attributes, etc.):

\[\min_G \max_D \mathbb{E}_{x, y}[\log D(x, y)] + \mathbb{E}_{z, y}[\log(1 - D(G(z, y), y))]\]

Key changes:

  • Generator: \(G(z, y)\) produces \(x\) conditioned on \(y\)

  • Discriminator: \(D(x, y)\) checks if \(x\) matches condition \(y\)

ApplicationsΒΆ

  1. Class-conditional: Generate specific digit/object

  2. Text-to-image: Generate image from description

  3. Image-to-image: Style transfer, super-resolution

  4. Attribute control: Age, expression, hair color

πŸ“š Reference Materials:

2. Architecture DesignΒΆ

Conditioning MethodsΒΆ

1. Concatenation:

  • Embed \(y\) and concatenate with \(z\) or \(x\)

  • Simple, widely used

2. Conditional Batch Normalization:

  • Modulate BN parameters based on \(y\)

  • More powerful for complex conditioning

3. Projection:

  • Project \(y\) embedding into discriminator layers

  • Better gradient flow

Class-Conditional ArchitectureΒΆ

For one-hot encoded labels \(y \in \{0, 1\}^C\):

Generator: $\(h = [z; \text{Embed}(y)]\)\( \)\(x = G(h)\)$

Discriminator: $\(h = [x; \text{Embed}(y)]\)\( \)\(\text{score} = D(h)\)$

Conditional GAN for MNISTΒΆ

A Conditional GAN (cGAN) augments both the generator and discriminator with class label information, enabling controlled generation of specific digit classes. The generator receives a noise vector \(z\) concatenated with a one-hot encoded label \(y\), and learns the conditional distribution \(p(x | y)\). The discriminator similarly receives the label as additional input and must judge whether a (image, label) pair is real or fabricated. Conditioning on labels transforms the GAN from an unconditional density estimator into a powerful tool for class-specific synthesis, which is the foundation for applications like text-to-image generation and data augmentation.

class ConditionalGenerator(nn.Module):
    """Generator conditioned on class labels."""
    def __init__(self, latent_dim=100, n_classes=10, img_size=28):
        super().__init__()
        self.latent_dim = latent_dim
        self.n_classes = n_classes
        self.img_size = img_size
        
        # Label embedding
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        # Generator network
        self.model = nn.Sequential(
            nn.Linear(latent_dim + n_classes, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            
            nn.Linear(1024, img_size * img_size),
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        # Embed labels
        label_input = self.label_emb(labels)
        
        # Concatenate noise and labels
        gen_input = torch.cat([z, label_input], dim=1)
        
        # Generate image
        img = self.model(gen_input)
        img = img.view(img.size(0), 1, self.img_size, self.img_size)
        return img

class ConditionalDiscriminator(nn.Module):
    """Discriminator conditioned on class labels."""
    def __init__(self, n_classes=10, img_size=28):
        super().__init__()
        self.n_classes = n_classes
        self.img_size = img_size
        
        # Label embedding
        self.label_emb = nn.Embedding(n_classes, n_classes)
        
        # Discriminator network
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size + n_classes, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, labels):
        # Flatten image
        img_flat = img.view(img.size(0), -1)
        
        # Embed labels
        label_input = self.label_emb(labels)
        
        # Concatenate image and labels
        disc_input = torch.cat([img_flat, label_input], dim=1)
        
        # Discriminate
        validity = self.model(disc_input)
        return validity

# Initialize models
latent_dim = 100
n_classes = 10

generator = ConditionalGenerator(latent_dim, n_classes).to(device)
discriminator = ConditionalDiscriminator(n_classes).to(device)

print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

Training Conditional GANΒΆ

The training loop alternates between discriminator and generator updates, following the standard GAN protocol but with labels passed to both networks. The discriminator loss encourages it to accept real (image, label) pairs and reject fake ones, while the generator loss pushes it to produce images that the discriminator accepts for a given label. Tracking per-class generation quality during training (e.g., via FID or visual inspection per digit) helps ensure the model does not mode-collapse to only generating a subset of classes. Label smoothing and spectral normalization are common stabilization techniques used in practice.

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Scale to [-1, 1]
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)

# Optimizers
lr = 0.0002
beta1 = 0.5

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Loss
adversarial_loss = nn.BCELoss()

# Training
n_epochs = 30
history = {'D_loss': [], 'G_loss': [], 'D_real': [], 'D_fake': []}

for epoch in range(n_epochs):
    epoch_D_loss = 0
    epoch_G_loss = 0
    epoch_D_real = 0
    epoch_D_fake = 0
    
    for i, (imgs, labels) in enumerate(train_loader):
        batch_size = imgs.size(0)
        
        # Labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        imgs = imgs.to(device)
        labels = labels.to(device)
        
        # -----------------
        # Train Discriminator
        # -----------------
        optimizer_D.zero_grad()
        
        # Real images
        real_validity = discriminator(imgs, labels)
        d_real_loss = adversarial_loss(real_validity, real_labels)
        
        # Fake images
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z, labels)
        fake_validity = discriminator(fake_imgs.detach(), labels)
        d_fake_loss = adversarial_loss(fake_validity, fake_labels)
        
        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        
        # -----------------
        # Train Generator
        # -----------------
        optimizer_G.zero_grad()
        
        # Generate images
        z = torch.randn(batch_size, latent_dim).to(device)
        gen_imgs = generator(z, labels)
        
        # Generator loss
        validity = discriminator(gen_imgs, labels)
        g_loss = adversarial_loss(validity, real_labels)
        
        g_loss.backward()
        optimizer_G.step()
        
        # Statistics
        epoch_D_loss += d_loss.item()
        epoch_G_loss += g_loss.item()
        epoch_D_real += real_validity.mean().item()
        epoch_D_fake += fake_validity.mean().item()
    
    # Record epoch stats
    n_batches = len(train_loader)
    history['D_loss'].append(epoch_D_loss / n_batches)
    history['G_loss'].append(epoch_G_loss / n_batches)
    history['D_real'].append(epoch_D_real / n_batches)
    history['D_fake'].append(epoch_D_fake / n_batches)
    
    print(f"Epoch [{epoch+1}/{n_epochs}] "
          f"D_loss: {history['D_loss'][-1]:.4f} "
          f"G_loss: {history['G_loss'][-1]:.4f} "
          f"D(real): {history['D_real'][-1]:.4f} "
          f"D(fake): {history['D_fake'][-1]:.4f}")

print("\nTraining complete!")
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history['D_loss'], label='Discriminator', linewidth=2)
axes[0].plot(history['G_loss'], label='Generator', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss', fontsize=13)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history['D_real'], label='D(real)', linewidth=2)
axes[1].plot(history['D_fake'], label='D(fake)', linewidth=2)
axes[1].axhline(y=0.5, color='r', linestyle='--', label='Equilibrium')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Discriminator Output', fontsize=12)
axes[1].set_title('Discriminator Performance', fontsize=13)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

Class-Conditional GenerationΒΆ

With a trained cGAN, generating images for a specific class is as simple as fixing the label input and sampling different noise vectors. Generating a grid of samples per class provides a quick visual assessment of generation quality and diversity. A well-trained cGAN should produce distinct, recognizable digits for each class with meaningful intra-class variation driven by the noise vector. This capability makes cGANs directly applicable to data augmentation (generating more samples for underrepresented classes) and controllable content creation.

generator.eval()

# Generate specific digits
with torch.no_grad():
    # Generate one sample for each class (0-9)
    z = torch.randn(10, latent_dim).to(device)
    labels = torch.arange(0, 10).to(device)
    
    generated = generator(z, labels).cpu()
    
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))
    for i in range(10):
        img = generated[i].squeeze()
        img = (img + 1) / 2  # Denormalize from [-1,1] to [0,1]
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'Class {i}', fontsize=11)
        axes[i].axis('off')
    plt.suptitle('Class-Conditional Generation', fontsize=14)
    plt.tight_layout()
    plt.show()

# Generate multiple samples per class
print("\nMultiple samples per class:")
with torch.no_grad():
    n_samples = 5
    fig, axes = plt.subplots(10, n_samples, figsize=(12, 20))
    
    for digit in range(10):
        z = torch.randn(n_samples, latent_dim).to(device)
        labels = torch.full((n_samples,), digit, dtype=torch.long).to(device)
        
        samples = generator(z, labels).cpu()
        
        for i in range(n_samples):
            img = samples[i].squeeze()
            img = (img + 1) / 2
            axes[digit, i].imshow(img, cmap='gray')
            axes[digit, i].axis('off')
            if i == 0:
                axes[digit, i].set_ylabel(f'{digit}', fontsize=13, rotation=0, labelpad=20)
    
    plt.suptitle('Diversity within Each Class', fontsize=14)
    plt.tight_layout()
    plt.show()

Latent Space Interpolation with Fixed ClassΒΆ

By fixing the class label and smoothly interpolating the noise vector \(z\) between two random endpoints, we can visualize the diversity captured within a single class. Smooth transitions indicate that the generator has learned a continuous mapping from latent space to image space, without sharp discontinuities or mode collapse. Comparing interpolations across different classes also reveals whether the model allocates similar latent capacity to each digit, which is a useful diagnostic for balanced generation quality.

# Interpolate between two latent codes for same class
with torch.no_grad():
    digit = 7
    n_steps = 10
    
    # Two random starting points
    z1 = torch.randn(1, latent_dim).to(device)
    z2 = torch.randn(1, latent_dim).to(device)
    
    # Interpolation
    alphas = np.linspace(0, 1, n_steps)
    
    fig, axes = plt.subplots(1, n_steps, figsize=(15, 2))
    
    for i, alpha in enumerate(alphas):
        z_interp = (1 - alpha) * z1 + alpha * z2
        label = torch.tensor([digit]).to(device)
        
        img = generator(z_interp, label).cpu().squeeze()
        img = (img + 1) / 2
        
        axes[i].imshow(img, cmap='gray')
        axes[i].axis('off')
    
    plt.suptitle(f'Latent Interpolation for Digit {digit}', fontsize=14)
    plt.tight_layout()
    plt.show()

# Fixed noise, vary class
print("\nFixed noise, varying class:")
with torch.no_grad():
    z_fixed = torch.randn(1, latent_dim).to(device)
    
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))
    
    for digit in range(10):
        label = torch.tensor([digit]).to(device)
        img = generator(z_fixed, label).cpu().squeeze()
        img = (img + 1) / 2
        
        axes[digit].imshow(img, cmap='gray')
        axes[digit].set_title(f'{digit}', fontsize=11)
        axes[digit].axis('off')
    
    plt.suptitle('Same Noise, Different Classes', fontsize=14)
    plt.tight_layout()
    plt.show()

print("\nObservation: Style (from z) is consistent, but class changes")

SummaryΒΆ

Key Innovations:ΒΆ

  1. Conditional generation: Control output via conditioning variable \(y\)

  2. Flexible conditioning: Class labels, attributes, images, text

  3. Discriminator conditioning: Ensures generated samples match condition

  4. Disentanglement: Separate noise (style) from condition (content)

Architecture Patterns:ΒΆ

Concatenation:

  • Generator: \(G([z; \text{emb}(y)])\)

  • Discriminator: \(D([x; \text{emb}(y)])\)

Projection (Miyato & Koyama, 2018): $\(D(x, y) = \phi(x)^T W + \psi(y)^T \phi(x)\)$

Advanced Variants:ΒΆ

  • AC-GAN (Odena et al., 2017): Auxiliary classifier in discriminator

  • Pix2Pix (Isola et al., 2017): Image-to-image translation

  • StarGAN (Choi et al., 2018): Multi-domain translation

  • BigGAN (Brock et al., 2019): Class-conditional at scale

Applications:ΒΆ

  1. Controllable generation: Specify what to generate

  2. Data augmentation: Generate labeled samples

  3. Transfer learning: Condition on domains

  4. Creative tools: Interactive generation

Training Tips:ΒΆ

  • Balance discriminator and generator updates

  • Use label smoothing (soft labels)

  • Spectral normalization for stability

  • Progressive growing for high resolution

Next Steps:ΒΆ

  • 08_stylegan.ipynb - Style-based generation

  • 09_pix2pix.ipynb - Image-to-image translation

  • 01_gan_mathematics.ipynb - GAN theory review