InfoGAN: Information-Maximizing Generative Adversarial Networks - Comprehensive TheoryΒΆ

IntroductionΒΆ

InfoGAN is an unsupervised learning method that extends the standard GAN framework to learn disentangled and interpretable representations. By maximizing mutual information between a subset of latent variables and generated samples, InfoGAN discovers meaningful semantic features (e.g., digit identity, rotation, width) without requiring labeled data.

Key Innovation: Automatic discovery of interpretable latent codes through information-theoretic regularization.

Applications:

  • Unsupervised disentanglement learning

  • Controllable generation (manipulate specific attributes)

  • Feature extraction

  • Semi-supervised learning

Background: Disentangled RepresentationsΒΆ

What is Disentanglement?ΒΆ

A disentangled representation factorizes the underlying explanatory factors of variation in data: $\(\mathbf{z} = (\mathbf{z}_1, \mathbf{z}_2, \ldots, \mathbf{z}_K)\)$

Where each \(\mathbf{z}_i\) controls a distinct, interpretable property (e.g., shape, color, size).

Example (MNIST):

  • \(c_1\): Digit identity (0-9)

  • \(c_2\): Rotation angle

  • \(c_3\): Stroke thickness

Benefits:

  1. Interpretability: Each dimension has semantic meaning

  2. Transfer Learning: Disentangled features generalize better

  3. Controllability: Manipulate specific properties independently

  4. Sample Efficiency: Simpler representations require less data

Why Standard GANs Fail at DisentanglementΒΆ

Problem: In vanilla GAN, generator input \(\mathbf{z} \sim p(\mathbf{z})\) is typically: $\(\mathbf{z} \sim \mathcal{N}(0, I)\)$

All dimensions are treated equallyβ€”no structure encouraging interpretability.

Result: Generator learns entangled representations where individual dimensions don’t correspond to meaningful factors.

Information Theory PrimerΒΆ

EntropyΒΆ

Shannon Entropy measures uncertainty in a random variable: $\(H(X) = -\mathbb{E}_{x \sim p(x)}[\log p(x)] = -\sum_x p(x) \log p(x)\)$

Interpretation:

  • High entropy β†’ high uncertainty (uniform distribution)

  • Low entropy β†’ low uncertainty (peaked distribution)

Mutual InformationΒΆ

Definition: Mutual information \(I(X; Y)\) measures how much knowing \(X\) reduces uncertainty about \(Y\):

\[I(X; Y) = H(X) - H(X | Y) = H(Y) - H(Y | X)\]
\[I(X; Y) = \mathbb{E}_{x,y \sim p(x,y)}\left[\log \frac{p(x,y)}{p(x)p(y)}\right] = D_{KL}(p(x,y) \| p(x)p(y))\]

Properties:

  1. Non-negative: \(I(X; Y) \geq 0\)

  2. Symmetric: \(I(X; Y) = I(Y; X)\)

  3. Zero iff independent: \(I(X; Y) = 0 \iff p(x,y) = p(x)p(y)\)

  4. Upper bound: \(I(X; Y) \leq \min(H(X), H(Y))\)

Intuition:

  • \(I(X; Y) = 0\): \(X\) and \(Y\) are independent (knowing \(X\) tells nothing about \(Y\))

  • \(I(X; Y)\) large: \(X\) and \(Y\) are highly dependent

Conditional EntropyΒΆ

\[H(X | Y) = \mathbb{E}_{y \sim p(y)}[H(X | Y = y)] = -\mathbb{E}_{x,y \sim p(x,y)}[\log p(x | y)]\]

Interpretation: Average uncertainty in \(X\) given knowledge of \(Y\).

Chain Rule: $\(H(X, Y) = H(X) + H(Y | X) = H(Y) + H(X | Y)\)$

InfoGAN: Core IdeaΒΆ

Standard GAN ObjectiveΒΆ

\[\min_G \max_D V(D, G) = \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}}[\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{z} \sim p_z}[\log(1 - D(G(\mathbf{z})))]\]

Problem: No incentive for \(G\) to use all dimensions of \(\mathbf{z}\) meaningfully.

InfoGAN: Latent Code DecompositionΒΆ

Decompose generator input into: $\(\mathbf{z} = (\mathbf{c}, \mathbf{n})\)$

Where:

  • \(\mathbf{c} = (c_1, c_2, \ldots, c_L)\): Latent codes (structured, interpretable)

    • Can be discrete (categorical) or continuous

    • Example: \(c_1 \sim \text{Cat}(K=10)\) for digit identity

    • Example: \(c_2 \sim \text{Unif}(-1, 1)\) for rotation

  • \(\mathbf{n}\): Noise (incompressible, random)

    • Traditional noise: \(\mathbf{n} \sim \mathcal{N}(0, I)\)

Mutual Information ObjectiveΒΆ

Goal: Maximize mutual information between latent codes \(\mathbf{c}\) and generated data \(G(\mathbf{c}, \mathbf{n})\):

\[\max_{G} I(\mathbf{c}; G(\mathbf{c}, \mathbf{n}))\]

Interpretation:

  • Force generator to use latent codes \(\mathbf{c}\)

  • Prevent \(G\) from ignoring \(\mathbf{c}\) (information bottleneck)

  • \(\mathbf{c}\) should be recoverable from \(G(\mathbf{c}, \mathbf{n})\)

Full InfoGAN ObjectiveΒΆ

\[\min_{G,Q} \max_D V_{\text{InfoGAN}}(D, G, Q) = V(D, G) - \lambda I(\mathbf{c}; G(\mathbf{c}, \mathbf{n}))\]

Where:

  • \(V(D, G)\): Standard GAN objective

  • \(I(\mathbf{c}; G(\mathbf{c}, \mathbf{n}))\): Mutual information term

  • \(\lambda\): Hyperparameter (typically \(\lambda = 1\))

  • \(Q\): Auxiliary network approximating \(p(\mathbf{c} | \mathbf{x})\)

Challenge: Computing \(I(\mathbf{c}; G(\mathbf{c}, \mathbf{n}))\) directly is intractable.

Variational Mutual Information MaximizationΒΆ

The ChallengeΒΆ

\[I(\mathbf{c}; G(\mathbf{c}, \mathbf{n})) = H(\mathbf{c}) - H(\mathbf{c} | G(\mathbf{c}, \mathbf{n}))\]
\[= \mathbb{E}_{\mathbf{x} \sim G}[\mathbb{E}_{\mathbf{c}' \sim p(\mathbf{c} | \mathbf{x})}[\log p(\mathbf{c}' | \mathbf{x})]] + H(\mathbf{c})\]

Problem: Posterior \(p(\mathbf{c} | \mathbf{x})\) is intractable.

Variational Lower BoundΒΆ

Introduce auxiliary distribution \(Q(\mathbf{c} | \mathbf{x})\) to approximate \(p(\mathbf{c} | \mathbf{x})\).

Lemma (Variational Lower Bound): $\(I(\mathbf{c}; G(\mathbf{c}, \mathbf{n})) \geq \mathbb{E}_{\mathbf{c} \sim p(\mathbf{c}), \mathbf{x} \sim G(\mathbf{c}, \mathbf{n})}[\log Q(\mathbf{c} | \mathbf{x})] + H(\mathbf{c}) = L_I(G, Q)\)$

Proof:

(1)ΒΆ\[\begin{align} I(\mathbf{c}; G(\mathbf{c}, \mathbf{n})) &= \mathbb{E}_{\mathbf{x} \sim G}[D_{KL}(p(\mathbf{c} | \mathbf{x}) \| p(\mathbf{c}))] \\ &= \mathbb{E}_{\mathbf{x}}[D_{KL}(p(\mathbf{c} | \mathbf{x}) \| Q(\mathbf{c} | \mathbf{x})) + D_{KL}(Q(\mathbf{c} | \mathbf{x}) \| p(\mathbf{c}))] \\ &\geq \mathbb{E}_{\mathbf{x}}[D_{KL}(Q(\mathbf{c} | \mathbf{x}) \| p(\mathbf{c}))] \\ &= \mathbb{E}_{\mathbf{x}, \mathbf{c} \sim p(\mathbf{c})}[\log Q(\mathbf{c} | \mathbf{x})] + H(\mathbf{c}) \end{align}\]

Inequality: Follows from non-negativity of KL divergence.

Equality: When \(Q(\mathbf{c} | \mathbf{x}) = p(\mathbf{c} | \mathbf{x})\).

Practical ObjectiveΒΆ

