InfoGAN: Information-Maximizing Generative Adversarial Networks - Comprehensive TheoryΒΆ
IntroductionΒΆ
InfoGAN is an unsupervised learning method that extends the standard GAN framework to learn disentangled and interpretable representations. By maximizing mutual information between a subset of latent variables and generated samples, InfoGAN discovers meaningful semantic features (e.g., digit identity, rotation, width) without requiring labeled data.
Key Innovation: Automatic discovery of interpretable latent codes through information-theoretic regularization.
Applications:
Unsupervised disentanglement learning
Controllable generation (manipulate specific attributes)
Feature extraction
Semi-supervised learning
Background: Disentangled RepresentationsΒΆ
What is Disentanglement?ΒΆ
A disentangled representation factorizes the underlying explanatory factors of variation in data: $\(\mathbf{z} = (\mathbf{z}_1, \mathbf{z}_2, \ldots, \mathbf{z}_K)\)$
Where each \(\mathbf{z}_i\) controls a distinct, interpretable property (e.g., shape, color, size).
Example (MNIST):
\(c_1\): Digit identity (0-9)
\(c_2\): Rotation angle
\(c_3\): Stroke thickness
Benefits:
Interpretability: Each dimension has semantic meaning
Transfer Learning: Disentangled features generalize better
Controllability: Manipulate specific properties independently
Sample Efficiency: Simpler representations require less data
Why Standard GANs Fail at DisentanglementΒΆ
Problem: In vanilla GAN, generator input \(\mathbf{z} \sim p(\mathbf{z})\) is typically: $\(\mathbf{z} \sim \mathcal{N}(0, I)\)$
All dimensions are treated equallyβno structure encouraging interpretability.
Result: Generator learns entangled representations where individual dimensions donβt correspond to meaningful factors.
Information Theory PrimerΒΆ
EntropyΒΆ
Shannon Entropy measures uncertainty in a random variable: $\(H(X) = -\mathbb{E}_{x \sim p(x)}[\log p(x)] = -\sum_x p(x) \log p(x)\)$
Interpretation:
High entropy β high uncertainty (uniform distribution)
Low entropy β low uncertainty (peaked distribution)
Mutual InformationΒΆ
Definition: Mutual information \(I(X; Y)\) measures how much knowing \(X\) reduces uncertainty about \(Y\):
Properties:
Non-negative: \(I(X; Y) \geq 0\)
Symmetric: \(I(X; Y) = I(Y; X)\)
Zero iff independent: \(I(X; Y) = 0 \iff p(x,y) = p(x)p(y)\)
Upper bound: \(I(X; Y) \leq \min(H(X), H(Y))\)
Intuition:
\(I(X; Y) = 0\): \(X\) and \(Y\) are independent (knowing \(X\) tells nothing about \(Y\))
\(I(X; Y)\) large: \(X\) and \(Y\) are highly dependent
Conditional EntropyΒΆ
Interpretation: Average uncertainty in \(X\) given knowledge of \(Y\).
Chain Rule: $\(H(X, Y) = H(X) + H(Y | X) = H(Y) + H(X | Y)\)$
InfoGAN: Core IdeaΒΆ
Standard GAN ObjectiveΒΆ
Problem: No incentive for \(G\) to use all dimensions of \(\mathbf{z}\) meaningfully.
InfoGAN: Latent Code DecompositionΒΆ
Decompose generator input into: $\(\mathbf{z} = (\mathbf{c}, \mathbf{n})\)$
Where:
\(\mathbf{c} = (c_1, c_2, \ldots, c_L)\): Latent codes (structured, interpretable)
Can be discrete (categorical) or continuous
Example: \(c_1 \sim \text{Cat}(K=10)\) for digit identity
Example: \(c_2 \sim \text{Unif}(-1, 1)\) for rotation
\(\mathbf{n}\): Noise (incompressible, random)
Traditional noise: \(\mathbf{n} \sim \mathcal{N}(0, I)\)
Mutual Information ObjectiveΒΆ
Goal: Maximize mutual information between latent codes \(\mathbf{c}\) and generated data \(G(\mathbf{c}, \mathbf{n})\):
Interpretation:
Force generator to use latent codes \(\mathbf{c}\)
Prevent \(G\) from ignoring \(\mathbf{c}\) (information bottleneck)
\(\mathbf{c}\) should be recoverable from \(G(\mathbf{c}, \mathbf{n})\)
Full InfoGAN ObjectiveΒΆ
Where:
\(V(D, G)\): Standard GAN objective
\(I(\mathbf{c}; G(\mathbf{c}, \mathbf{n}))\): Mutual information term
\(\lambda\): Hyperparameter (typically \(\lambda = 1\))
\(Q\): Auxiliary network approximating \(p(\mathbf{c} | \mathbf{x})\)
Challenge: Computing \(I(\mathbf{c}; G(\mathbf{c}, \mathbf{n}))\) directly is intractable.
Variational Mutual Information MaximizationΒΆ
The ChallengeΒΆ
Problem: Posterior \(p(\mathbf{c} | \mathbf{x})\) is intractable.
Variational Lower BoundΒΆ
Introduce auxiliary distribution \(Q(\mathbf{c} | \mathbf{x})\) to approximate \(p(\mathbf{c} | \mathbf{x})\).
Lemma (Variational Lower Bound): $\(I(\mathbf{c}; G(\mathbf{c}, \mathbf{n})) \geq \mathbb{E}_{\mathbf{c} \sim p(\mathbf{c}), \mathbf{x} \sim G(\mathbf{c}, \mathbf{n})}[\log Q(\mathbf{c} | \mathbf{x})] + H(\mathbf{c}) = L_I(G, Q)\)$
Proof:
Inequality: Follows from non-negativity of KL divergence.
Equality: When \(Q(\mathbf{c} | \mathbf{x}) = p(\mathbf{c} | \mathbf{x})\).
Practical ObjectiveΒΆ
Where the mutual information lower bound is: $\(L_I(G, Q) = \mathbb{E}_{\mathbf{c} \sim p(\mathbf{c}), \mathbf{n} \sim p(\mathbf{n})}[\log Q(\mathbf{c} | G(\mathbf{c}, \mathbf{n}))] + H(\mathbf{c})\)$
Simplification: Since \(H(\mathbf{c})\) is constant w.r.t. \(G\) and \(Q\): $\(\max_{G, Q} L_I = \max_{G,Q} \mathbb{E}_{\mathbf{c}, \mathbf{n}}[\log Q(\mathbf{c} | G(\mathbf{c}, \mathbf{n}))]\)$
Implementation DetailsΒΆ
Q-Network ArchitectureΒΆ
The auxiliary network \(Q\) shares parameters with discriminator \(D\) for efficiency:
Shared Layers: \(D\) and \(Q\) share convolutional feature extractor
Separate Heads:
\(D\) head: Output \(D(\mathbf{x}) \in [0, 1]\) (real/fake)
\(Q\) head: Output \(Q(\mathbf{c} | \mathbf{x})\) (latent code distribution)
Benefits:
Parameter efficiency (shared representation)
Minimal computational overhead
Natural integration with discriminator
Latent Code TypesΒΆ
1. Discrete Categorical (\(c_i \sim \text{Cat}(K)\)):
Prior: \(p(c_i = k) = 1/K\) (uniform over \(K\) classes)
\(Q\) parameterization: Softmax over \(K\) logits $\(Q(c_i | \mathbf{x}) = \text{Softmax}(f_Q(\mathbf{x}))\)$
Loss: Cross-entropy $\(\mathcal{L}_{\text{cat}} = -\mathbb{E}[\log Q(c_i | \mathbf{x})]\)$
2. Continuous (\(c_j \sim \mathcal{N}(0, 1)\) or \(\text{Unif}(-1, 1)\)):
Prior: Gaussian \(p(c_j) = \mathcal{N}(0, 1)\) or uniform
\(Q\) parameterization: Gaussian with learned mean and variance $\(Q(c_j | \mathbf{x}) = \mathcal{N}(\mu_Q(\mathbf{x}), \sigma_Q^2(\mathbf{x}))\)$
Loss: Negative log-likelihood (equivalent to MSE for Gaussian) $\(\mathcal{L}_{\text{cont}} = \frac{1}{2}\mathbb{E}\left[\frac{(c_j - \mu_Q(\mathbf{x}))^2}{\sigma_Q^2(\mathbf{x})} + \log \sigma_Q^2(\mathbf{x})\right]\)$
Simplification (fixed variance): If we fix \(\sigma_Q^2 = 1\), the loss simplifies to MSE: $\(\mathcal{L}_{\text{cont}} = \mathbb{E}[\|c_j - \mu_Q(\mathbf{x})\|^2]\)$
Training AlgorithmΒΆ
For each iteration:
Sample minibatch:
Latent codes: \(\mathbf{c} \sim p(\mathbf{c})\)
Noise: \(\mathbf{n} \sim \mathcal{N}(0, I)\)
Real data: \(\mathbf{x} \sim p_{\text{data}}\)
Generate fake samples: $\(\mathbf{x}_{\text{fake}} = G(\mathbf{c}, \mathbf{n})\)$
Update Discriminator \(D\): $\(\max_D \mathbb{E}_{\mathbf{x}}[\log D(\mathbf{x})] + \mathbb{E}_{\mathbf{c}, \mathbf{n}}[\log(1 - D(G(\mathbf{c}, \mathbf{n})))]\)$
Update Generator \(G\) and \(Q\) jointly: $\(\min_{G,Q} \mathbb{E}_{\mathbf{c}, \mathbf{n}}[\log(1 - D(G(\mathbf{c}, \mathbf{n})))] - \lambda \mathbb{E}_{\mathbf{c}, \mathbf{n}}[\log Q(\mathbf{c} | G(\mathbf{c}, \mathbf{n}))]\)$
Key Point: The mutual information loss acts as a regularizer, penalizing \(G\) if \(Q\) cannot recover \(\mathbf{c}\) from \(G(\mathbf{c}, \mathbf{n})\).
Theoretical AnalysisΒΆ
Why Mutual Information Encourages Disentanglement?ΒΆ
Information Bottleneck Principle:
Maximize \(I(\mathbf{c}; \mathbf{x})\): Generator must use \(\mathbf{c}\) to create \(\mathbf{x}\)
Structured prior \(p(\mathbf{c})\): Factorized prior encourages independence $\(p(\mathbf{c}) = \prod_{i=1}^L p(c_i)\)$
Result: To maximize \(I(\mathbf{c}; \mathbf{x})\) under factorized prior, each \(c_i\) specializes to different aspects of \(\mathbf{x}\).
Informal Argument:
If \(c_1\) and \(c_2\) encode the same information, \(I(\mathbf{c}; \mathbf{x})\) is not maximized
Redundancy reduces effective capacity
Maximizing MI incentivizes efficient, non-redundant use of \(c_i\)
Connection to Information BottleneckΒΆ
Information Bottleneck Method (Tishby et al.): $\(\max_{p(\tilde{X} | X)} I(\tilde{X}; Y) - \beta I(\tilde{X}; X)\)$
Where:
\(\tilde{X}\): Compressed representation
\(Y\): Target variable
\(\beta\): Trade-off between compression and prediction
InfoGAN Perspective:
\(\mathbf{c}\): Compressed representation
Maximize \(I(\mathbf{c}; G(\mathbf{c}, \mathbf{n}))\) subject to factorized prior
Factorization acts as implicit compression constraint
Relation to Ξ²-VAEΒΆ
Ξ²-VAE Objective: $\(\mathcal{L}_{\beta-\text{VAE}} = \mathbb{E}[\log p(\mathbf{x} | \mathbf{z})] - \beta D_{KL}(q(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z}))\)$
Similarities:
Both encourage disentanglement via information-theoretic constraints
InfoGAN: Maximize \(I(\mathbf{c}; \mathbf{x})\)
Ξ²-VAE: Constrain \(I(\mathbf{z}; \mathbf{x})\) via KL penalty
Differences:
Ξ²-VAE: Variational inference framework (encoder-decoder)
InfoGAN: GAN framework (adversarial training)
Ξ²-VAE: Explicit reconstruction loss
InfoGAN: Implicit via adversarial loss
Experiments and ResultsΒΆ
MNISTΒΆ
Latent Code Configuration:
\(c_1 \sim \text{Cat}(K=10)\): Discrete (10 classes)
\(c_2, c_3 \sim \text{Unif}(-1, 1)\): Continuous (2 dimensions)
\(\mathbf{n} \sim \mathcal{N}(0, I_{62})\): Noise (62 dimensions)
Discovered Meanings:
\(c_1\): Digit identity (0-9)
\(c_2\): Rotation angle (continuous variation)
\(c_3\): Stroke thickness (width)
Key Observation: Without any labels, InfoGAN discovers semantic factors!
3D FacesΒΆ
Latent Codes:
\(c_1 \sim \text{Cat}(K=10)\): Discrete
\(c_2, c_3, c_4, c_5 \sim \text{Unif}(-1, 1)\): Continuous
Discovered Meanings:
\(c_1\): Face identity / person
\(c_2\): Elevation (vertical rotation)
\(c_3\): Azimuth (horizontal rotation)
\(c_4\): Lighting direction
\(c_5\): (Less interpretable, possibly expression)
3D ChairsΒΆ
Discovered:
Rotation angle
Chair type
Width
Quantitative Evaluation of DisentanglementΒΆ
1. Mutual Information Gap (MIG)ΒΆ
Where:
\(v_k\): Ground-truth factor
\(c_j\): Learned latent code
\(j(k)\): Best-matching code for factor \(k\)
Interpretation: Measures how much more information the best code provides about each factor compared to the next-best code.
2. SAP Score (Separated Attribute Predictability)ΒΆ
Train linear classifier to predict each ground-truth factor from each latent code
Compute difference between top two scores
Average over factors
Higher score β better disentanglement.
3. DCI (Disentanglement, Completeness, Informativeness)ΒΆ
Three metrics:
Disentanglement: Each code controls single factor
Completeness: Each factor controlled by single code
Informativeness: Total predictiveness of factors
4. Higgins MetricΒΆ
Classifier-based: Train classifier on latent traversals, measure accuracy.
Challenges with MetricsΒΆ
Problem: Different metrics can disagree on ranking of models.
Recent Work: Focus on downstream task performance rather than disentanglement metrics.
Variants and ExtensionsΒΆ
1. ss-InfoGAN (Semi-Supervised)ΒΆ
Combine InfoGAN with small amount of labeled data:
Where: $\(\mathcal{L}_{\text{supervised}} = -\mathbb{E}_{(\mathbf{x}, y) \sim p_{\text{labeled}}}[\log Q(y | \mathbf{x})]\)$
Benefit: Even small amount of labels significantly improves semantic alignment.
2. InfoGAN-CR (Contrastive Regularization)ΒΆ
Problem: InfoGAN can sometimes ignore some latent codes.
Solution: Add contrastive loss: $\(\mathcal{L}_{\text{CR}} = -\log \frac{\exp(s(\mathbf{x}, \mathbf{c}))}{\exp(s(\mathbf{x}, \mathbf{c})) + \sum_{\mathbf{c}'} \exp(s(\mathbf{x}, \mathbf{c}'))}\)$
Where \(s(\mathbf{x}, \mathbf{c})\) is similarity score (e.g., cosine similarity).
3. AC-InfoGAN (Auxiliary Classifier)ΒΆ
Combine InfoGAN with AC-GAN:
Auxiliary classifier for labeled data
InfoGAN objective for unsupervised codes
4. Variational InfoGANΒΆ
Use variational inference for more flexible \(Q\) distributions: $\(Q(\mathbf{c} | \mathbf{x}) = \mathcal{N}(\mu_Q(\mathbf{x}), \text{diag}(\sigma_Q^2(\mathbf{x})))\)$
With reparameterization trick for backpropagation.
LimitationsΒΆ
1. No Guarantee of Semantic AlignmentΒΆ
Problem: Discovered codes may not align with human-interpretable factors.
Example: On complex datasets (ImageNet), codes may capture texture, lighting, rather than object identity.
Mitigation: Use semi-supervised variants (ss-InfoGAN).
2. Mode CollapseΒΆ
InfoGAN inherits GAN training instabilities:
Mode collapse
Training instability
Hyperparameter sensitivity
Mitigation: Spectral normalization, progressive growing, Wasserstein loss.
3. Limited to Factorized PriorsΒΆ
Assumption: \(p(\mathbf{c}) = \prod_i p(c_i)\) (independence)
Problem: Real-world factors may be correlated (e.g., age and wrinkles).
Extension: Structured priors (e.g., hierarchical).
4. Difficulty with High-Dimensional DataΒΆ
Challenge: On complex datasets (e.g., CelebA-HQ), harder to achieve clean disentanglement.
Recent Approach: Combine with StyleGAN (StyleGAN2-InfoGAN).
5. Evaluation ChallengesΒΆ
Problem: No ground-truth factors for real-world data.
Solutions:
Qualitative inspection (latent traversals)
Downstream task performance
User studies
ApplicationsΒΆ
1. Controllable Image GenerationΒΆ
Generate images with specific attributes
Interactive editing (e.g., change pose, lighting, expression)
2. Feature ExtractionΒΆ
Use learned codes \(\mathbf{c}\) as features for downstream tasks
Often more interpretable than standard GAN latent space
3. Data AugmentationΒΆ
Generate new samples by varying \(\mathbf{c}\)
Useful when data is limited
4. Anomaly DetectionΒΆ
Model normal data with disentangled codes
Detect anomalies as samples with unusual code distributions
5. Fairness in MLΒΆ
Disentangle sensitive attributes (gender, race)
Enable fair decision-making by controlling for biases
Comparison with Other Disentanglement MethodsΒΆ
InfoGAN vs. Ξ²-VAEΒΆ
Aspect |
InfoGAN |
Ξ²-VAE |
|---|---|---|
Framework |
GAN (adversarial) |
VAE (variational) |
Objective |
Maximize \(I(\mathbf{c}; \mathbf{x})\) |
Constrain \(I(\mathbf{z}; \mathbf{x})\) via KL |
Sample Quality |
Higher (GANs generally better) |
Lower (VAEs tend to blur) |
Training Stability |
Lower (GAN instability) |
Higher (VAE more stable) |
Reconstruction |
No explicit reconstruction |
Explicit reconstruction loss |
Inference |
No encoder (generate only) |
Has encoder (both directions) |
InfoGAN vs. Factor-VAEΒΆ
Factor-VAE: Adds total correlation penalty: $\(\mathcal{L} = \mathcal{L}_{\text{VAE}} + \gamma \cdot TC(q(\mathbf{z}))\)$
Where \(TC\) is total correlation (measures deviation from factorial distribution).
Trade-off: Factor-VAE often achieves better disentanglement metrics than Ξ²-VAE, but InfoGAN has better sample quality.
InfoGAN vs. ControlVAEΒΆ
ControlVAE: Directly controls capacity of information flow with PID controller.
Comparison: ControlVAE more stable than Ξ²-VAE, but InfoGAN has adversarial advantages.
Future DirectionsΒΆ
1. Hierarchical DisentanglementΒΆ
Learn hierarchical structure of factors:
Coarse level: Object category
Medium level: Pose, color
Fine level: Texture, lighting
2. Causal DisentanglementΒΆ
Go beyond statistical independence to causal independence: $\(\mathbf{c}_i = f_i(\text{parents}(\mathbf{c}_i), \epsilon_i)\)$
3. Group-Based DisentanglementΒΆ
Leverage group theory (e.g., symmetries, transformations):
Rotation group for pose
Translation group for position
4. Disentanglement for VideoΒΆ
Separate:
Content (what): object identity
Motion (how): dynamics, actions
Style (appearance): lighting, texture
5. Multi-Modal DisentanglementΒΆ
Disentangle factors across modalities (image, text, audio):
Shared factors (e.g., emotion)
Modality-specific factors (e.g., pitch in audio)
6. Intervention-Based LearningΒΆ
Use interventions (counterfactuals) to discover causal factors:
βWhat if we change pose but not identity?β
Practical TipsΒΆ
Hyperparameter SelectionΒΆ
\(\lambda\) (MI weight): Start with \(\lambda = 1.0\), adjust if codes are ignored
Number of codes: Start small (2-5), increase gradually
Code types: Mix discrete and continuous based on data
Learning rates: Use separate LR for \(G\), \(D\), \(Q\) (e.g., \(10^{-4}\))
Architecture ChoicesΒΆ
Shared \(D\) and \(Q\): Efficient, works well in practice
Q network depth: 1-2 extra layers after shared features
Activation: LeakyReLU for \(D\)/\(Q\), Tanh for \(G\) output
DebuggingΒΆ
Visualize latent traversals: Fix all codes except one, vary it, generate images
Check MI loss: Should be non-zero and decreasing
Monitor code usage: Ensure all codes are being used (not mode collapse)
ConclusionΒΆ
InfoGAN elegantly extends GANs with information-theoretic regularization to discover interpretable, disentangled representations in an unsupervised manner. By maximizing mutual information between latent codes and generated data, InfoGAN incentivizes the generator to use codes for distinct, semantic factors.
Key Contributions:
Unsupervised disentanglement: No labels required
Variational lower bound: Tractable mutual information approximation
Efficient implementation: Shared \(D\)-\(Q\) architecture
Empirical success: Discovers meaningful factors on various datasets
Impact:
Pioneered unsupervised disentanglement in GANs
Inspired numerous follow-up works (Ξ²-VAE, Factor-VAE, ControlVAE)
Applications in controllable generation, representation learning, fairness
Current State: While newer methods (StyleGAN, diffusion models) have surpassed InfoGAN in generation quality, the core idea of information-theoretic regularization remains influential in modern disentanglement research.
"""
InfoGAN - Complete Implementation
This implementation includes:
1. Standard GAN components (Generator, Discriminator)
2. Q-Network for approximating P(c|x)
3. Categorical and continuous latent codes
4. Mutual information maximization
5. Training utilities
6. Visualization of disentangled representations
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict
from dataclasses import dataclass
# ============================================================================
# 1. NETWORK ARCHITECTURES
# ============================================================================
class InfoGANGenerator(nn.Module):
"""
InfoGAN Generator.
Takes noise z and latent codes c as input, generates images.
Args:
latent_dim: Dimension of noise z
code_dim: Total dimension of latent codes c
img_channels: Number of image channels
img_size: Output image size (assumes square)
"""
def __init__(
self,
latent_dim: int = 62,
code_dim: int = 12, # e.g., 10 (categorical) + 2 (continuous)
img_channels: int = 1,
img_size: int = 28
):
super().__init__()
self.latent_dim = latent_dim
self.code_dim = code_dim
self.img_channels = img_channels
self.img_size = img_size
input_dim = latent_dim + code_dim
self.model = nn.Sequential(
# Input: (latent_dim + code_dim)
nn.Linear(input_dim, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, 128 * 7 * 7),
nn.BatchNorm1d(128 * 7 * 7),
nn.ReLU()
)
# Reshape to (batch, 128, 7, 7) for conv layers
self.conv_layers = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 7x7 -> 14x14
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, img_channels, 4, 2, 1), # 14x14 -> 28x28
nn.Tanh()
)
def forward(self, z: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
"""
Generate images from noise and latent codes.
Args:
z: Noise (batch_size, latent_dim)
c: Latent codes (batch_size, code_dim)
Returns:
Generated images (batch_size, img_channels, img_size, img_size)
"""
# Concatenate noise and codes
gen_input = torch.cat([z, c], dim=1)
# MLP layers
x = self.model(gen_input)
# Reshape for conv layers
x = x.view(x.size(0), 128, 7, 7)
# Convolutional layers
img = self.conv_layers(x)
return img
class SharedDiscriminatorQ(nn.Module):
"""
Shared Discriminator and Q-network.
The discriminator and Q-network share convolutional layers,
then split into separate heads:
- D head: Real/fake classification
- Q head: Latent code prediction
Args:
img_channels: Number of image channels
img_size: Input image size
num_categorical: Number of categorical classes
num_continuous: Number of continuous codes
"""
def __init__(
self,
img_channels: int = 1,
img_size: int = 28,
num_categorical: int = 10,
num_continuous: int = 2
):
super().__init__()
self.num_categorical = num_categorical
self.num_continuous = num_continuous
# Shared feature extractor
self.shared = nn.Sequential(
nn.Conv2d(img_channels, 64, 4, 2, 1), # 28x28 -> 14x14
nn.LeakyReLU(0.1),
nn.Conv2d(64, 128, 4, 2, 1), # 14x14 -> 7x7
nn.BatchNorm2d(128),
nn.LeakyReLU(0.1),
nn.Flatten(),
nn.Linear(128 * 7 * 7, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.1)
)
# Discriminator head (real/fake)
self.d_head = nn.Sequential(
nn.Linear(1024, 1),
nn.Sigmoid()
)
# Q head (predict latent codes)
# For categorical code: output logits
# For continuous codes: output mean (and optionally variance)
# Shared Q layer
self.q_shared = nn.Sequential(
nn.Linear(1024, 128),
nn.BatchNorm1d(128),
nn.LeakyReLU(0.1)
)
# Categorical code head (logits)
self.q_categorical = nn.Linear(128, num_categorical)
# Continuous code head (mean)
self.q_continuous_mean = nn.Linear(128, num_continuous)
# Continuous code head (log variance) - optional
# For simplicity, we'll use fixed variance
# self.q_continuous_logvar = nn.Linear(128, num_continuous)
def forward(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
img: Input images (batch_size, img_channels, img_size, img_size)
Returns:
d_out: Discriminator output (batch_size, 1) - probability of real
q_cat: Categorical code logits (batch_size, num_categorical)
q_cont: Continuous code means (batch_size, num_continuous)
"""
# Shared features
features = self.shared(img)
# Discriminator output
d_out = self.d_head(features)
# Q network features
q_features = self.q_shared(features)
# Categorical code (logits for cross-entropy)
q_cat = self.q_categorical(q_features)
# Continuous code (means)
q_cont = self.q_continuous_mean(q_features)
return d_out, q_cat, q_cont
# ============================================================================
# 2. LATENT CODE SAMPLING
# ============================================================================
def sample_latent_codes(
batch_size: int,
num_categorical: int,
num_continuous: int,
device: str = 'cpu'
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample latent codes from prior distributions.
Args:
batch_size: Number of samples
num_categorical: Number of categorical classes
num_continuous: Number of continuous codes
device: Device to create tensors on
Returns:
c_cat_idx: Categorical code indices (batch_size,)
c_cat_onehot: Categorical code one-hot (batch_size, num_categorical)
c_cont: Continuous codes (batch_size, num_continuous)
"""
# Categorical: sample uniformly from K classes
c_cat_idx = torch.randint(0, num_categorical, (batch_size,), device=device)
c_cat_onehot = F.one_hot(c_cat_idx, num_classes=num_categorical).float()
# Continuous: sample from Uniform(-1, 1)
c_cont = torch.rand(batch_size, num_continuous, device=device) * 2 - 1
return c_cat_idx, c_cat_onehot, c_cont
def combine_codes(c_cat_onehot: torch.Tensor, c_cont: torch.Tensor) -> torch.Tensor:
"""
Combine categorical and continuous codes into single vector.
Args:
c_cat_onehot: Categorical code one-hot (batch_size, num_categorical)
c_cont: Continuous codes (batch_size, num_continuous)
Returns:
Combined codes (batch_size, num_categorical + num_continuous)
"""
return torch.cat([c_cat_onehot, c_cont], dim=1)
# ============================================================================
# 3. LOSS FUNCTIONS
# ============================================================================
def discriminator_loss(
d_real: torch.Tensor,
d_fake: torch.Tensor
) -> torch.Tensor:
"""
Standard GAN discriminator loss (binary cross-entropy).
Args:
d_real: Discriminator output on real images (batch_size, 1)
d_fake: Discriminator output on fake images (batch_size, 1)
Returns:
Discriminator loss
"""
real_labels = torch.ones_like(d_real)
fake_labels = torch.zeros_like(d_fake)
loss_real = F.binary_cross_entropy(d_real, real_labels)
loss_fake = F.binary_cross_entropy(d_fake, fake_labels)
return loss_real + loss_fake
def generator_loss(d_fake: torch.Tensor) -> torch.Tensor:
"""
Standard GAN generator loss.
Args:
d_fake: Discriminator output on fake images (batch_size, 1)
Returns:
Generator loss
"""
real_labels = torch.ones_like(d_fake)
return F.binary_cross_entropy(d_fake, real_labels)
def mutual_information_loss(
q_cat_logits: torch.Tensor,
q_cont_mean: torch.Tensor,
c_cat_idx: torch.Tensor,
c_cont: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Mutual information loss (lower bound approximation).
For categorical code: Cross-entropy
For continuous code: MSE (assuming Gaussian with unit variance)
Args:
q_cat_logits: Q network categorical output (batch_size, num_categorical)
q_cont_mean: Q network continuous output (batch_size, num_continuous)
c_cat_idx: True categorical code indices (batch_size,)
c_cont: True continuous codes (batch_size, num_continuous)
Returns:
mi_cat: Mutual information loss for categorical code
mi_cont: Mutual information loss for continuous code
"""
# Categorical: cross-entropy
# -E[log Q(c_cat | x)]
mi_cat = F.cross_entropy(q_cat_logits, c_cat_idx)
# Continuous: MSE (equivalent to Gaussian NLL with fixed variance)
# -E[log Q(c_cont | x)] β ||c_cont - ΞΌ_Q(x)||^2
mi_cont = F.mse_loss(q_cont_mean, c_cont)
return mi_cat, mi_cont
# ============================================================================
# 4. TRAINING
# ============================================================================
@dataclass
class InfoGANConfig:
"""Configuration for InfoGAN."""
latent_dim: int = 62
num_categorical: int = 10
num_continuous: int = 2
img_channels: int = 1
img_size: int = 28
# Training
lr_g: float = 2e-4
lr_d: float = 2e-4
beta1: float = 0.5
beta2: float = 0.999
# Loss weights
lambda_cat: float = 1.0 # Weight for categorical MI loss
lambda_cont: float = 0.1 # Weight for continuous MI loss
def train_infogan_step(
generator: nn.Module,
discriminator_q: nn.Module,
real_imgs: torch.Tensor,
optimizer_g: optim.Optimizer,
optimizer_d: optim.Optimizer,
config: InfoGANConfig,
device: str = 'cpu'
) -> Dict[str, float]:
"""
Single training step for InfoGAN.
Args:
generator: Generator network
discriminator_q: Shared discriminator and Q network
real_imgs: Batch of real images
optimizer_g: Generator optimizer
optimizer_d: Discriminator optimizer
config: InfoGAN configuration
device: Device
Returns:
Dictionary of losses
"""
batch_size = real_imgs.size(0)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_d.zero_grad()
# Real images
d_real, _, _ = discriminator_q(real_imgs)
# Fake images
z = torch.randn(batch_size, config.latent_dim, device=device)
c_cat_idx, c_cat_onehot, c_cont = sample_latent_codes(
batch_size, config.num_categorical, config.num_continuous, device
)
c = combine_codes(c_cat_onehot, c_cont)
fake_imgs = generator(z, c)
d_fake, _, _ = discriminator_q(fake_imgs.detach())
# Discriminator loss
d_loss = discriminator_loss(d_real, d_fake)
d_loss.backward()
optimizer_d.step()
# -------------------------
# Train Generator and Q
# -------------------------
optimizer_g.zero_grad()
# Generate fake images
z = torch.randn(batch_size, config.latent_dim, device=device)
c_cat_idx, c_cat_onehot, c_cont = sample_latent_codes(
batch_size, config.num_categorical, config.num_continuous, device
)
c = combine_codes(c_cat_onehot, c_cont)
fake_imgs = generator(z, c)
d_fake, q_cat_logits, q_cont_mean = discriminator_q(fake_imgs)
# Generator adversarial loss
g_loss = generator_loss(d_fake)
# Mutual information loss
mi_cat, mi_cont = mutual_information_loss(q_cat_logits, q_cont_mean, c_cat_idx, c_cont)
# Total generator loss
total_loss = g_loss + config.lambda_cat * mi_cat + config.lambda_cont * mi_cont
total_loss.backward()
optimizer_g.step()
return {
'd_loss': d_loss.item(),
'g_loss': g_loss.item(),
'mi_cat': mi_cat.item(),
'mi_cont': mi_cont.item(),
'total_g': total_loss.item()
}
# ============================================================================
# 5. VISUALIZATION
# ============================================================================
def visualize_categorical_code(
generator: nn.Module,
num_categorical: int,
num_continuous: int,
latent_dim: int,
device: str = 'cpu',
samples_per_class: int = 10
):
"""
Visualize effect of categorical code by generating samples for each class.
Args:
generator: Trained generator
num_categorical: Number of categorical classes
num_continuous: Number of continuous codes
latent_dim: Noise dimension
device: Device
samples_per_class: Samples per class
"""
generator.eval()
fig, axes = plt.subplots(num_categorical, samples_per_class,
figsize=(samples_per_class * 1.5, num_categorical * 1.5))
with torch.no_grad():
for class_idx in range(num_categorical):
# Fixed noise for each class
z = torch.randn(samples_per_class, latent_dim, device=device)
# Set categorical code to class_idx
c_cat_onehot = F.one_hot(
torch.full((samples_per_class,), class_idx, dtype=torch.long, device=device),
num_classes=num_categorical
).float()
# Random continuous codes
c_cont = torch.rand(samples_per_class, num_continuous, device=device) * 2 - 1
# Combine codes
c = combine_codes(c_cat_onehot, c_cont)
# Generate
imgs = generator(z, c).cpu()
# Plot
for i in range(samples_per_class):
ax = axes[class_idx, i] if num_categorical > 1 else axes[i]
img = imgs[i, 0] # Grayscale
ax.imshow(img, cmap='gray')
ax.axis('off')
if i == 0:
ax.set_title(f'Class {class_idx}', fontsize=9)
plt.tight_layout()
plt.suptitle('Categorical Code Visualization', y=1.01, fontsize=12, fontweight='bold')
plt.show()
def visualize_continuous_code(
generator: nn.Module,
num_categorical: int,
num_continuous: int,
latent_dim: int,
continuous_idx: int = 0,
device: str = 'cpu',
num_steps: int = 10,
num_rows: int = 5
):
"""
Visualize effect of continuous code by varying it while keeping others fixed.
Args:
generator: Trained generator
num_categorical: Number of categorical classes
num_continuous: Number of continuous codes
latent_dim: Noise dimension
continuous_idx: Which continuous code to vary
device: Device
num_steps: Number of steps in continuous variation
num_rows: Number of different samples (rows)
"""
generator.eval()
fig, axes = plt.subplots(num_rows, num_steps, figsize=(num_steps * 1.5, num_rows * 1.5))
# Range for continuous code
c_range = torch.linspace(-2, 2, num_steps, device=device)
with torch.no_grad():
for row in range(num_rows):
# Fixed noise and codes for this row
z = torch.randn(1, latent_dim, device=device).repeat(num_steps, 1)
# Random categorical code
cat_class = torch.randint(0, num_categorical, (1,), device=device).item()
c_cat_onehot = F.one_hot(
torch.full((num_steps,), cat_class, dtype=torch.long, device=device),
num_classes=num_categorical
).float()
# Fixed continuous codes
c_cont = torch.rand(1, num_continuous, device=device).repeat(num_steps, 1)
# Vary the selected continuous code
c_cont[:, continuous_idx] = c_range
# Combine
c = combine_codes(c_cat_onehot, c_cont)
# Generate
imgs = generator(z, c).cpu()
# Plot
for col in range(num_steps):
ax = axes[row, col] if num_rows > 1 else axes[col]
img = imgs[col, 0]
ax.imshow(img, cmap='gray')
ax.axis('off')
if row == 0:
ax.set_title(f'{c_range[col]:.1f}', fontsize=8)
plt.tight_layout()
plt.suptitle(f'Continuous Code {continuous_idx} Variation', y=1.01,
fontsize=12, fontweight='bold')
plt.show()
def visualize_2d_continuous_space(
generator: nn.Module,
num_categorical: int,
num_continuous: int,
latent_dim: int,
device: str = 'cpu',
grid_size: int = 10
):
"""
Visualize 2D continuous code space (if num_continuous >= 2).
Args:
generator: Trained generator
num_categorical: Number of categorical classes
num_continuous: Number of continuous codes
latent_dim: Noise dimension
device: Device
grid_size: Grid resolution
"""
if num_continuous < 2:
print("Need at least 2 continuous codes for 2D visualization")
return
generator.eval()
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
# Create grid
c1_range = torch.linspace(-2, 2, grid_size, device=device)
c2_range = torch.linspace(-2, 2, grid_size, device=device)
with torch.no_grad():
# Fixed noise and categorical code
z = torch.randn(1, latent_dim, device=device)
cat_class = torch.randint(0, num_categorical, (1,), device=device).item()
for i, c1 in enumerate(c1_range):
for j, c2 in enumerate(c2_range):
# Set continuous codes
c_cat_onehot = F.one_hot(
torch.tensor([cat_class], device=device),
num_classes=num_categorical
).float()
c_cont = torch.zeros(1, num_continuous, device=device)
c_cont[0, 0] = c1
c_cont[0, 1] = c2
c = combine_codes(c_cat_onehot, c_cont)
# Generate
img = generator(z, c).cpu()[0, 0]
# Plot
axes[i, j].imshow(img, cmap='gray')
axes[i, j].axis('off')
plt.tight_layout()
plt.suptitle('2D Continuous Code Space', y=1.01, fontsize=14, fontweight='bold')
plt.show()
# ============================================================================
# 6. DEMONSTRATION
# ============================================================================
def demo_infogan():
"""Demonstrate InfoGAN components."""
print("=" * 80)
print("InfoGAN Demonstration")
print("=" * 80)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nDevice: {device}")
# Configuration
config = InfoGANConfig()
print("\n1. Configuration")
print("-" * 40)
print(f"Latent dimension (noise): {config.latent_dim}")
print(f"Categorical codes: {config.num_categorical} classes")
print(f"Continuous codes: {config.num_continuous} dimensions")
print(f"Total code dimension: {config.num_categorical + config.num_continuous}")
# Initialize networks
print("\n2. Network Architecture")
print("-" * 40)
generator = InfoGANGenerator(
latent_dim=config.latent_dim,
code_dim=config.num_categorical + config.num_continuous
).to(device)
discriminator_q = SharedDiscriminatorQ(
num_categorical=config.num_categorical,
num_continuous=config.num_continuous
).to(device)
g_params = sum(p.numel() for p in generator.parameters())
dq_params = sum(p.numel() for p in discriminator_q.parameters())
print(f"Generator parameters: {g_params:,}")
print(f"Discriminator + Q parameters: {dq_params:,}")
print(f"Total parameters: {g_params + dq_params:,}")
# Test forward pass
print("\n3. Forward Pass")
print("-" * 40)
batch_size = 4
z = torch.randn(batch_size, config.latent_dim, device=device)
c_cat_idx, c_cat_onehot, c_cont = sample_latent_codes(
batch_size, config.num_categorical, config.num_continuous, device
)
c = combine_codes(c_cat_onehot, c_cont)
print(f"Noise z shape: {z.shape}")
print(f"Categorical code (one-hot) shape: {c_cat_onehot.shape}")
print(f"Continuous code shape: {c_cont.shape}")
print(f"Combined code shape: {c.shape}")
# Generate
fake_imgs = generator(z, c)
print(f"\nGenerated images shape: {fake_imgs.shape}")
# Discriminate and predict codes
d_out, q_cat, q_cont = discriminator_q(fake_imgs)
print(f"Discriminator output shape: {d_out.shape}")
print(f"Q categorical logits shape: {q_cat.shape}")
print(f"Q continuous mean shape: {q_cont.shape}")
# Compute losses
print("\n4. Loss Computation")
print("-" * 40)
real_imgs = torch.randn_like(fake_imgs)
d_real, _, _ = discriminator_q(real_imgs)
d_loss = discriminator_loss(d_real, d_out)
g_loss = generator_loss(d_out)
mi_cat, mi_cont = mutual_information_loss(q_cat, q_cont, c_cat_idx, c_cont)
print(f"Discriminator loss: {d_loss.item():.4f}")
print(f"Generator loss: {g_loss.item():.4f}")
print(f"MI loss (categorical): {mi_cat.item():.4f}")
print(f"MI loss (continuous): {mi_cont.item():.4f}")
# Code recovery accuracy
q_cat_pred = q_cat.argmax(dim=1)
accuracy = (q_cat_pred == c_cat_idx).float().mean()
print(f"\nCategorical code recovery accuracy: {accuracy.item():.2%}")
print("\n5. Mutual Information Interpretation")
print("-" * 40)
print("The mutual information loss encourages:")
print("β Generator to use latent codes c")
print("β Q network to accurately predict c from generated images")
print("β Disentangled representations (each code controls distinct factor)")
print("\nWithout MI loss:")
print("β Generator might ignore codes c")
print("β All variation would be in noise z")
print("β No interpretable structure")
def demo_code_disentanglement():
"""Demonstrate the concept of code disentanglement."""
print("\n" + "=" * 80)
print("Code Disentanglement Demonstration")
print("=" * 80)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = InfoGANConfig()
generator = InfoGANGenerator(
latent_dim=config.latent_dim,
code_dim=config.num_categorical + config.num_continuous
).to(device)
print("\n1. Categorical Code Traversal")
print("-" * 40)
print(f"Varying categorical code across {config.num_categorical} classes")
print("Expected: Different digit identities (if trained on MNIST)")
# Demonstrate categorical traversal
z = torch.randn(config.num_categorical, config.latent_dim, device=device)
c_cont_fixed = torch.zeros(config.num_categorical, config.num_continuous, device=device)
imgs_list = []
for cat_idx in range(config.num_categorical):
c_cat_onehot = F.one_hot(
torch.tensor([cat_idx], device=device),
num_classes=config.num_categorical
).float()
c = combine_codes(c_cat_onehot, c_cont_fixed[:1])
img = generator(z[:1], c)
imgs_list.append(img)
print(f"Generated {len(imgs_list)} images with different categorical codes")
print("\n2. Continuous Code Traversal")
print("-" * 40)
print(f"Varying continuous code 0 from -2 to +2")
print("Expected: Smooth variation in rotation/thickness (if trained)")
# Demonstrate continuous traversal
num_steps = 10
c_range = torch.linspace(-2, 2, num_steps, device=device)
z_fixed = torch.randn(1, config.latent_dim, device=device).repeat(num_steps, 1)
c_cat_fixed = F.one_hot(
torch.zeros(num_steps, dtype=torch.long, device=device),
num_classes=config.num_categorical
).float()
c_cont = torch.zeros(num_steps, config.num_continuous, device=device)
c_cont[:, 0] = c_range
c_combined = combine_codes(c_cat_fixed, c_cont)
imgs_cont = generator(z_fixed, c_combined)
print(f"Generated {num_steps} images with varying continuous code")
print("\n3. Independence of Codes")
print("-" * 40)
print("Key property: Changing one code should not affect others")
print("Example: Varying rotation should not change digit identity")
print("This is enforced by:")
print(" - Factorized prior: p(c) = p(c_cat) * p(c_cont)")
print(" - Mutual information maximization")
# Run demonstrations
if __name__ == "__main__":
print("Starting InfoGAN demonstrations...\n")
# Main demonstration
demo_infogan()
# Disentanglement concept
demo_code_disentanglement()
print("\n" + "=" * 80)
print("InfoGAN Implementation Complete!")
print("=" * 80)
print("\nKey Components Implemented:")
print("β Generator with latent code conditioning")
print("β Shared Discriminator + Q network")
print("β Categorical and continuous latent codes")
print("β Mutual information loss (variational lower bound)")
print("β Training utilities")
print("β Visualization functions")
print("\nTo train on real data (e.g., MNIST):")
print("1. Load dataset")
print("2. Create data loader")
print("3. Train for ~50 epochs")
print("4. Visualize learned disentangled codes")
print("\nExpected discoveries on MNIST:")
print("- Categorical code: Digit identity (0-9)")
print("- Continuous code 1: Rotation angle")
print("- Continuous code 2: Stroke thickness/width")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import seaborn as sns
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
sns.set_style('whitegrid')
1. Motivation: Disentangled RepresentationsΒΆ
Problem with Standard GANs:ΒΆ
In vanilla GAN, latent code \(z\) is unstructured:
No control over generated features
Entangled representations (changing \(z\) affects multiple attributes)
No interpretability
Info-GAN Solution:ΒΆ
Decompose latent variables: $\(z = (z_c, z_n)\)$
where:
\(z_c\) = latent code (categorical or continuous) - should be interpretable
\(z_n\) = noise (unstructured, incompressible)
Goal: Learn disentangled \(z_c\) such that each dimension controls a specific factor of variation.
π Reference Materials:
gan.pdf - Gan
2. Information-Theoretic ObjectiveΒΆ
Mutual Information:ΒΆ
Measure statistical dependence between latent code \(c\) and generated data \(G(z, c)\):
where:
\(H(c)\) = entropy of latent code (fixed by design)
\(H(c | G(z, c))\) = conditional entropy (low when \(c\) is recoverable from \(G(z, c)\))
Intuition: If we can recover \(c\) from the generated image, then \(c\) is meaningful.
Info-GAN Objective:ΒΆ
Standard GAN objective + mutual information maximization.
Challenge:ΒΆ
Computing \(I(c; G(z, c))\) requires \(p(c | x)\), which is intractable.
Solution: Use variational lower bound with auxiliary distribution \(Q(c | x)\).
3. Variational Mutual Information MaximizationΒΆ
Lemma (Variational Lower Bound):ΒΆ
where \(Q(c | x)\) is an auxiliary network that tries to predict \(c\) from \(x = G(z, c)\).
Proof Sketch:ΒΆ
Practical Objective:ΒΆ
Implementation: Add auxiliary network \(Q\) that shares parameters with discriminator \(D\).
Info-GAN ArchitectureΒΆ
Info-GAN extends the standard GAN by splitting the input to the generator into three components: the standard noise vector \(z\), a categorical code \(c_1\) that captures discrete variation (like digit identity in MNIST), and continuous codes \(c_2\) that capture smooth factors (like rotation angle or stroke width). The key architectural addition is a Q-network (often sharing layers with the discriminator) that predicts the latent codes from a generated image. By maximizing the mutual information \(I(c; G(z, c))\) between the codes and the generated output, Info-GAN forces the generator to use the codes in a semantically meaningful way β learning disentangled representations without any labels.
class Generator(nn.Module):
def __init__(self, latent_dim, code_dim, img_shape):
super().__init__()
self.img_shape = img_shape
input_dim = latent_dim + code_dim
self.fc = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z, c):
"""Generate image from noise z and latent code c."""
x = torch.cat([z, c], dim=1)
x = self.fc(x)
return x.view(x.size(0), *self.img_shape)
class Discriminator(nn.Module):
def __init__(self, img_shape):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2)
)
# Discriminator head
self.adv_head = nn.Sequential(
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
x = img.view(img.size(0), -1)
features = self.shared(x)
validity = self.adv_head(features)
return validity, features
class QNetwork(nn.Module):
"""Auxiliary network Q to predict latent code from generated images."""
def __init__(self, n_categorical, n_continuous):
super().__init__()
self.n_categorical = n_categorical
self.n_continuous = n_continuous
# Takes features from discriminator
if n_categorical > 0:
self.categorical = nn.Linear(256, n_categorical)
if n_continuous > 0:
self.continuous_mu = nn.Linear(256, n_continuous)
self.continuous_logvar = nn.Linear(256, n_continuous)
def forward(self, features):
"""Predict latent code distribution from discriminator features."""
outputs = {}
if self.n_categorical > 0:
outputs['categorical'] = F.softmax(self.categorical(features), dim=1)
if self.n_continuous > 0:
outputs['continuous_mu'] = self.continuous_mu(features)
outputs['continuous_logvar'] = self.continuous_logvar(features)
return outputs
# Test architecture
latent_dim = 62
n_categorical = 10 # e.g., digit identity for MNIST
n_continuous = 2 # e.g., rotation, width
code_dim = n_categorical + n_continuous
img_shape = (1, 28, 28)
G = Generator(latent_dim, code_dim, img_shape).to(device)
D = Discriminator(img_shape).to(device)
Q = QNetwork(n_categorical, n_continuous).to(device)
print(f"Generator parameters: {sum(p.numel() for p in G.parameters()):,}")
print(f"Discriminator parameters: {sum(p.numel() for p in D.parameters()):,}")
print(f"Q Network parameters: {sum(p.numel() for p in Q.parameters()):,}")
5. Loss FunctionsΒΆ
Discriminator Loss (Standard GAN):ΒΆ
Generator Loss:ΒΆ
Two components:
Fool discriminator (standard GAN)
Maximize mutual information (Info-GAN)
Q Network Loss:ΒΆ
For categorical code \(c_{cat}\): $\(L_{cat} = -\mathbb{E} [\log Q(c_{cat} | G(z, c))] = \text{CrossEntropy}\)$
For continuous code \(c_{cont}\) (assume Gaussian): $\(L_{cont} = -\mathbb{E} \left[ -\frac{(c_{cont} - \mu_{Q})^2}{2\sigma_Q^2} - \log \sigma_Q \right]\)$
def sample_latent_code(batch_size, n_categorical, n_continuous):
"""Sample latent code: categorical (one-hot) + continuous (uniform)."""
code = []
if n_categorical > 0:
# Sample categorical code as one-hot
cat_code = np.random.randint(0, n_categorical, batch_size)
cat_code_onehot = np.zeros((batch_size, n_categorical))
cat_code_onehot[np.arange(batch_size), cat_code] = 1
code.append(torch.FloatTensor(cat_code_onehot))
if n_continuous > 0:
# Sample continuous code from uniform [-1, 1]
cont_code = torch.FloatTensor(batch_size, n_continuous).uniform_(-1, 1)
code.append(cont_code)
return torch.cat(code, dim=1)
def mutual_info_loss(Q_outputs, true_code, n_categorical, n_continuous):
"""Compute mutual information loss for Q network."""
loss = 0
if n_categorical > 0:
# Categorical: cross entropy
cat_target = torch.argmax(true_code[:, :n_categorical], dim=1)
loss += F.cross_entropy(Q_outputs['categorical'], cat_target)
if n_continuous > 0:
# Continuous: Gaussian negative log-likelihood
cont_target = true_code[:, n_categorical:]
mu = Q_outputs['continuous_mu']
logvar = Q_outputs['continuous_logvar']
# NLL = 0.5 * (log(2Ο) + logvar + (x - mu)^2 / exp(logvar))
nll = 0.5 * (logvar + (cont_target - mu)**2 / torch.exp(logvar))
loss += nll.mean()
return loss
# Test loss computation
batch_size = 64
z = torch.randn(batch_size, latent_dim).to(device)
c = sample_latent_code(batch_size, n_categorical, n_continuous).to(device)
with torch.no_grad():
fake_imgs = G(z, c)
_, features = D(fake_imgs)
Q_outputs = Q(features)
mi_loss = mutual_info_loss(Q_outputs, c, n_categorical, n_continuous)
print(f"Test MI loss: {mi_loss.item():.4f}")
print(f"Categorical prediction shape: {Q_outputs['categorical'].shape}")
print(f"Continuous mu shape: {Q_outputs['continuous_mu'].shape}")
Training Info-GAN on MNISTΒΆ
Training proceeds with three simultaneous objectives: the standard adversarial loss for the generator and discriminator, plus an information-theoretic regularizer that maximizes a lower bound on the mutual information \(I(c; G(z, c))\). In practice the Q-network outputs parameters of the code distributions β a softmax for categorical codes and (mean, variance) for continuous codes β and the mutual information is estimated via the variational bound. The hyperparameter \(\lambda\) controls the weight of the information loss relative to the adversarial loss. Watching the categorical accuracy and continuous code reconstruction during training confirms whether the network is learning meaningful disentanglement.
# Load MNIST
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # [-1, 1]
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True)
# Optimizers
lr = 0.0002
optimizer_G = optim.Adam(list(G.parameters()) + list(Q.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
# Training loop
n_epochs = 15
lambda_mi = 1.0 # Weight for mutual information loss
G.train()
D.train()
Q.train()
history = {'D_loss': [], 'G_loss': [], 'MI_loss': []}
for epoch in range(n_epochs):
epoch_D_loss = 0
epoch_G_loss = 0
epoch_MI_loss = 0
for batch_idx, (real_imgs, _) in enumerate(train_loader):
batch_size = real_imgs.size(0)
real_imgs = real_imgs.to(device)
# Labels
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# =====================
# Train Discriminator
# =====================
optimizer_D.zero_grad()
# Real images
real_validity, _ = D(real_imgs)
d_real_loss = F.binary_cross_entropy(real_validity, real_labels)
# Fake images
z = torch.randn(batch_size, latent_dim).to(device)
c = sample_latent_code(batch_size, n_categorical, n_continuous).to(device)
fake_imgs = G(z, c)
fake_validity, _ = D(fake_imgs.detach())
d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels)
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
optimizer_D.step()
# =====================
# Train Generator + Q
# =====================
optimizer_G.zero_grad()
# Generate new samples
z = torch.randn(batch_size, latent_dim).to(device)
c = sample_latent_code(batch_size, n_categorical, n_continuous).to(device)
fake_imgs = G(z, c)
# Adversarial loss
fake_validity, features = D(fake_imgs)
g_loss = F.binary_cross_entropy(fake_validity, real_labels)
# Mutual information loss
Q_outputs = Q(features)
mi_loss = mutual_info_loss(Q_outputs, c, n_categorical, n_continuous)
# Total generator loss
total_g_loss = g_loss + lambda_mi * mi_loss
total_g_loss.backward()
optimizer_G.step()
# Track losses
epoch_D_loss += d_loss.item()
epoch_G_loss += g_loss.item()
epoch_MI_loss += mi_loss.item()
# Average losses
epoch_D_loss /= len(train_loader)
epoch_G_loss /= len(train_loader)
epoch_MI_loss /= len(train_loader)
history['D_loss'].append(epoch_D_loss)
history['G_loss'].append(epoch_G_loss)
history['MI_loss'].append(epoch_MI_loss)
print(f"Epoch [{epoch+1}/{n_epochs}] D_loss: {epoch_D_loss:.4f} | "
f"G_loss: {epoch_G_loss:.4f} | MI_loss: {epoch_MI_loss:.4f}")
print("\nTraining complete!")
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
ax1 = axes[0]
ax1.plot(history['D_loss'], label='Discriminator', linewidth=2)
ax1.plot(history['G_loss'], label='Generator', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Adversarial Losses', fontsize=13)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax2 = axes[1]
ax2.plot(history['MI_loss'], color='green', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('MI Loss', fontsize=12)
ax2.set_title('Mutual Information Loss', fontsize=13)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
7. Analyzing Learned Latent CodesΒΆ
Categorical Code VariationΒΆ
Fix noise \(z\) and continuous code, vary categorical code to see if it captures digit identity.
G.eval()
# Fix noise and continuous code
z_fixed = torch.randn(1, latent_dim).to(device)
c_cont_fixed = torch.zeros(1, n_continuous).to(device)
# Vary categorical code
fig, axes = plt.subplots(1, n_categorical, figsize=(15, 2))
with torch.no_grad():
for cat_idx in range(n_categorical):
# Create one-hot categorical code
c_cat = torch.zeros(1, n_categorical).to(device)
c_cat[0, cat_idx] = 1
# Combine codes
c = torch.cat([c_cat, c_cont_fixed], dim=1)
# Generate
img = G(z_fixed, c)
img = img.cpu().numpy().squeeze()
axes[cat_idx].imshow(img, cmap='gray')
axes[cat_idx].set_title(f'Cat={cat_idx}', fontsize=10)
axes[cat_idx].axis('off')
plt.suptitle('Categorical Code Variation (Fixed Noise)', fontsize=14)
plt.tight_layout()
plt.show()
print("If trained well, each column should show a different digit class.")
Continuous Code VariationΒΆ
To verify that the continuous latent codes have learned interpretable factors of variation, we fix the noise vector \(z\) and the categorical code, then sweep each continuous code across a range (typically \([-2, 2]\)). If training was successful, each continuous code should control a single, smooth visual attribute β for example, one code might govern digit width while another controls slant. This kind of traversal visualization is the standard evaluation protocol for disentangled representations and directly demonstrates why Info-GAN is valuable: it discovers meaningful structure in the data without any supervision.
# Fix noise and categorical code (digit 3 for example)
z_fixed = torch.randn(1, latent_dim).to(device)
c_cat_fixed = torch.zeros(1, n_categorical).to(device)
c_cat_fixed[0, 3] = 1 # Choose digit 3
# Vary first continuous code
n_samples = 10
c1_values = np.linspace(-2, 2, n_samples)
fig, axes = plt.subplots(2, n_samples, figsize=(15, 4))
with torch.no_grad():
# Row 1: Vary c1, fix c2=0
for i, c1 in enumerate(c1_values):
c_cont = torch.tensor([[c1, 0]], dtype=torch.float32).to(device)
c = torch.cat([c_cat_fixed, c_cont], dim=1)
img = G(z_fixed, c).cpu().numpy().squeeze()
axes[0, i].imshow(img, cmap='gray')
axes[0, i].set_title(f'c1={c1:.1f}', fontsize=9)
axes[0, i].axis('off')
# Row 2: Vary c2, fix c1=0
for i, c2 in enumerate(c1_values):
c_cont = torch.tensor([[0, c2]], dtype=torch.float32).to(device)
c = torch.cat([c_cat_fixed, c_cont], dim=1)
img = G(z_fixed, c).cpu().numpy().squeeze()
axes[1, i].imshow(img, cmap='gray')
axes[1, i].set_title(f'c2={c2:.1f}', fontsize=9)
axes[1, i].axis('off')
axes[0, 0].set_ylabel('Vary c1', fontsize=11)
axes[1, 0].set_ylabel('Vary c2', fontsize=11)
plt.suptitle('Continuous Code Variation (e.g., rotation, width)', fontsize=14)
plt.tight_layout()
plt.show()
print("Continuous codes should capture smooth transformations (e.g., rotation, thickness).")
SummaryΒΆ
Key Contributions of Info-GAN:ΒΆ
Unsupervised disentanglement - learns interpretable factors without labels
Mutual information maximization - ensures latent codes are meaningful
Variational bound - tractable optimization via auxiliary network \(Q\)
Minimal overhead - only adds small network \(Q\), shares features with \(D\)
Applications:ΒΆ
Controllable generation - manipulate specific attributes
Representation learning - discover data structure automatically
Data augmentation - generate variations along learned factors
Interpretability - understand what GAN has learned
Limitations:ΒΆ
Requires choosing number of codes \(|c|\) in advance
No guarantee which code learns which factor
Training can be unstable (GAN issues + MI objective)
Extensions:ΒΆ
Ξ²-VAE + Info-GAN - combine strengths
Multi-level Info-GAN - hierarchical disentanglement
Info-GAN for video - temporal disentanglement
Next Steps:ΒΆ
06_conditional_gan.ipynb - Class-conditional generation
07_bayesian_gan.ipynb - Uncertainty in generative models
Return to 03_variational_autoencoders_advanced.ipynb for disentanglement comparison