\[\min_{G,Q} \max_D V_{\text{InfoGAN}} = V(D, G) - \lambda L_I(G, Q)\]

Where the mutual information lower bound is: $\(L_I(G, Q) = \mathbb{E}_{\mathbf{c} \sim p(\mathbf{c}), \mathbf{n} \sim p(\mathbf{n})}[\log Q(\mathbf{c} | G(\mathbf{c}, \mathbf{n}))] + H(\mathbf{c})\)$

Simplification: Since \(H(\mathbf{c})\) is constant w.r.t. \(G\) and \(Q\): $\(\max_{G, Q} L_I = \max_{G,Q} \mathbb{E}_{\mathbf{c}, \mathbf{n}}[\log Q(\mathbf{c} | G(\mathbf{c}, \mathbf{n}))]\)$

Implementation DetailsΒΆ

Q-Network ArchitectureΒΆ

The auxiliary network \(Q\) shares parameters with discriminator \(D\) for efficiency:

  1. Shared Layers: \(D\) and \(Q\) share convolutional feature extractor

  2. Separate Heads:

    • \(D\) head: Output \(D(\mathbf{x}) \in [0, 1]\) (real/fake)

    • \(Q\) head: Output \(Q(\mathbf{c} | \mathbf{x})\) (latent code distribution)

Benefits:

  • Parameter efficiency (shared representation)

  • Minimal computational overhead

  • Natural integration with discriminator

Latent Code TypesΒΆ

1. Discrete Categorical (\(c_i \sim \text{Cat}(K)\)):

  • Prior: \(p(c_i = k) = 1/K\) (uniform over \(K\) classes)

  • \(Q\) parameterization: Softmax over \(K\) logits $\(Q(c_i | \mathbf{x}) = \text{Softmax}(f_Q(\mathbf{x}))\)$

  • Loss: Cross-entropy $\(\mathcal{L}_{\text{cat}} = -\mathbb{E}[\log Q(c_i | \mathbf{x})]\)$

2. Continuous (\(c_j \sim \mathcal{N}(0, 1)\) or \(\text{Unif}(-1, 1)\)):

  • Prior: Gaussian \(p(c_j) = \mathcal{N}(0, 1)\) or uniform

  • \(Q\) parameterization: Gaussian with learned mean and variance $\(Q(c_j | \mathbf{x}) = \mathcal{N}(\mu_Q(\mathbf{x}), \sigma_Q^2(\mathbf{x}))\)$

  • Loss: Negative log-likelihood (equivalent to MSE for Gaussian) $\(\mathcal{L}_{\text{cont}} = \frac{1}{2}\mathbb{E}\left[\frac{(c_j - \mu_Q(\mathbf{x}))^2}{\sigma_Q^2(\mathbf{x})} + \log \sigma_Q^2(\mathbf{x})\right]\)$

Simplification (fixed variance): If we fix \(\sigma_Q^2 = 1\), the loss simplifies to MSE: $\(\mathcal{L}_{\text{cont}} = \mathbb{E}[\|c_j - \mu_Q(\mathbf{x})\|^2]\)$

Training AlgorithmΒΆ

For each iteration:

  1. Sample minibatch:

    • Latent codes: \(\mathbf{c} \sim p(\mathbf{c})\)

    • Noise: \(\mathbf{n} \sim \mathcal{N}(0, I)\)

    • Real data: \(\mathbf{x} \sim p_{\text{data}}\)

  2. Generate fake samples: $\(\mathbf{x}_{\text{fake}} = G(\mathbf{c}, \mathbf{n})\)$

  3. Update Discriminator \(D\): $\(\max_D \mathbb{E}_{\mathbf{x}}[\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{c}, \mathbf{n}}[\log(1 - D(G(\mathbf{c}, \mathbf{n})))]\)$

  4. Update Generator \(G\) and \(Q\) jointly: $\(\min_{G,Q} \mathbb{E}_{\mathbf{c}, \mathbf{n}}[\log(1 - D(G(\mathbf{c}, \mathbf{n})))] - \lambda \mathbb{E}_{\mathbf{c}, \mathbf{n}}[\log Q(\mathbf{c} | G(\mathbf{c}, \mathbf{n}))]\)$

Key Point: The mutual information loss acts as a regularizer, penalizing \(G\) if \(Q\) cannot recover \(\mathbf{c}\) from \(G(\mathbf{c}, \mathbf{n})\).

Theoretical AnalysisΒΆ

Why Mutual Information Encourages Disentanglement?ΒΆ

Information Bottleneck Principle:

  1. Maximize \(I(\mathbf{c}; \mathbf{x})\): Generator must use \(\mathbf{c}\) to create \(\mathbf{x}\)

  2. Structured prior \(p(\mathbf{c})\): Factorized prior encourages independence $\(p(\mathbf{c}) = \prod_{i=1}^L p(c_i)\)$

Result: To maximize \(I(\mathbf{c}; \mathbf{x})\) under factorized prior, each \(c_i\) specializes to different aspects of \(\mathbf{x}\).

Informal Argument:

  • If \(c_1\) and \(c_2\) encode the same information, \(I(\mathbf{c}; \mathbf{x})\) is not maximized

  • Redundancy reduces effective capacity

  • Maximizing MI incentivizes efficient, non-redundant use of \(c_i\)

Connection to Information BottleneckΒΆ

Information Bottleneck Method (Tishby et al.): $\(\max_{p(\tilde{X} | X)} I(\tilde{X}; Y) - \beta I(\tilde{X}; X)\)$

Where:

  • \(\tilde{X}\): Compressed representation

  • \(Y\): Target variable

  • \(\beta\): Trade-off between compression and prediction

InfoGAN Perspective:

  • \(\mathbf{c}\): Compressed representation

  • Maximize \(I(\mathbf{c}; G(\mathbf{c}, \mathbf{n}))\) subject to factorized prior

  • Factorization acts as implicit compression constraint

Relation to Ξ²-VAEΒΆ

Ξ²-VAE Objective: $\(\mathcal{L}_{\beta-\text{VAE}} = \mathbb{E}[\log p(\mathbf{x} | \mathbf{z})] - \beta D_{KL}(q(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z}))\)$

Similarities:

  • Both encourage disentanglement via information-theoretic constraints

  • InfoGAN: Maximize \(I(\mathbf{c}; \mathbf{x})\)

  • Ξ²-VAE: Constrain \(I(\mathbf{z}; \mathbf{x})\) via KL penalty

Differences:

  • Ξ²-VAE: Variational inference framework (encoder-decoder)

  • InfoGAN: GAN framework (adversarial training)

  • Ξ²-VAE: Explicit reconstruction loss

  • InfoGAN: Implicit via adversarial loss

Experiments and ResultsΒΆ

MNISTΒΆ

Latent Code Configuration:

  • \(c_1 \sim \text{Cat}(K=10)\): Discrete (10 classes)

  • \(c_2, c_3 \sim \text{Unif}(-1, 1)\): Continuous (2 dimensions)

  • \(\mathbf{n} \sim \mathcal{N}(0, I_{62})\): Noise (62 dimensions)

Discovered Meanings:

  • \(c_1\): Digit identity (0-9)

  • \(c_2\): Rotation angle (continuous variation)

  • \(c_3\): Stroke thickness (width)

Key Observation: Without any labels, InfoGAN discovers semantic factors!

3D FacesΒΆ

Latent Codes:

  • \(c_1 \sim \text{Cat}(K=10)\): Discrete

  • \(c_2, c_3, c_4, c_5 \sim \text{Unif}(-1, 1)\): Continuous

Discovered Meanings:

  • \(c_1\): Face identity / person

  • \(c_2\): Elevation (vertical rotation)

  • \(c_3\): Azimuth (horizontal rotation)

  • \(c_4\): Lighting direction

  • \(c_5\): (Less interpretable, possibly expression)

3D ChairsΒΆ

Discovered:

  • Rotation angle

  • Chair type

  • Width

Quantitative Evaluation of DisentanglementΒΆ

1. Mutual Information Gap (MIG)ΒΆ

\[\text{MIG} = \frac{1}{K} \sum_{k=1}^K \frac{I(v_k; c_{j(k)}) - \max_{j \neq j(k)} I(v_k; c_j)}{H(v_k)}\]

Where:

  • \(v_k\): Ground-truth factor

  • \(c_j\): Learned latent code

  • \(j(k)\): Best-matching code for factor \(k\)

Interpretation: Measures how much more information the best code provides about each factor compared to the next-best code.

2. SAP Score (Separated Attribute Predictability)ΒΆ

  1. Train linear classifier to predict each ground-truth factor from each latent code

  2. Compute difference between top two scores

  3. Average over factors

Higher score β†’ better disentanglement.

3. DCI (Disentanglement, Completeness, Informativeness)ΒΆ

Three metrics:

  • Disentanglement: Each code controls single factor

  • Completeness: Each factor controlled by single code

  • Informativeness: Total predictiveness of factors

4. Higgins MetricΒΆ

Classifier-based: Train classifier on latent traversals, measure accuracy.

Challenges with MetricsΒΆ

Problem: Different metrics can disagree on ranking of models.

Recent Work: Focus on downstream task performance rather than disentanglement metrics.

Variants and ExtensionsΒΆ

1. ss-InfoGAN (Semi-Supervised)ΒΆ

Combine InfoGAN with small amount of labeled data:

\[\mathcal{L}_{\text{ss-InfoGAN}} = \mathcal{L}_{\text{InfoGAN}} + \lambda_{\text{sup}} \mathcal{L}_{\text{supervised}}\]

Where: $\(\mathcal{L}_{\text{supervised}} = -\mathbb{E}_{(\mathbf{x}, y) \sim p_{\text{labeled}}}[\log Q(y | \mathbf{x})]\)$

Benefit: Even small amount of labels significantly improves semantic alignment.

2. InfoGAN-CR (Contrastive Regularization)ΒΆ

Problem: InfoGAN can sometimes ignore some latent codes.

Solution: Add contrastive loss: $\(\mathcal{L}_{\text{CR}} = -\log \frac{\exp(s(\mathbf{x}, \mathbf{c}))}{\exp(s(\mathbf{x}, \mathbf{c})) + \sum_{\mathbf{c}'} \exp(s(\mathbf{x}, \mathbf{c}'))}\)$

Where \(s(\mathbf{x}, \mathbf{c})\) is similarity score (e.g., cosine similarity).

3. AC-InfoGAN (Auxiliary Classifier)ΒΆ

Combine InfoGAN with AC-GAN:

  • Auxiliary classifier for labeled data

  • InfoGAN objective for unsupervised codes

4. Variational InfoGANΒΆ

Use variational inference for more flexible \(Q\) distributions: $\(Q(\mathbf{c} | \mathbf{x}) = \mathcal{N}(\mu_Q(\mathbf{x}), \text{diag}(\sigma_Q^2(\mathbf{x})))\)$

With reparameterization trick for backpropagation.

LimitationsΒΆ

1. No Guarantee of Semantic AlignmentΒΆ

Problem: Discovered codes may not align with human-interpretable factors.

Example: On complex datasets (ImageNet), codes may capture texture, lighting, rather than object identity.

Mitigation: Use semi-supervised variants (ss-InfoGAN).

2. Mode CollapseΒΆ

InfoGAN inherits GAN training instabilities:

  • Mode collapse

  • Training instability

  • Hyperparameter sensitivity

Mitigation: Spectral normalization, progressive growing, Wasserstein loss.

3. Limited to Factorized PriorsΒΆ

Assumption: \(p(\mathbf{c}) = \prod_i p(c_i)\) (independence)

Problem: Real-world factors may be correlated (e.g., age and wrinkles).

Extension: Structured priors (e.g., hierarchical).

4. Difficulty with High-Dimensional DataΒΆ

Challenge: On complex datasets (e.g., CelebA-HQ), harder to achieve clean disentanglement.

Recent Approach: Combine with StyleGAN (StyleGAN2-InfoGAN).

5. Evaluation ChallengesΒΆ

Problem: No ground-truth factors for real-world data.

Solutions:

  • Qualitative inspection (latent traversals)

  • Downstream task performance

  • User studies

ApplicationsΒΆ

1. Controllable Image GenerationΒΆ

  • Generate images with specific attributes

  • Interactive editing (e.g., change pose, lighting, expression)

2. Feature ExtractionΒΆ

  • Use learned codes \(\mathbf{c}\) as features for downstream tasks

  • Often more interpretable than standard GAN latent space

3. Data AugmentationΒΆ

  • Generate new samples by varying \(\mathbf{c}\)

  • Useful when data is limited

4. Anomaly DetectionΒΆ

  • Model normal data with disentangled codes

  • Detect anomalies as samples with unusual code distributions

5. Fairness in MLΒΆ

  • Disentangle sensitive attributes (gender, race)

  • Enable fair decision-making by controlling for biases

Comparison with Other Disentanglement MethodsΒΆ

InfoGAN vs. Ξ²-VAEΒΆ

Aspect

InfoGAN

Ξ²-VAE

Framework

GAN (adversarial)

VAE (variational)

Objective

Maximize \(I(\mathbf{c}; \mathbf{x})\)

Constrain \(I(\mathbf{z}; \mathbf{x})\) via KL

Sample Quality

Higher (GANs generally better)

Lower (VAEs tend to blur)

Training Stability

Lower (GAN instability)

Higher (VAE more stable)

Reconstruction

No explicit reconstruction

Explicit reconstruction loss

Inference

No encoder (generate only)

Has encoder (both directions)

InfoGAN vs. Factor-VAEΒΆ

Factor-VAE: Adds total correlation penalty: $\(\mathcal{L} = \mathcal{L}_{\text{VAE}} + \gamma \cdot TC(q(\mathbf{z}))\)$

Where \(TC\) is total correlation (measures deviation from factorial distribution).

Trade-off: Factor-VAE often achieves better disentanglement metrics than Ξ²-VAE, but InfoGAN has better sample quality.

InfoGAN vs. ControlVAEΒΆ

ControlVAE: Directly controls capacity of information flow with PID controller.

Comparison: ControlVAE more stable than Ξ²-VAE, but InfoGAN has adversarial advantages.

Future DirectionsΒΆ

1. Hierarchical DisentanglementΒΆ

Learn hierarchical structure of factors:

  • Coarse level: Object category

  • Medium level: Pose, color

  • Fine level: Texture, lighting

2. Causal DisentanglementΒΆ

Go beyond statistical independence to causal independence: $\(\mathbf{c}_i = f_i(\text{parents}(\mathbf{c}_i), \epsilon_i)\)$

3. Group-Based DisentanglementΒΆ

Leverage group theory (e.g., symmetries, transformations):

  • Rotation group for pose

  • Translation group for position

4. Disentanglement for VideoΒΆ

Separate:

  • Content (what): object identity

  • Motion (how): dynamics, actions

  • Style (appearance): lighting, texture

5. Multi-Modal DisentanglementΒΆ

Disentangle factors across modalities (image, text, audio):

  • Shared factors (e.g., emotion)

  • Modality-specific factors (e.g., pitch in audio)

6. Intervention-Based LearningΒΆ

Use interventions (counterfactuals) to discover causal factors:

  • β€œWhat if we change pose but not identity?”

Practical TipsΒΆ

Hyperparameter SelectionΒΆ

  • \(\lambda\) (MI weight): Start with \(\lambda = 1.0\), adjust if codes are ignored

  • Number of codes: Start small (2-5), increase gradually

  • Code types: Mix discrete and continuous based on data

  • Learning rates: Use separate LR for \(G\), \(D\), \(Q\) (e.g., \(10^{-4}\))

Architecture ChoicesΒΆ

  • Shared \(D\) and \(Q\): Efficient, works well in practice

  • Q network depth: 1-2 extra layers after shared features

  • Activation: LeakyReLU for \(D\)/\(Q\), Tanh for \(G\) output

DebuggingΒΆ

  • Visualize latent traversals: Fix all codes except one, vary it, generate images

  • Check MI loss: Should be non-zero and decreasing

  • Monitor code usage: Ensure all codes are being used (not mode collapse)

ConclusionΒΆ

InfoGAN elegantly extends GANs with information-theoretic regularization to discover interpretable, disentangled representations in an unsupervised manner. By maximizing mutual information between latent codes and generated data, InfoGAN incentivizes the generator to use codes for distinct, semantic factors.

Key Contributions:

  1. Unsupervised disentanglement: No labels required

  2. Variational lower bound: Tractable mutual information approximation

  3. Efficient implementation: Shared \(D\)-\(Q\) architecture

  4. Empirical success: Discovers meaningful factors on various datasets

Impact:

  • Pioneered unsupervised disentanglement in GANs

  • Inspired numerous follow-up works (Ξ²-VAE, Factor-VAE, ControlVAE)

  • Applications in controllable generation, representation learning, fairness

Current State: While newer methods (StyleGAN, diffusion models) have surpassed InfoGAN in generation quality, the core idea of information-theoretic regularization remains influential in modern disentanglement research.

"""
InfoGAN - Complete Implementation

This implementation includes:
1. Standard GAN components (Generator, Discriminator)
2. Q-Network for approximating P(c|x)
3. Categorical and continuous latent codes
4. Mutual information maximization
5. Training utilities
6. Visualization of disentangled representations
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict
from dataclasses import dataclass


# ============================================================================
# 1. NETWORK ARCHITECTURES
# ============================================================================

class InfoGANGenerator(nn.Module):
    """
    InfoGAN Generator.
    
    Takes noise z and latent codes c as input, generates images.
    
    Args:
        latent_dim: Dimension of noise z
        code_dim: Total dimension of latent codes c
        img_channels: Number of image channels
        img_size: Output image size (assumes square)
    """
    def __init__(
        self,
        latent_dim: int = 62,
        code_dim: int = 12,  # e.g., 10 (categorical) + 2 (continuous)
        img_channels: int = 1,
        img_size: int = 28
    ):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.code_dim = code_dim
        self.img_channels = img_channels
        self.img_size = img_size
        
        input_dim = latent_dim + code_dim
        
        self.model = nn.Sequential(
            # Input: (latent_dim + code_dim)
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            
            nn.Linear(1024, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU()
        )
        
        # Reshape to (batch, 128, 7, 7) for conv layers
        self.conv_layers = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 7x7 -> 14x14
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),  # 14x14 -> 28x28
            nn.Tanh()
        )
        
    def forward(self, z: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """
        Generate images from noise and latent codes.
        
        Args:
            z: Noise (batch_size, latent_dim)
            c: Latent codes (batch_size, code_dim)
            
        Returns:
            Generated images (batch_size, img_channels, img_size, img_size)
        """
        # Concatenate noise and codes
        gen_input = torch.cat([z, c], dim=1)
        
        # MLP layers
        x = self.model(gen_input)
        
        # Reshape for conv layers
        x = x.view(x.size(0), 128, 7, 7)
        
        # Convolutional layers
        img = self.conv_layers(x)
        
        return img


class SharedDiscriminatorQ(nn.Module):
    """
    Shared Discriminator and Q-network.
    
    The discriminator and Q-network share convolutional layers,
    then split into separate heads:
    - D head: Real/fake classification
    - Q head: Latent code prediction
    
    Args:
        img_channels: Number of image channels
        img_size: Input image size
        num_categorical: Number of categorical classes
        num_continuous: Number of continuous codes
    """
    def __init__(
        self,
        img_channels: int = 1,
        img_size: int = 28,
        num_categorical: int = 10,
        num_continuous: int = 2
    ):
        super().__init__()
        
        self.num_categorical = num_categorical
        self.num_continuous = num_continuous
        
        # Shared feature extractor
        self.shared = nn.Sequential(
            nn.Conv2d(img_channels, 64, 4, 2, 1),  # 28x28 -> 14x14
            nn.LeakyReLU(0.1),
            
            nn.Conv2d(64, 128, 4, 2, 1),  # 14x14 -> 7x7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.1)
        )
        
        # Discriminator head (real/fake)
        self.d_head = nn.Sequential(
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
        
        # Q head (predict latent codes)
        # For categorical code: output logits
        # For continuous codes: output mean (and optionally variance)
        
        # Shared Q layer
        self.q_shared = nn.Sequential(
            nn.Linear(1024, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.1)
        )
        
        # Categorical code head (logits)
        self.q_categorical = nn.Linear(128, num_categorical)
        
        # Continuous code head (mean)
        self.q_continuous_mean = nn.Linear(128, num_continuous)
        
        # Continuous code head (log variance) - optional
        # For simplicity, we'll use fixed variance
        # self.q_continuous_logvar = nn.Linear(128, num_continuous)
        
    def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            img: Input images (batch_size, img_channels, img_size, img_size)
            
        Returns:
            d_out: Discriminator output (batch_size, 1) - probability of real
            q_cat: Categorical code logits (batch_size, num_categorical)
            q_cont: Continuous code means (batch_size, num_continuous)
        """
        # Shared features
        features = self.shared(img)
        
        # Discriminator output
        d_out = self.d_head(features)
        
        # Q network features
        q_features = self.q_shared(features)
        
        # Categorical code (logits for cross-entropy)
        q_cat = self.q_categorical(q_features)
        
        # Continuous code (means)
        q_cont = self.q_continuous_mean(q_features)
        
        return d_out, q_cat, q_cont


# ============================================================================
# 2. LATENT CODE SAMPLING
# ============================================================================

def sample_latent_codes(
    batch_size: int,
    num_categorical: int,
    num_continuous: int,
    device: str = 'cpu'
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Sample latent codes from prior distributions.
    
    Args:
        batch_size: Number of samples
        num_categorical: Number of categorical classes
        num_continuous: Number of continuous codes
        device: Device to create tensors on
        
    Returns:
        c_cat_idx: Categorical code indices (batch_size,)
        c_cat_onehot: Categorical code one-hot (batch_size, num_categorical)
        c_cont: Continuous codes (batch_size, num_continuous)
    """
    # Categorical: sample uniformly from K classes
    c_cat_idx = torch.randint(0, num_categorical, (batch_size,), device=device)
    c_cat_onehot = F.one_hot(c_cat_idx, num_classes=num_categorical).float()
    
    # Continuous: sample from Uniform(-1, 1)
    c_cont = torch.rand(batch_size, num_continuous, device=device) * 2 - 1
    
    return c_cat_idx, c_cat_onehot, c_cont


def combine_codes(c_cat_onehot: torch.Tensor, c_cont: torch.Tensor) -> torch.Tensor:
    """
    Combine categorical and continuous codes into single vector.
    
    Args:
        c_cat_onehot: Categorical code one-hot (batch_size, num_categorical)
        c_cont: Continuous codes (batch_size, num_continuous)
        
    Returns:
        Combined codes (batch_size, num_categorical + num_continuous)
    """
    return torch.cat([c_cat_onehot, c_cont], dim=1)


# ============================================================================
# 3. LOSS FUNCTIONS
# ============================================================================

def discriminator_loss(
    d_real: torch.Tensor,
    d_fake: torch.Tensor
) -> torch.Tensor:
    """
    Standard GAN discriminator loss (binary cross-entropy).
    
    Args:
        d_real: Discriminator output on real images (batch_size, 1)
        d_fake: Discriminator output on fake images (batch_size, 1)
        
    Returns:
        Discriminator loss
    """
    real_labels = torch.ones_like(d_real)
    fake_labels = torch.zeros_like(d_fake)
    
    loss_real = F.binary_cross_entropy(d_real, real_labels)
    loss_fake = F.binary_cross_entropy(d_fake, fake_labels)
    
    return loss_real + loss_fake


def generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
    """
    Standard GAN generator loss.
    
    Args:
        d_fake: Discriminator output on fake images (batch_size, 1)
        
    Returns:
        Generator loss
    """
    real_labels = torch.ones_like(d_fake)
    return F.binary_cross_entropy(d_fake, real_labels)


def mutual_information_loss(
    q_cat_logits: torch.Tensor,
    q_cont_mean: torch.Tensor,
    c_cat_idx: torch.Tensor,
    c_cont: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Mutual information loss (lower bound approximation).
    
    For categorical code: Cross-entropy
    For continuous code: MSE (assuming Gaussian with unit variance)
    
    Args:
        q_cat_logits: Q network categorical output (batch_size, num_categorical)
        q_cont_mean: Q network continuous output (batch_size, num_continuous)
        c_cat_idx: True categorical code indices (batch_size,)
        c_cont: True continuous codes (batch_size, num_continuous)
        
    Returns:
        mi_cat: Mutual information loss for categorical code
        mi_cont: Mutual information loss for continuous code
    """
    # Categorical: cross-entropy
    # -E[log Q(c_cat | x)]
    mi_cat = F.cross_entropy(q_cat_logits, c_cat_idx)
    
    # Continuous: MSE (equivalent to Gaussian NLL with fixed variance)
    # -E[log Q(c_cont | x)] ∝ ||c_cont - μ_Q(x)||^2
    mi_cont = F.mse_loss(q_cont_mean, c_cont)
    
    return mi_cat, mi_cont


# ============================================================================
# 4. TRAINING
# ============================================================================

@dataclass
class InfoGANConfig:
    """Configuration for InfoGAN."""
    latent_dim: int = 62
    num_categorical: int = 10
    num_continuous: int = 2
    img_channels: int = 1
    img_size: int = 28
    
    # Training
    lr_g: float = 2e-4
    lr_d: float = 2e-4
    beta1: float = 0.5
    beta2: float = 0.999
    
    # Loss weights
    lambda_cat: float = 1.0  # Weight for categorical MI loss
    lambda_cont: float = 0.1  # Weight for continuous MI loss


def train_infogan_step(
    generator: nn.Module,
    discriminator_q: nn.Module,
    real_imgs: torch.Tensor,
    optimizer_g: optim.Optimizer,
    optimizer_d: optim.Optimizer,
    config: InfoGANConfig,
    device: str = 'cpu'
) -> Dict[str, float]:
    """
    Single training step for InfoGAN.
    
    Args:
        generator: Generator network
        discriminator_q: Shared discriminator and Q network
        real_imgs: Batch of real images
        optimizer_g: Generator optimizer
        optimizer_d: Discriminator optimizer
        config: InfoGAN configuration
        device: Device
        
    Returns:
        Dictionary of losses
    """
    batch_size = real_imgs.size(0)
    
    # ---------------------
    # Train Discriminator
    # ---------------------
    optimizer_d.zero_grad()
    
    # Real images
    d_real, _, _ = discriminator_q(real_imgs)
    
    # Fake images
    z = torch.randn(batch_size, config.latent_dim, device=device)
    c_cat_idx, c_cat_onehot, c_cont = sample_latent_codes(
        batch_size, config.num_categorical, config.num_continuous, device
    )
    c = combine_codes(c_cat_onehot, c_cont)
    
    fake_imgs = generator(z, c)
    d_fake, _, _ = discriminator_q(fake_imgs.detach())
    
    # Discriminator loss
    d_loss = discriminator_loss(d_real, d_fake)
    d_loss.backward()
    optimizer_d.step()
    
    # -------------------------
    # Train Generator and Q
    # -------------------------
    optimizer_g.zero_grad()
    
    # Generate fake images
    z = torch.randn(batch_size, config.latent_dim, device=device)
    c_cat_idx, c_cat_onehot, c_cont = sample_latent_codes(
        batch_size, config.num_categorical, config.num_continuous, device
    )
    c = combine_codes(c_cat_onehot, c_cont)
    
    fake_imgs = generator(z, c)
    d_fake, q_cat_logits, q_cont_mean = discriminator_q(fake_imgs)
    
    # Generator adversarial loss
    g_loss = generator_loss(d_fake)
    
    # Mutual information loss
    mi_cat, mi_cont = mutual_information_loss(q_cat_logits, q_cont_mean, c_cat_idx, c_cont)
    
    # Total generator loss
    total_loss = g_loss + config.lambda_cat * mi_cat + config.lambda_cont * mi_cont
    total_loss.backward()
    optimizer_g.step()
    
    return {
        'd_loss': d_loss.item(),
        'g_loss': g_loss.item(),
        'mi_cat': mi_cat.item(),
        'mi_cont': mi_cont.item(),
        'total_g': total_loss.item()
    }


# ============================================================================
# 5. VISUALIZATION
# ============================================================================

def visualize_categorical_code(
    generator: nn.Module,
    num_categorical: int,
    num_continuous: int,
    latent_dim: int,
    device: str = 'cpu',
    samples_per_class: int = 10
):
    """
    Visualize effect of categorical code by generating samples for each class.
    
    Args:
        generator: Trained generator
        num_categorical: Number of categorical classes
        num_continuous: Number of continuous codes
        latent_dim: Noise dimension
        device: Device
        samples_per_class: Samples per class
    """
    generator.eval()
    
    fig, axes = plt.subplots(num_categorical, samples_per_class, 
                            figsize=(samples_per_class * 1.5, num_categorical * 1.5))
    
    with torch.no_grad():
        for class_idx in range(num_categorical):
            # Fixed noise for each class
            z = torch.randn(samples_per_class, latent_dim, device=device)
            
            # Set categorical code to class_idx
            c_cat_onehot = F.one_hot(
                torch.full((samples_per_class,), class_idx, dtype=torch.long, device=device),
                num_classes=num_categorical
            ).float()
            
            # Random continuous codes
            c_cont = torch.rand(samples_per_class, num_continuous, device=device) * 2 - 1
            
            # Combine codes
            c = combine_codes(c_cat_onehot, c_cont)
            
            # Generate
            imgs = generator(z, c).cpu()
            
            # Plot
            for i in range(samples_per_class):
                ax = axes[class_idx, i] if num_categorical > 1 else axes[i]
                img = imgs[i, 0]  # Grayscale
                ax.imshow(img, cmap='gray')
                ax.axis('off')
                
                if i == 0:
                    ax.set_title(f'Class {class_idx}', fontsize=9)
    
    plt.tight_layout()
    plt.suptitle('Categorical Code Visualization', y=1.01, fontsize=12, fontweight='bold')
    plt.show()


def visualize_continuous_code(
    generator: nn.Module,
    num_categorical: int,
    num_continuous: int,
    latent_dim: int,
    continuous_idx: int = 0,
    device: str = 'cpu',
    num_steps: int = 10,
    num_rows: int = 5
):
    """
    Visualize effect of continuous code by varying it while keeping others fixed.
    
    Args:
        generator: Trained generator
        num_categorical: Number of categorical classes
        num_continuous: Number of continuous codes
        latent_dim: Noise dimension
        continuous_idx: Which continuous code to vary
        device: Device
        num_steps: Number of steps in continuous variation
        num_rows: Number of different samples (rows)
    """
    generator.eval()
    
    fig, axes = plt.subplots(num_rows, num_steps, figsize=(num_steps * 1.5, num_rows * 1.5))
    
    # Range for continuous code
    c_range = torch.linspace(-2, 2, num_steps, device=device)
    
    with torch.no_grad():
        for row in range(num_rows):
            # Fixed noise and codes for this row
            z = torch.randn(1, latent_dim, device=device).repeat(num_steps, 1)
            
            # Random categorical code
            cat_class = torch.randint(0, num_categorical, (1,), device=device).item()
            c_cat_onehot = F.one_hot(
                torch.full((num_steps,), cat_class, dtype=torch.long, device=device),
                num_classes=num_categorical
            ).float()
            
            # Fixed continuous codes
            c_cont = torch.rand(1, num_continuous, device=device).repeat(num_steps, 1)
            
            # Vary the selected continuous code
            c_cont[:, continuous_idx] = c_range
            
            # Combine
            c = combine_codes(c_cat_onehot, c_cont)
            
            # Generate
            imgs = generator(z, c).cpu()
            
            # Plot
            for col in range(num_steps):
                ax = axes[row, col] if num_rows > 1 else axes[col]
                img = imgs[col, 0]
                ax.imshow(img, cmap='gray')
                ax.axis('off')
                
                if row == 0:
                    ax.set_title(f'{c_range[col]:.1f}', fontsize=8)
    
    plt.tight_layout()
    plt.suptitle(f'Continuous Code {continuous_idx} Variation', y=1.01, 
                fontsize=12, fontweight='bold')
    plt.show()


def visualize_2d_continuous_space(
    generator: nn.Module,
    num_categorical: int,
    num_continuous: int,
    latent_dim: int,
    device: str = 'cpu',
    grid_size: int = 10
):
    """
    Visualize 2D continuous code space (if num_continuous >= 2).
    
    Args:
        generator: Trained generator
        num_categorical: Number of categorical classes
        num_continuous: Number of continuous codes
        latent_dim: Noise dimension
        device: Device
        grid_size: Grid resolution
    """
    if num_continuous < 2:
        print("Need at least 2 continuous codes for 2D visualization")
        return
    
    generator.eval()
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
    
    # Create grid
    c1_range = torch.linspace(-2, 2, grid_size, device=device)
    c2_range = torch.linspace(-2, 2, grid_size, device=device)
    
    with torch.no_grad():
        # Fixed noise and categorical code
        z = torch.randn(1, latent_dim, device=device)
        cat_class = torch.randint(0, num_categorical, (1,), device=device).item()
        
        for i, c1 in enumerate(c1_range):
            for j, c2 in enumerate(c2_range):
                # Set continuous codes
                c_cat_onehot = F.one_hot(
                    torch.tensor([cat_class], device=device),
                    num_classes=num_categorical
                ).float()
                
                c_cont = torch.zeros(1, num_continuous, device=device)
                c_cont[0, 0] = c1
                c_cont[0, 1] = c2
                
                c = combine_codes(c_cat_onehot, c_cont)
                
                # Generate
                img = generator(z, c).cpu()[0, 0]
                
                # Plot
                axes[i, j].imshow(img, cmap='gray')
                axes[i, j].axis('off')
    
    plt.tight_layout()
    plt.suptitle('2D Continuous Code Space', y=1.01, fontsize=14, fontweight='bold')
    plt.show()


# ============================================================================
# 6. DEMONSTRATION
# ============================================================================

def demo_infogan():
    """Demonstrate InfoGAN components."""
    
    print("=" * 80)
    print("InfoGAN Demonstration")
    print("=" * 80)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nDevice: {device}")
    
    # Configuration
    config = InfoGANConfig()
    
    print("\n1. Configuration")
    print("-" * 40)
    print(f"Latent dimension (noise): {config.latent_dim}")
    print(f"Categorical codes: {config.num_categorical} classes")
    print(f"Continuous codes: {config.num_continuous} dimensions")
    print(f"Total code dimension: {config.num_categorical + config.num_continuous}")
    
    # Initialize networks
    print("\n2. Network Architecture")
    print("-" * 40)
    
    generator = InfoGANGenerator(
        latent_dim=config.latent_dim,
        code_dim=config.num_categorical + config.num_continuous
    ).to(device)
    
    discriminator_q = SharedDiscriminatorQ(
        num_categorical=config.num_categorical,
        num_continuous=config.num_continuous
    ).to(device)
    
    g_params = sum(p.numel() for p in generator.parameters())
    dq_params = sum(p.numel() for p in discriminator_q.parameters())
    
    print(f"Generator parameters: {g_params:,}")
    print(f"Discriminator + Q parameters: {dq_params:,}")
    print(f"Total parameters: {g_params + dq_params:,}")
    
    # Test forward pass
    print("\n3. Forward Pass")
    print("-" * 40)
    
    batch_size = 4
    z = torch.randn(batch_size, config.latent_dim, device=device)
    c_cat_idx, c_cat_onehot, c_cont = sample_latent_codes(
        batch_size, config.num_categorical, config.num_continuous, device
    )
    c = combine_codes(c_cat_onehot, c_cont)
    
    print(f"Noise z shape: {z.shape}")
    print(f"Categorical code (one-hot) shape: {c_cat_onehot.shape}")
    print(f"Continuous code shape: {c_cont.shape}")
    print(f"Combined code shape: {c.shape}")
    
    # Generate
    fake_imgs = generator(z, c)
    print(f"\nGenerated images shape: {fake_imgs.shape}")
    
    # Discriminate and predict codes
    d_out, q_cat, q_cont = discriminator_q(fake_imgs)
    print(f"Discriminator output shape: {d_out.shape}")
    print(f"Q categorical logits shape: {q_cat.shape}")
    print(f"Q continuous mean shape: {q_cont.shape}")
    
    # Compute losses
    print("\n4. Loss Computation")
    print("-" * 40)
    
    real_imgs = torch.randn_like(fake_imgs)
    d_real, _, _ = discriminator_q(real_imgs)
    
    d_loss = discriminator_loss(d_real, d_out)
    g_loss = generator_loss(d_out)
    mi_cat, mi_cont = mutual_information_loss(q_cat, q_cont, c_cat_idx, c_cont)
    
    print(f"Discriminator loss: {d_loss.item():.4f}")
    print(f"Generator loss: {g_loss.item():.4f}")
    print(f"MI loss (categorical): {mi_cat.item():.4f}")
    print(f"MI loss (continuous): {mi_cont.item():.4f}")
    
    # Code recovery accuracy
    q_cat_pred = q_cat.argmax(dim=1)
    accuracy = (q_cat_pred == c_cat_idx).float().mean()
    print(f"\nCategorical code recovery accuracy: {accuracy.item():.2%}")
    
    print("\n5. Mutual Information Interpretation")
    print("-" * 40)
    print("The mutual information loss encourages:")
    print("βœ“ Generator to use latent codes c")
    print("βœ“ Q network to accurately predict c from generated images")
    print("βœ“ Disentangled representations (each code controls distinct factor)")
    print("\nWithout MI loss:")
    print("βœ— Generator might ignore codes c")
    print("βœ— All variation would be in noise z")
    print("βœ— No interpretable structure")


def demo_code_disentanglement():
    """Demonstrate the concept of code disentanglement."""
    
    print("\n" + "=" * 80)
    print("Code Disentanglement Demonstration")
    print("=" * 80)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    config = InfoGANConfig()
    
    generator = InfoGANGenerator(
        latent_dim=config.latent_dim,
        code_dim=config.num_categorical + config.num_continuous
    ).to(device)
    
    print("\n1. Categorical Code Traversal")
    print("-" * 40)
    print(f"Varying categorical code across {config.num_categorical} classes")
    print("Expected: Different digit identities (if trained on MNIST)")
    
    # Demonstrate categorical traversal
    z = torch.randn(config.num_categorical, config.latent_dim, device=device)
    c_cont_fixed = torch.zeros(config.num_categorical, config.num_continuous, device=device)
    
    imgs_list = []
    for cat_idx in range(config.num_categorical):
        c_cat_onehot = F.one_hot(
            torch.tensor([cat_idx], device=device),
            num_classes=config.num_categorical
        ).float()
        c = combine_codes(c_cat_onehot, c_cont_fixed[:1])
        img = generator(z[:1], c)
        imgs_list.append(img)
    
    print(f"Generated {len(imgs_list)} images with different categorical codes")
    
    print("\n2. Continuous Code Traversal")
    print("-" * 40)
    print(f"Varying continuous code 0 from -2 to +2")
    print("Expected: Smooth variation in rotation/thickness (if trained)")
    
    # Demonstrate continuous traversal
    num_steps = 10
    c_range = torch.linspace(-2, 2, num_steps, device=device)
    z_fixed = torch.randn(1, config.latent_dim, device=device).repeat(num_steps, 1)
    c_cat_fixed = F.one_hot(
        torch.zeros(num_steps, dtype=torch.long, device=device),
        num_classes=config.num_categorical
    ).float()
    
    c_cont = torch.zeros(num_steps, config.num_continuous, device=device)
    c_cont[:, 0] = c_range
    
    c_combined = combine_codes(c_cat_fixed, c_cont)
    imgs_cont = generator(z_fixed, c_combined)
    
    print(f"Generated {num_steps} images with varying continuous code")
    
    print("\n3. Independence of Codes")
    print("-" * 40)
    print("Key property: Changing one code should not affect others")
    print("Example: Varying rotation should not change digit identity")
    print("This is enforced by:")
    print("  - Factorized prior: p(c) = p(c_cat) * p(c_cont)")
    print("  - Mutual information maximization")


# Run demonstrations
if __name__ == "__main__":
    print("Starting InfoGAN demonstrations...\n")
    
    # Main demonstration
    demo_infogan()
    
    # Disentanglement concept
    demo_code_disentanglement()
    
    print("\n" + "=" * 80)
    print("InfoGAN Implementation Complete!")
    print("=" * 80)
    print("\nKey Components Implemented:")
    print("βœ“ Generator with latent code conditioning")
    print("βœ“ Shared Discriminator + Q network")
    print("βœ“ Categorical and continuous latent codes")
    print("βœ“ Mutual information loss (variational lower bound)")
    print("βœ“ Training utilities")
    print("βœ“ Visualization functions")
    print("\nTo train on real data (e.g., MNIST):")
    print("1. Load dataset")
    print("2. Create data loader")
    print("3. Train for ~50 epochs")
    print("4. Visualize learned disentangled codes")
    print("\nExpected discoveries on MNIST:")
    print("- Categorical code: Digit identity (0-9)")
    print("- Continuous code 1: Rotation angle")
    print("- Continuous code 2: Stroke thickness/width")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')

1. Motivation: Disentangled RepresentationsΒΆ

Problem with Standard GANs:ΒΆ

In vanilla GAN, latent code \(z\) is unstructured:

  • No control over generated features

  • Entangled representations (changing \(z\) affects multiple attributes)

  • No interpretability

Info-GAN Solution:ΒΆ

Decompose latent variables: $\(z = (z_c, z_n)\)$

where:

  • \(z_c\) = latent code (categorical or continuous) - should be interpretable

  • \(z_n\) = noise (unstructured, incompressible)

Goal: Learn disentangled \(z_c\) such that each dimension controls a specific factor of variation.

πŸ“š Reference Materials:

2. Information-Theoretic ObjectiveΒΆ

Mutual Information:ΒΆ

Measure statistical dependence between latent code \(c\) and generated data \(G(z, c)\):

\[I(c; G(z, c)) = H(c) - H(c | G(z, c))\]

where:

  • \(H(c)\) = entropy of latent code (fixed by design)

  • \(H(c | G(z, c))\) = conditional entropy (low when \(c\) is recoverable from \(G(z, c)\))

Intuition: If we can recover \(c\) from the generated image, then \(c\) is meaningful.

Info-GAN Objective:ΒΆ

\[\min_G \max_D V(D, G) - \lambda I(c; G(z, c))\]

Standard GAN objective + mutual information maximization.

Challenge:ΒΆ

Computing \(I(c; G(z, c))\) requires \(p(c | x)\), which is intractable.

Solution: Use variational lower bound with auxiliary distribution \(Q(c | x)\).

3. Variational Mutual Information MaximizationΒΆ

Lemma (Variational Lower Bound):ΒΆ

\[I(c; G(z, c)) \geq \mathbb{E}_{c \sim p(c), x \sim G(z, c)} [\log Q(c | x)] + H(c)\]

where \(Q(c | x)\) is an auxiliary network that tries to predict \(c\) from \(x = G(z, c)\).

Proof Sketch:ΒΆ

(2)ΒΆ\[\begin{align} I(c; G(z, c)) &= H(c) - H(c | G(z, c)) \\ &= H(c) + \mathbb{E}_{x \sim G} [\mathbb{E}_{c \sim p(c|x)} [\log p(c | x)]] \\ &= H(c) + \mathbb{E}_{x, c} [\log p(c | x)] \\ &= H(c) + \mathbb{E}_{x, c} \left[ \log \frac{Q(c|x)}{Q(c|x)} p(c | x) \right] \\ &= H(c) + \mathbb{E}_{x, c} [\log Q(c | x)] + \mathbb{E}_{x, c} \left[ \log \frac{p(c|x)}{Q(c|x)} \right] \\ &\geq H(c) + \mathbb{E}_{x, c} [\log Q(c | x)] \quad \text{(KL divergence β‰₯ 0)} \end{align}\]

Practical Objective:ΒΆ

\[L_I(G, Q) = \mathbb{E}_{c \sim p(c), z \sim p(z)} [\log Q(c | G(z, c))]\]

Implementation: Add auxiliary network \(Q\) that shares parameters with discriminator \(D\).

Info-GAN ArchitectureΒΆ

Info-GAN extends the standard GAN by splitting the input to the generator into three components: the standard noise vector \(z\), a categorical code \(c_1\) that captures discrete variation (like digit identity in MNIST), and continuous codes \(c_2\) that capture smooth factors (like rotation angle or stroke width). The key architectural addition is a Q-network (often sharing layers with the discriminator) that predicts the latent codes from a generated image. By maximizing the mutual information \(I(c; G(z, c))\) between the codes and the generated output, Info-GAN forces the generator to use the codes in a semantically meaningful way – learning disentangled representations without any labels.

class Generator(nn.Module):
    def __init__(self, latent_dim, code_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape
        input_dim = latent_dim + code_dim
        
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(256),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    
    def forward(self, z, c):
        """Generate image from noise z and latent code c."""
        x = torch.cat([z, c], dim=1)
        x = self.fc(x)
        return x.view(x.size(0), *self.img_shape)

class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()
        
        self.shared = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2)
        )
        
        # Discriminator head
        self.adv_head = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img):
        x = img.view(img.size(0), -1)
        features = self.shared(x)
        validity = self.adv_head(features)
        return validity, features

class QNetwork(nn.Module):
    """Auxiliary network Q to predict latent code from generated images."""
    def __init__(self, n_categorical, n_continuous):
        super().__init__()
        self.n_categorical = n_categorical
        self.n_continuous = n_continuous
        
        # Takes features from discriminator
        if n_categorical > 0:
            self.categorical = nn.Linear(256, n_categorical)
        if n_continuous > 0:
            self.continuous_mu = nn.Linear(256, n_continuous)
            self.continuous_logvar = nn.Linear(256, n_continuous)
    
    def forward(self, features):
        """Predict latent code distribution from discriminator features."""
        outputs = {}
        
        if self.n_categorical > 0:
            outputs['categorical'] = F.softmax(self.categorical(features), dim=1)
        
        if self.n_continuous > 0:
            outputs['continuous_mu'] = self.continuous_mu(features)
            outputs['continuous_logvar'] = self.continuous_logvar(features)
        
        return outputs

# Test architecture
latent_dim = 62
n_categorical = 10  # e.g., digit identity for MNIST
n_continuous = 2    # e.g., rotation, width
code_dim = n_categorical + n_continuous
img_shape = (1, 28, 28)

G = Generator(latent_dim, code_dim, img_shape).to(device)
D = Discriminator(img_shape).to(device)
Q = QNetwork(n_categorical, n_continuous).to(device)

print(f"Generator parameters: {sum(p.numel() for p in G.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in D.parameters()):,}")
print(f"Q Network parameters: {sum(p.numel() for p in Q.parameters()):,}")

5. Loss FunctionsΒΆ

Discriminator Loss (Standard GAN):ΒΆ

\[L_D = -\mathbb{E}_{x \sim p_{data}} [\log D(x)] - \mathbb{E}_{z, c} [\log(1 - D(G(z, c)))]\]

Generator Loss:ΒΆ

\[L_G = -\mathbb{E}_{z, c} [\log D(G(z, c))] - \lambda \mathbb{E}_{z, c} [\log Q(c | G(z, c))]\]

Two components:

  1. Fool discriminator (standard GAN)

  2. Maximize mutual information (Info-GAN)

Q Network Loss:ΒΆ

For categorical code \(c_{cat}\): $\(L_{cat} = -\mathbb{E} [\log Q(c_{cat} | G(z, c))] = \text{CrossEntropy}\)$

For continuous code \(c_{cont}\) (assume Gaussian): $\(L_{cont} = -\mathbb{E} \left[ -\frac{(c_{cont} - \mu_{Q})^2}{2\sigma_Q^2} - \log \sigma_Q \right]\)$

def sample_latent_code(batch_size, n_categorical, n_continuous):
    """Sample latent code: categorical (one-hot) + continuous (uniform)."""
    code = []
    
    if n_categorical > 0:
        # Sample categorical code as one-hot
        cat_code = np.random.randint(0, n_categorical, batch_size)
        cat_code_onehot = np.zeros((batch_size, n_categorical))
        cat_code_onehot[np.arange(batch_size), cat_code] = 1
        code.append(torch.FloatTensor(cat_code_onehot))
    
    if n_continuous > 0:
        # Sample continuous code from uniform [-1, 1]
        cont_code = torch.FloatTensor(batch_size, n_continuous).uniform_(-1, 1)
        code.append(cont_code)
    
    return torch.cat(code, dim=1)

def mutual_info_loss(Q_outputs, true_code, n_categorical, n_continuous):
    """Compute mutual information loss for Q network."""
    loss = 0
    
    if n_categorical > 0:
        # Categorical: cross entropy
        cat_target = torch.argmax(true_code[:, :n_categorical], dim=1)
        loss += F.cross_entropy(Q_outputs['categorical'], cat_target)
    
    if n_continuous > 0:
        # Continuous: Gaussian negative log-likelihood
        cont_target = true_code[:, n_categorical:]
        mu = Q_outputs['continuous_mu']
        logvar = Q_outputs['continuous_logvar']
        
        # NLL = 0.5 * (log(2Ο€) + logvar + (x - mu)^2 / exp(logvar))
        nll = 0.5 * (logvar + (cont_target - mu)**2 / torch.exp(logvar))
        loss += nll.mean()
    
    return loss

# Test loss computation
batch_size = 64
z = torch.randn(batch_size, latent_dim).to(device)
c = sample_latent_code(batch_size, n_categorical, n_continuous).to(device)

with torch.no_grad():
    fake_imgs = G(z, c)
    _, features = D(fake_imgs)
    Q_outputs = Q(features)
    mi_loss = mutual_info_loss(Q_outputs, c, n_categorical, n_continuous)

print(f"Test MI loss: {mi_loss.item():.4f}")
print(f"Categorical prediction shape: {Q_outputs['categorical'].shape}")
print(f"Continuous mu shape: {Q_outputs['continuous_mu'].shape}")

Training Info-GAN on MNISTΒΆ

Training proceeds with three simultaneous objectives: the standard adversarial loss for the generator and discriminator, plus an information-theoretic regularizer that maximizes a lower bound on the mutual information \(I(c; G(z, c))\). In practice the Q-network outputs parameters of the code distributions – a softmax for categorical codes and (mean, variance) for continuous codes – and the mutual information is estimated via the variational bound. The hyperparameter \(\lambda\) controls the weight of the information loss relative to the adversarial loss. Watching the categorical accuracy and continuous code reconstruction during training confirms whether the network is learning meaningful disentanglement.

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # [-1, 1]
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)

# Optimizers
lr = 0.0002
optimizer_G = optim.Adam(list(G.parameters()) + list(Q.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# Training loop
n_epochs = 15
lambda_mi = 1.0  # Weight for mutual information loss

G.train()
D.train()
Q.train()

history = {'D_loss': [], 'G_loss': [], 'MI_loss': []}

for epoch in range(n_epochs):
    epoch_D_loss = 0
    epoch_G_loss = 0
    epoch_MI_loss = 0
    
    for batch_idx, (real_imgs, _) in enumerate(train_loader):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        
        # Labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # =====================
        # Train Discriminator
        # =====================
        optimizer_D.zero_grad()
        
        # Real images
        real_validity, _ = D(real_imgs)
        d_real_loss = F.binary_cross_entropy(real_validity, real_labels)
        
        # Fake images
        z = torch.randn(batch_size, latent_dim).to(device)
        c = sample_latent_code(batch_size, n_categorical, n_continuous).to(device)
        fake_imgs = G(z, c)
        fake_validity, _ = D(fake_imgs.detach())
        d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        optimizer_D.step()
        
        # =====================
        # Train Generator + Q
        # =====================
        optimizer_G.zero_grad()
        
        # Generate new samples
        z = torch.randn(batch_size, latent_dim).to(device)
        c = sample_latent_code(batch_size, n_categorical, n_continuous).to(device)
        fake_imgs = G(z, c)
        
        # Adversarial loss
        fake_validity, features = D(fake_imgs)
        g_loss = F.binary_cross_entropy(fake_validity, real_labels)
        
        # Mutual information loss
        Q_outputs = Q(features)
        mi_loss = mutual_info_loss(Q_outputs, c, n_categorical, n_continuous)
        
        # Total generator loss
        total_g_loss = g_loss + lambda_mi * mi_loss
        total_g_loss.backward()
        optimizer_G.step()
        
        # Track losses
        epoch_D_loss += d_loss.item()
        epoch_G_loss += g_loss.item()
        epoch_MI_loss += mi_loss.item()
    
    # Average losses
    epoch_D_loss /= len(train_loader)
    epoch_G_loss /= len(train_loader)
    epoch_MI_loss /= len(train_loader)
    
    history['D_loss'].append(epoch_D_loss)
    history['G_loss'].append(epoch_G_loss)
    history['MI_loss'].append(epoch_MI_loss)
    
    print(f"Epoch [{epoch+1}/{n_epochs}] D_loss: {epoch_D_loss:.4f} | "
          f"G_loss: {epoch_G_loss:.4f} | MI_loss: {epoch_MI_loss:.4f}")

print("\nTraining complete!")
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

ax1 = axes[0]
ax1.plot(history['D_loss'], label='Discriminator', linewidth=2)
ax1.plot(history['G_loss'], label='Generator', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Adversarial Losses', fontsize=13)
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2 = axes[1]
ax2.plot(history['MI_loss'], color='green', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('MI Loss', fontsize=12)
ax2.set_title('Mutual Information Loss', fontsize=13)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

7. Analyzing Learned Latent CodesΒΆ

Categorical Code VariationΒΆ

Fix noise \(z\) and continuous code, vary categorical code to see if it captures digit identity.

G.eval()

# Fix noise and continuous code
z_fixed = torch.randn(1, latent_dim).to(device)
c_cont_fixed = torch.zeros(1, n_continuous).to(device)

# Vary categorical code
fig, axes = plt.subplots(1, n_categorical, figsize=(15, 2))

with torch.no_grad():
    for cat_idx in range(n_categorical):
        # Create one-hot categorical code
        c_cat = torch.zeros(1, n_categorical).to(device)
        c_cat[0, cat_idx] = 1
        
        # Combine codes
        c = torch.cat([c_cat, c_cont_fixed], dim=1)
        
        # Generate
        img = G(z_fixed, c)
        img = img.cpu().numpy().squeeze()
        
        axes[cat_idx].imshow(img, cmap='gray')
        axes[cat_idx].set_title(f'Cat={cat_idx}', fontsize=10)
        axes[cat_idx].axis('off')

plt.suptitle('Categorical Code Variation (Fixed Noise)', fontsize=14)
plt.tight_layout()
plt.show()

print("If trained well, each column should show a different digit class.")

Continuous Code VariationΒΆ

To verify that the continuous latent codes have learned interpretable factors of variation, we fix the noise vector \(z\) and the categorical code, then sweep each continuous code across a range (typically \([-2, 2]\)). If training was successful, each continuous code should control a single, smooth visual attribute – for example, one code might govern digit width while another controls slant. This kind of traversal visualization is the standard evaluation protocol for disentangled representations and directly demonstrates why Info-GAN is valuable: it discovers meaningful structure in the data without any supervision.

# Fix noise and categorical code (digit 3 for example)
z_fixed = torch.randn(1, latent_dim).to(device)
c_cat_fixed = torch.zeros(1, n_categorical).to(device)
c_cat_fixed[0, 3] = 1  # Choose digit 3

# Vary first continuous code
n_samples = 10
c1_values = np.linspace(-2, 2, n_samples)

fig, axes = plt.subplots(2, n_samples, figsize=(15, 4))

with torch.no_grad():
    # Row 1: Vary c1, fix c2=0
    for i, c1 in enumerate(c1_values):
        c_cont = torch.tensor([[c1, 0]], dtype=torch.float32).to(device)
        c = torch.cat([c_cat_fixed, c_cont], dim=1)
        
        img = G(z_fixed, c).cpu().numpy().squeeze()
        axes[0, i].imshow(img, cmap='gray')
        axes[0, i].set_title(f'c1={c1:.1f}', fontsize=9)
        axes[0, i].axis('off')
    
    # Row 2: Vary c2, fix c1=0
    for i, c2 in enumerate(c1_values):
        c_cont = torch.tensor([[0, c2]], dtype=torch.float32).to(device)
        c = torch.cat([c_cat_fixed, c_cont], dim=1)
        
        img = G(z_fixed, c).cpu().numpy().squeeze()
        axes[1, i].imshow(img, cmap='gray')
        axes[1, i].set_title(f'c2={c2:.1f}', fontsize=9)
        axes[1, i].axis('off')

axes[0, 0].set_ylabel('Vary c1', fontsize=11)
axes[1, 0].set_ylabel('Vary c2', fontsize=11)
plt.suptitle('Continuous Code Variation (e.g., rotation, width)', fontsize=14)
plt.tight_layout()
plt.show()

print("Continuous codes should capture smooth transformations (e.g., rotation, thickness).")

SummaryΒΆ

Key Contributions of Info-GAN:ΒΆ

  1. Unsupervised disentanglement - learns interpretable factors without labels

  2. Mutual information maximization - ensures latent codes are meaningful

  3. Variational bound - tractable optimization via auxiliary network \(Q\)

  4. Minimal overhead - only adds small network \(Q\), shares features with \(D\)

Applications:ΒΆ

  • Controllable generation - manipulate specific attributes

  • Representation learning - discover data structure automatically

  • Data augmentation - generate variations along learned factors

  • Interpretability - understand what GAN has learned

Limitations:ΒΆ

  • Requires choosing number of codes \(|c|\) in advance

  • No guarantee which code learns which factor

  • Training can be unstable (GAN issues + MI objective)

Extensions:ΒΆ

  • Ξ²-VAE + Info-GAN - combine strengths

  • Multi-level Info-GAN - hierarchical disentanglement

  • Info-GAN for video - temporal disentanglement

Next Steps:ΒΆ

  • 06_conditional_gan.ipynb - Class-conditional generation

  • 07_bayesian_gan.ipynb - Uncertainty in generative models

  • Return to 03_variational_autoencoders_advanced.ipynb for disentanglement comparison