Neural Radiance Fields (NeRF): Comprehensive TheoryΒΆ
IntroductionΒΆ
Neural Radiance Fields (NeRF) revolutionized 3D scene representation and novel view synthesis by representing scenes as continuous volumetric radiance fields encoded by deep neural networks. Instead of explicit 3D representations (meshes, point clouds, voxels), NeRF learns an implicit function that maps 5D coordinates (3D location + 2D viewing direction) to volume density and view-dependent emitted radiance.
Key Innovation: Representing complex 3D scenes as weights of a neural network, achieving photorealistic novel view synthesis from sparse input views.
Mathematical FoundationΒΆ
Scene RepresentationΒΆ
NeRF represents a static scene as a continuous function:
Where:
\(\mathbf{x} = (x, y, z) \in \mathbb{R}^3\) is a 3D location
\(\mathbf{d} = (\theta, \phi) \in \mathbb{S}^2\) is a viewing direction (unit vector)
\(\mathbf{c} = (r, g, b) \in [0,1]^3\) is emitted RGB color
\(\sigma \in \mathbb{R}^+ \) is volume density (opacity)
\(\Theta\) represents network parameters
Key Property: Volume density \(\sigma(\mathbf{x})\) is view-independent (geometric property), while color \(\mathbf{c}(\mathbf{x}, \mathbf{d})\) is view-dependent (captures specular reflections, viewing angle effects).
Volume RenderingΒΆ
To render a pixel, NeRF performs volumetric rendering along camera rays. For a camera ray \(\mathbf{r}(t) = \mathbf{o} + t\mathbf{d}\) (origin \(\mathbf{o}\), direction \(\mathbf{d}\), parameter \(t\)), the expected color is:
Where transmittance \(T(t)\) is the accumulated transparency from \(t_n\) to \(t\):
Interpretation:
\(T(t)\) is the probability that the ray travels from \(t_n\) to \(t\) without hitting any particle
\(\sigma(\mathbf{r}(t)) \cdot dt\) is the probability of ray termination in infinitesimal segment
The integral accumulates color contributions weighted by opacity and visibility
Discrete ApproximationΒΆ
In practice, we discretize the continuous integral using quadrature:
Where:
\(\alpha_i = 1 - \exp(-\sigma_i \delta_i)\) is alpha compositing opacity
\(\delta_i = t_{i+1} - t_i\) is the distance between adjacent samples
\(T_i = \exp\left(-\sum_{j=1}^{i-1} \sigma_j \delta_j\right) = \prod_{j=1}^{i-1}(1 - \alpha_j)\) is discrete transmittance
Sampling Strategy: Sample points along ray using stratified sampling to ensure continuous coverage:
Network ArchitectureΒΆ
Two-Stage MLPΒΆ
NeRF uses an 8-layer MLP with a skip connection:
Density Branch (layers 1-8):
Input: 3D position \(\mathbf{x}\) (after positional encoding)
Hidden: 256 units per layer, ReLU activations
Skip connection: Concatenate input at layer 5
Output: 256D feature vector + density \(\sigma\)
Color Branch (layer 9):
Input: 256D features + viewing direction \(\mathbf{d}\) (after positional encoding)
Hidden: 128 units, ReLU
Output: RGB color \(\mathbf{c}\)
Architecture Insight: Separating density and color allows the network to learn view-independent geometry while modeling view-dependent appearance.
Positional EncodingΒΆ
Critical innovation: Map input coordinates to higher-dimensional space using sinusoidal functions:
For 3D position: \(L = 10\) (60D output)
For 2D direction: \(L = 4\) (24D output)
Why It Works:
Neural networks have spectral bias toward low-frequency functions
Positional encoding projects inputs to high-frequency space
Enables learning of high-frequency variations in color and geometry
Related to Fourier feature mapping and coordinate-based MLPs
Mathematical Intuition: This is similar to a Fourier basis expansion, allowing the MLP to learn arbitrary functions via the universal approximation property in this expanded feature space.
Hierarchical Volume SamplingΒΆ
NeRF uses a coarse-to-fine sampling strategy to allocate network capacity efficiently:
Coarse Network (\(F_c\))ΒΆ
Sample \(N_c\) points uniformly (stratified) along each ray
Query coarse network \(F_c\) at these positions
Compute coarse rendering \(\hat{C}_c(\mathbf{r})\)
Fine Network (\(F_f\))ΒΆ
Use coarse networkβs density predictions to compute importance sampling weights:
Sample additional \(N_f\) points from piecewise-constant PDF based on \(w_i\)
Combine coarse and fine samples (total \(N_c + N_f\) points)
Query fine network \(F_f\) and compute fine rendering \(\hat{C}_f(\mathbf{r})\)
Benefit: Focus samples in regions with high expected contribution to final rendering, improving efficiency and quality.
Loss FunctionΒΆ
Where \(\mathcal{R}\) is a batch of camera rays, and \(C(\mathbf{r})\) is the ground truth pixel color.
Training StrategiesΒΆ
Ray BatchingΒΆ
Sample random batch of camera rays from training images
Typical: 4096 rays per batch
Each ray requires \(N_c + N_f\) network queries (e.g., 64 + 128 = 192)
Total: ~786,432 network evaluations per batch
OptimizationΒΆ
Optimizer: Adam
Learning rate: \(5 \times 10^{-4}\) with exponential decay
Training: 100,000-300,000 iterations
Hardware: Single NVIDIA V100 GPU (~1-2 days per scene)
RegularizationΒΆ
Weight Decay: L2 regularization on network weights prevents overfitting on sparse views.
Ray Jittering: Add small random perturbations to ray directions for augmentation.
Advanced NeRF VariantsΒΆ
1. Mip-NeRF (Multiscale Representation)ΒΆ
Problem: Original NeRF samples single points along rays, causing aliasing and ambiguity at different scales.
Solution: Instead of querying discrete points, query 3D conical frustums that account for pixel footprint.
Integrated Positional Encoding (IPE): $\(\mathbb{E}_{\mathbf{x} \sim p(\mathbf{x})}[\gamma(\mathbf{x})] \approx \gamma(\boldsymbol{\mu}, \boldsymbol{\Sigma})\)$
Where \(\boldsymbol{\mu}, \boldsymbol{\Sigma}\) are mean and covariance of Gaussian approximation to conical frustum.
Benefit: Anti-aliasing, improved rendering at different resolutions, faster convergence.
2. Instant NGP (Neural Graphics Primitives)ΒΆ
Problem: NeRF training is slow (~days per scene).
Solution: Multi-resolution hash encoding with small MLP.
Hash Encoding: $\(\mathbf{h} = \bigoplus_{l=1}^{L} \text{Interp}(\mathbf{x}; H_l)\)$
Where \(H_l\) is a hash table at resolution level \(l\), and Interp performs trilinear interpolation.
Performance: 1000Γ faster training (~5 seconds per scene), interactive rendering.
3. NeRF in the Wild (NeRF-W)ΒΆ
Problem: Original NeRF assumes static scenes with consistent lighting.
Solution: Model appearance variations and transient objects.
Extended Representation: $\(F_\Theta: (\mathbf{x}, \mathbf{d}, \mathbf{l}_i, \boldsymbol{\tau}_i) \rightarrow (\mathbf{c}, \sigma, \beta)\)$
\(\mathbf{l}_i\): Per-image appearance embedding
\(\boldsymbol{\tau}_i\): Per-image transient embedding
\(\beta\): Transient density (for moving objects, lighting changes)
4. PlenoxelsΒΆ
Problem: MLP inference is slow for real-time rendering.
Solution: Replace MLP with sparse voxel grid (spherical harmonics for view-dependence).
Rendering: Direct interpolation from voxel grid (no network inference).
Trade-off: Memory vs. speed (faster rendering but larger storage).
5. TensoRFΒΆ
Problem: Voxel grids scale poorly to high resolutions (memory \(O(N^3)\)).
Solution: Tensor decomposition of radiance field.
CP Decomposition: $\(\mathcal{F}(x,y,z) = \sum_{r=1}^{R} \mathbf{v}^X_r[x] \cdot \mathbf{v}^Y_r[y] \cdot \mathbf{v}^Z_r[z]\)$
VM Decomposition (Vector-Matrix): $\(\mathcal{F}(x,y,z) = \sum_{r=1}^{R} \mathbf{v}^Z_r[z] \cdot \mathbf{M}^{XY}_r[x,y]\)$
Benefit: Compact representation, fast rendering, improved quality.
6. Nerfies / HyperNeRFΒΆ
Problem: NeRF assumes rigid static scenes.
Solution: Model non-rigid deformations for dynamic/deformable scenes.
Deformation Field: $\(\mathbf{x}' = \mathbf{x} + \Delta\mathbf{x}(\mathbf{x}, t)\)$
Query canonical NeRF at deformed position \(\mathbf{x}'\).
Applications: Face reenactment, human performance capture, dynamic object modeling.
7. NeRF++ (Unbounded Scenes)ΒΆ
Problem: Original NeRF assumes bounded scenes.
Solution: Separate foreground and background modeling with inverted sphere parameterization.
Background Representation: $\(\mathbf{x}_\text{bg} = \frac{\mathbf{d}}{|\mathbf{d}|^2}\)$
Map infinite rays to bounded space using inverse radius.
Dynamic NeRF ExtensionsΒΆ
D-NeRF (Dynamic NeRF)ΒΆ
Representation: $\(F_\Theta: (\mathbf{x}, t) \rightarrow (\Delta\mathbf{x}, \mathbf{c}, \sigma)\)$
Canonical space: Static NeRF
Deformation network: Maps \((\mathbf{x}, t)\) to canonical coordinates
Separate networks for deformation and canonical radiance
Neural Scene Flow FieldsΒΆ
Represent 4D space-time as continuous flow field: $\(F: (\mathbf{x}, t_1, t_2) \rightarrow \Delta\mathbf{x}_{t_1 \rightarrow t_2}\)$
Track 3D points across time for dynamic reconstruction.
Generalization and Few-Shot NeRFΒΆ
pixelNeRFΒΆ
Problem: NeRF requires many views per scene (50-100 images).
Solution: Condition NeRF on input image features (meta-learning across scenes).
Architecture: $\(F_\Theta: (\mathbf{x}, \mathbf{d}, \mathbf{W}(\mathbf{x})) \rightarrow (\mathbf{c}, \sigma)\)$
Where \(\mathbf{W}(\mathbf{x})\) is image feature extracted from CNN encoder at projected position.
Benefit: Single or few-view novel view synthesis.
IBRNet / MVSNeRFΒΆ
Combine multi-view stereo with NeRF for few-shot generalization:
Extract features from source views
Aggregate features along epipolar lines
Condition NeRF decoder on aggregated features
Semantic and Controllable NeRFΒΆ
Semantic-NeRFΒΆ
Extended Output: $\(F_\Theta: (\mathbf{x}, \mathbf{d}) \rightarrow (\mathbf{c}, \sigma, \mathbf{s})\)$
Where \(\mathbf{s}\) is semantic logits for object classes.
Rendering: Composite semantics using same volume rendering equation.
Editable NeRFΒΆ
Object Decomposition: Represent scene as composition of object NeRFs.
Control: Manipulate object poses, appearances, and properties independently.
Text-to-3D and Diffusion NeRFΒΆ
DreamFusionΒΆ
Problem: No 3D training data, only 2D text-to-image models.
Solution: Score Distillation Sampling (SDS) loss.
SDS Loss: $\(\mathcal{L}_\text{SDS} = \mathbb{E}_{t,\epsilon}\left[w(t) \left\|\epsilon_\phi(\mathbf{z}_t; t, y) - \epsilon\right\|_2^2\right]\)$
Where:
\(\epsilon_\phi\): Pre-trained diffusion model (e.g., Stable Diffusion)
\(\mathbf{z}_t\): Noised rendering from NeRF
\(y\): Text prompt
\(\epsilon\): Random noise
Interpretation: Optimize NeRF to produce renderings that the diffusion model considers high-probability samples for the text prompt.
Magic3DΒΆ
Improvements over DreamFusion:
Coarse-to-fine optimization (low-res β high-res)
Explicit mesh extraction
2Γ faster, higher quality
NeRF for RelightingΒΆ
NeRF-W (Lighting Variation)ΒΆ
Problem: Model scenes with varying illumination.
Solution: Per-image illumination embedding.
NeRD (Neural Reflectance Decomposition)ΒΆ
Physics-Based Decomposition: $\(\mathbf{c} = \mathbf{k}_d \cdot L_d + \mathbf{k}_s \cdot L_s\)$
\(\mathbf{k}_d\): Diffuse albedo
\(\mathbf{k}_s\): Specular albedo
\(L_d, L_s\): Diffuse and specular lighting
Benefit: Relight scenes under novel illumination conditions.
Acceleration TechniquesΒΆ
1. Baking / DistillationΒΆ
Distill NeRF into faster representations:
SNeRG: Distill into sparse voxel grid with view-dependent features
MobileNeRF: Distill into textured polygonal mesh
Trade-off: Quality vs. speed
2. Early Ray TerminationΒΆ
Stop marching along ray when accumulated opacity exceeds threshold: $\(T_i < \epsilon \implies \text{terminate}\)$
Benefit: Skip sampling in empty or fully opaque regions.
3. Occupancy GridsΒΆ
Maintain coarse 3D grid indicating occupied regions:
Skip sampling in empty voxels
Update grid periodically during training
Speed-up: 2-10Γ faster training
4. Neural Sparse Voxel Fields (NSVF)ΒΆ
Structure: Octree of voxels with learned features.
Rendering: Only evaluate MLP in occupied voxels.
Benefit: 10Γ faster than NeRF with comparable quality.
NeRF for Novel ApplicationsΒΆ
1. Medical ImagingΒΆ
CT/MRI Reconstruction: Reconstruct 3D volumes from sparse 2D slices
Surgical Planning: Novel viewpoint generation for pre-operative visualization
2. RoboticsΒΆ
Scene Understanding: Implicit 3D maps for navigation
Manipulation: Object pose estimation and grasp planning
3. AR/VRΒΆ
Telepresence: Real-time novel view synthesis for immersive communication
Content Creation: Simplified 3D asset generation from photos
4. Autonomous DrivingΒΆ
Simulation: Generate realistic training data with controllable viewpoints
Mapping: Compact scene representations for localization
Theoretical AnalysisΒΆ
Optimization LandscapeΒΆ
Theorem (Universal Approximation): A sufficiently large MLP with positional encoding can approximate any continuous radiance field to arbitrary precision.
Proof Sketch:
Positional encoding spans high-frequency Fourier basis
Universal approximation theorem for neural networks
Composition yields arbitrary precision
Challenges:
Non-convex optimization landscape
Local minima (floaters, artifacts)
Requires good initialization and careful hyperparameter tuning
Sample ComplexityΒΆ
Question: How many views are needed for accurate reconstruction?
Analysis:
Bounded scenes: \(O(N^2)\) views for Nyquist sampling
With positional encoding: \(O(N)\) views sufficient (smooth interpolation)
Generalization methods: Single-view possible with meta-learning
Frequency BiasΒΆ
Spectral Bias: MLPs learn low frequencies first.
Positional Encoding Effect: Shifts learning to high frequencies, but can cause overfitting.
Coarse-to-Fine Strategy: Gradually increase positional encoding frequencies during training.
Limitations and Open ProblemsΒΆ
Current LimitationsΒΆ
Training Time: Original NeRF takes hours to days per scene
Static Scenes: Difficult to model complex dynamics
View Dependence: Specular surfaces and transparency remain challenging
Memory: Full-resolution NeRF requires significant GPU memory
Generalization: Limited ability to generalize across scenes without meta-learning
Open Research DirectionsΒΆ
Real-Time Rendering: Achieve interactive frame rates on consumer hardware
Physical Accuracy: Incorporate physically-based rendering, global illumination
Compositional Scenes: Disentangle objects, materials, lighting automatically
Multi-Modal Fusion: Combine NeRF with LiDAR, depth sensors, semantic segmentation
Uncertainty Quantification: Estimate confidence in novel view predictions
Scalability: Handle large-scale environments (city-scale, planet-scale)
Mathematical ConnectionsΒΆ
Relation to Signed Distance Functions (SDF)ΒΆ
NeRF density can be related to SDF gradient: $\(\sigma(\mathbf{x}) = \alpha \cdot \|\nabla f(\mathbf{x})\|\)$
Where \(f(\mathbf{x})\) is signed distance to surface.
Benefit: Easier surface extraction (isosurface at \(f = 0\)).
Connection to Light FieldsΒΆ
NeRF is a compressed, continuous representation of the 4D light field: $\(L(x, y, z, \theta, \phi)\)$
Advantage: Neural representation is far more compact than discrete light field grids.
Volume Rendering as IntegrationΒΆ
The rendering equation is a special case of radiative transfer: $\(\frac{d L}{ds} = -\sigma L + \sigma L_e\)$
Where \(L\) is radiance, \(s\) is distance along ray, \(L_e\) is emission.
Practical ConsiderationsΒΆ
Dataset PreparationΒΆ
COLMAP: Structure-from-Motion to estimate camera poses
Intrinsics: Camera calibration (focal length, principal point)
Image Preprocessing: Undistortion, white balancing
Bounds: Estimate near and far planes for ray marching
Hyperparameter TuningΒΆ
Network Size: 256 units typical, 128 for faster training
Positional Encoding Frequencies: \(L=10\) for position, \(L=4\) for direction
Sampling Rates: \(N_c = 64\), \(N_f = 128\) balances quality and speed
Learning Rate: \(5 \times 10^{-4}\) with exponential decay to \(5 \times 10^{-5}\)
Debugging TipsΒΆ
Density Check: Visualize density field (should be concentrated near surfaces)
Color Check: Render from training views (should match ground truth)
Transmittance: Check accumulated opacity (should reach ~1.0 for opaque scenes)
Loss Curves: Monitor coarse and fine losses (should decrease smoothly)
Evaluation MetricsΒΆ
Image QualityΒΆ
PSNR (Peak Signal-to-Noise Ratio): $\(\text{PSNR} = 10 \log_{10} \frac{\text{MAX}^2}{\text{MSE}}\)$
SSIM (Structural Similarity): $\(\text{SSIM}(x,y) = \frac{(2\mu_x\mu_y + c_1)(2\sigma_{xy} + c_2)}{(\mu_x^2 + \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}\)$
LPIPS (Learned Perceptual Image Patch Similarity): Uses deep features from VGG network to measure perceptual distance.
Geometry QualityΒΆ
Chamfer Distance: Measures distance between predicted and ground truth point clouds
Mesh Accuracy: After surface extraction (e.g., via marching cubes)
ConclusionΒΆ
Neural Radiance Fields represent a paradigm shift in 3D scene representation, replacing explicit geometric structures with implicit neural functions. The core innovationβvolumetric rendering with positional encodingβenables photorealistic novel view synthesis from sparse inputs.
Key Takeaways:
Continuous Representation: NeRF models scenes as continuous 5D functions
Differentiable Rendering: End-to-end optimization via gradient descent
Implicit Geometry: Density field encodes 3D structure without explicit meshes
View Synthesis: State-of-the-art quality for novel viewpoint generation
Future Impact: NeRF and its variants are transforming computer vision, graphics, and AR/VR, with applications ranging from content creation to robotics and medical imaging. The field continues to evolve rapidly with improvements in speed, generalization, and controllability.
"""
Neural Radiance Fields (NeRF) - Complete Implementation
This implementation includes:
1. Positional encoding for high-frequency details
2. NeRF MLP architecture with density and color branches
3. Volumetric rendering with hierarchical sampling
4. Coarse and fine networks
5. Training pipeline with ray batching
6. Visualization utilities
7. Extensions: Mip-NeRF components, instant NGP concepts
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List, Dict
from dataclasses import dataclass
import math
# ============================================================================
# 1. POSITIONAL ENCODING
# ============================================================================
class PositionalEncoding(nn.Module):
"""
Positional encoding using sinusoidal functions.
Ξ³(p) = (sin(2^0 Ο p), cos(2^0 Ο p), ..., sin(2^(L-1) Ο p), cos(2^(L-1) Ο p))
Args:
num_freqs: Number of frequency bands (L)
include_input: Whether to concatenate original input
"""
def __init__(self, num_freqs: int, include_input: bool = True):
super().__init__()
self.num_freqs = num_freqs
self.include_input = include_input
# Frequency bands: [2^0, 2^1, ..., 2^(L-1)]
freq_bands = 2.0 ** torch.linspace(0, num_freqs - 1, num_freqs)
self.register_buffer('freq_bands', freq_bands)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor (..., D)
Returns:
Encoded tensor (..., D * (2 * num_freqs + include_input))
"""
# x: (..., D)
# freq_bands: (L,)
# x[..., None, :]: (..., 1, D)
# freq_bands[..., None]: (L, 1)
# Product: (..., L, D)
encoded = []
if self.include_input:
encoded.append(x)
# Compute sin and cos for each frequency
x_freq = x[..., None, :] * self.freq_bands[..., None] # (..., L, D)
encoded.append(torch.sin(math.pi * x_freq).flatten(-2)) # (..., L*D)
encoded.append(torch.cos(math.pi * x_freq).flatten(-2)) # (..., L*D)
return torch.cat(encoded, dim=-1)
def get_output_dim(self, input_dim: int) -> int:
"""Calculate output dimension after encoding."""
return input_dim * (2 * self.num_freqs + int(self.include_input))
# ============================================================================
# 2. NERF MLP ARCHITECTURE
# ============================================================================
class NeRFMLP(nn.Module):
"""
NeRF MLP architecture with skip connections.
Architecture:
- 8 layers with 256 units (ReLU activation)
- Skip connection at layer 5
- Density head: 1 output (Ο)
- Feature vector: 256D
- Color head: Takes features + viewing direction β RGB
Args:
pos_enc_dim: Dimension after positional encoding for position (60 for L=10)
dir_enc_dim: Dimension after positional encoding for direction (24 for L=4)
hidden_dim: Hidden layer dimension (default 256)
num_layers: Number of layers (default 8)
"""
def __init__(
self,
pos_enc_dim: int = 60,
dir_enc_dim: int = 24,
hidden_dim: int = 256,
num_layers: int = 8
):
super().__init__()
self.num_layers = num_layers
self.hidden_dim = hidden_dim
# Density branch (8 layers with skip connection)
self.density_layers = nn.ModuleList()
# First layer: position encoding β hidden
self.density_layers.append(nn.Linear(pos_enc_dim, hidden_dim))
# Intermediate layers
for i in range(1, num_layers):
if i == 4: # Skip connection at layer 5 (index 4)
self.density_layers.append(nn.Linear(hidden_dim + pos_enc_dim, hidden_dim))
else:
self.density_layers.append(nn.Linear(hidden_dim, hidden_dim))
# Density output (Ο β₯ 0)
self.density_head = nn.Linear(hidden_dim, 1)
# Feature vector for color computation
self.feature_linear = nn.Linear(hidden_dim, hidden_dim)
# Color branch (takes features + direction encoding)
self.color_layer = nn.Linear(hidden_dim + dir_enc_dim, hidden_dim // 2)
self.color_head = nn.Linear(hidden_dim // 2, 3)
def forward(
self,
pos_encoded: torch.Tensor,
dir_encoded: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
pos_encoded: Positional encoding (..., pos_enc_dim)
dir_encoded: Direction encoding (..., dir_enc_dim)
Returns:
rgb: Color (..., 3)
density: Volume density (..., 1)
"""
# Density branch
h = pos_encoded
for i, layer in enumerate(self.density_layers):
if i == 4: # Skip connection
h = torch.cat([h, pos_encoded], dim=-1)
h = F.relu(layer(h))
# Density output (ReLU to ensure non-negative)
density = F.relu(self.density_head(h))
# Feature vector for color
features = self.feature_linear(h)
# Color branch (view-dependent)
h = torch.cat([features, dir_encoded], dim=-1)
h = F.relu(self.color_layer(h))
rgb = torch.sigmoid(self.color_head(h)) # RGB in [0, 1]
return rgb, density
# ============================================================================
# 3. VOLUME RENDERING
# ============================================================================
def volume_rendering(
rgb: torch.Tensor,
density: torch.Tensor,
t_vals: torch.Tensor,
noise_std: float = 0.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Perform volume rendering along rays.
C(r) = Ξ£ T_i Β· Ξ±_i Β· c_i
Where:
- Ξ±_i = 1 - exp(-Ο_i Β· Ξ΄_i)
- T_i = exp(-Ξ£_{j<i} Ο_j Β· Ξ΄_j) = Ξ _{j<i} (1 - Ξ±_j)
- Ξ΄_i = t_{i+1} - t_i
Args:
rgb: Color values (N_rays, N_samples, 3)
density: Volume density (N_rays, N_samples, 1)
t_vals: Sample positions along rays (N_rays, N_samples)
noise_std: Add noise to density for regularization
Returns:
rgb_map: Rendered color (N_rays, 3)
depth_map: Expected depth (N_rays,)
acc_map: Accumulated opacity (N_rays,)
"""
# Add noise to density during training (regularization)
if noise_std > 0.0:
noise = torch.randn_like(density) * noise_std
density = density + noise
# Compute distances between samples (Ξ΄_i)
dists = t_vals[..., 1:] - t_vals[..., :-1] # (N_rays, N_samples-1)
dists = torch.cat([
dists,
torch.full_like(dists[..., :1], 1e10) # Last distance is infinity
], dim=-1) # (N_rays, N_samples)
# Compute alpha (opacity): Ξ±_i = 1 - exp(-Ο_i Β· Ξ΄_i)
alpha = 1.0 - torch.exp(-F.relu(density[..., 0]) * dists) # (N_rays, N_samples)
# Compute transmittance: T_i = Ξ _{j<i} (1 - Ξ±_j)
# Use cumulative product: T_i = exp(-Ξ£_{j<i} Ο_j Β· Ξ΄_j)
transmittance = torch.cumprod(
torch.cat([
torch.ones_like(alpha[..., :1]),
1.0 - alpha[..., :-1] + 1e-10 # Avoid log(0)
], dim=-1),
dim=-1
) # (N_rays, N_samples)
# Compute weights: w_i = T_i Β· Ξ±_i
weights = transmittance * alpha # (N_rays, N_samples)
# Render RGB: C(r) = Ξ£ w_i Β· c_i
rgb_map = torch.sum(weights[..., None] * rgb, dim=-2) # (N_rays, 3)
# Compute depth map: E[t] = Ξ£ w_i Β· t_i
depth_map = torch.sum(weights * t_vals, dim=-1) # (N_rays,)
# Accumulated opacity (for background composition)
acc_map = torch.sum(weights, dim=-1) # (N_rays,)
return rgb_map, depth_map, acc_map, weights
# ============================================================================
# 4. RAY GENERATION
# ============================================================================
def get_rays(
H: int,
W: int,
focal: float,
c2w: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate camera rays for all pixels.
Args:
H: Image height
W: Image width
focal: Focal length
c2w: Camera-to-world transformation matrix (3, 4) or (4, 4)
Returns:
rays_o: Ray origins (H, W, 3)
rays_d: Ray directions (H, W, 3)
"""
# Pixel coordinates
i, j = torch.meshgrid(
torch.arange(W, dtype=torch.float32),
torch.arange(H, dtype=torch.float32),
indexing='xy'
)
# Normalized image coordinates (assuming principal point at center)
dirs = torch.stack([
(i - W * 0.5) / focal,
-(j - H * 0.5) / focal, # Flip y-axis
-torch.ones_like(i)
], dim=-1) # (H, W, 3)
# Transform ray directions from camera to world coordinates
rays_d = torch.sum(dirs[..., None, :] * c2w[:3, :3], dim=-1) # (H, W, 3)
# Ray origin is camera center
rays_o = c2w[:3, -1].expand(rays_d.shape) # (H, W, 3)
return rays_o, rays_d
# ============================================================================
# 5. SAMPLING STRATEGIES
# ============================================================================
def sample_along_rays(
rays_o: torch.Tensor,
rays_d: torch.Tensor,
near: float,
far: float,
n_samples: int,
perturb: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample points along rays using stratified sampling.
Args:
rays_o: Ray origins (N_rays, 3)
rays_d: Ray directions (N_rays, 3)
near: Near plane distance
far: Far plane distance
n_samples: Number of samples per ray
perturb: Whether to add random perturbations
Returns:
pts: Sampled 3D points (N_rays, N_samples, 3)
t_vals: Sample positions along rays (N_rays, N_samples)
"""
N_rays = rays_o.shape[0]
device = rays_o.device
# Linearly spaced samples in [near, far]
t_vals = torch.linspace(0.0, 1.0, n_samples, device=device)
t_vals = near + (far - near) * t_vals # (N_samples,)
# Stratified sampling (add random jitter to each bin)
if perturb:
mids = 0.5 * (t_vals[:-1] + t_vals[1:])
upper = torch.cat([mids, t_vals[-1:]], dim=0)
lower = torch.cat([t_vals[:1], mids], dim=0)
t_rand = torch.rand(N_rays, n_samples, device=device)
t_vals = lower + (upper - lower) * t_rand # (N_rays, N_samples)
else:
t_vals = t_vals.expand(N_rays, n_samples)
# Compute 3D points: p = o + t * d
pts = rays_o[..., None, :] + rays_d[..., None, :] * t_vals[..., :, None]
# (N_rays, N_samples, 3)
return pts, t_vals
def sample_pdf(
bins: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
perturb: bool = True
) -> torch.Tensor:
"""
Sample from piecewise-constant PDF (for hierarchical sampling).
Args:
bins: Bin edges (N_rays, N_bins+1)
weights: Weights for each bin (N_rays, N_bins)
n_samples: Number of samples to draw
perturb: Whether to add random perturbations
Returns:
samples: Sampled positions (N_rays, n_samples)
"""
# Normalize weights to get PDF
weights = weights + 1e-5 # Prevent division by zero
pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # (N_rays, N_bins)
# Compute CDF
cdf = torch.cumsum(pdf, dim=-1) # (N_rays, N_bins)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1) # (N_rays, N_bins+1)
# Sample uniformly from [0, 1]
if perturb:
u = torch.rand(cdf.shape[0], n_samples, device=cdf.device)
else:
u = torch.linspace(0.0, 1.0, n_samples, device=cdf.device)
u = u.expand(cdf.shape[0], n_samples)
# Invert CDF to get samples
indices = torch.searchsorted(cdf, u, right=True) # (N_rays, n_samples)
below = torch.clamp(indices - 1, min=0)
above = torch.clamp(indices, max=cdf.shape[-1] - 1)
# Gather CDF values
cdf_below = torch.gather(cdf, 1, below)
cdf_above = torch.gather(cdf, 1, above)
# Gather bin edges
bins_below = torch.gather(bins, 1, below)
bins_above = torch.gather(bins, 1, above)
# Linear interpolation
denom = cdf_above - cdf_below
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_below) / denom
samples = bins_below + t * (bins_above - bins_below)
return samples
# ============================================================================
# 6. COMPLETE NERF MODEL
# ============================================================================
@dataclass
class NeRFConfig:
"""Configuration for NeRF model."""
# Network architecture
hidden_dim: int = 256
num_layers: int = 8
# Positional encoding
pos_freq: int = 10 # L for position (60D output)
dir_freq: int = 4 # L for direction (24D output)
# Sampling
n_samples_coarse: int = 64
n_samples_fine: int = 128
# Scene bounds
near: float = 2.0
far: float = 6.0
# Training
noise_std: float = 1.0 # Density noise during training
perturb: bool = True # Stratified sampling
class NeRF(nn.Module):
"""
Complete NeRF model with coarse and fine networks.
Args:
config: NeRF configuration
"""
def __init__(self, config: NeRFConfig):
super().__init__()
self.config = config
# Positional encoders
self.pos_encoder = PositionalEncoding(config.pos_freq, include_input=True)
self.dir_encoder = PositionalEncoding(config.dir_freq, include_input=True)
pos_enc_dim = self.pos_encoder.get_output_dim(3) # 3D position
dir_enc_dim = self.dir_encoder.get_output_dim(3) # 3D direction
# Coarse and fine networks
self.nerf_coarse = NeRFMLP(pos_enc_dim, dir_enc_dim, config.hidden_dim, config.num_layers)
self.nerf_fine = NeRFMLP(pos_enc_dim, dir_enc_dim, config.hidden_dim, config.num_layers)
def forward(
self,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
use_fine: bool = True
) -> Dict[str, torch.Tensor]:
"""
Render rays through the NeRF model.
Args:
rays_o: Ray origins (N_rays, 3)
rays_d: Ray directions (N_rays, 3)
use_fine: Whether to use hierarchical sampling and fine network
Returns:
Dictionary containing:
- rgb_coarse: Coarse RGB (N_rays, 3)
- rgb_fine: Fine RGB (N_rays, 3) if use_fine=True
- depth_coarse: Coarse depth (N_rays,)
- depth_fine: Fine depth (N_rays,) if use_fine=True
"""
# Normalize ray directions
rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
# === Coarse Network ===
# Sample points along rays
pts_coarse, t_vals_coarse = sample_along_rays(
rays_o, rays_d,
self.config.near, self.config.far,
self.config.n_samples_coarse,
perturb=self.config.perturb and self.training
)
# Encode positions and directions
N_rays, N_samples = pts_coarse.shape[:2]
pts_flat = pts_coarse.reshape(-1, 3)
dirs_flat = rays_d[:, None, :].expand(N_rays, N_samples, 3).reshape(-1, 3)
pts_encoded = self.pos_encoder(pts_flat)
dirs_encoded = self.dir_encoder(dirs_flat)
# Query coarse network
rgb_coarse, density_coarse = self.nerf_coarse(pts_encoded, dirs_encoded)
rgb_coarse = rgb_coarse.reshape(N_rays, N_samples, 3)
density_coarse = density_coarse.reshape(N_rays, N_samples, 1)
# Volume rendering
noise_std = self.config.noise_std if self.training else 0.0
rgb_map_coarse, depth_map_coarse, acc_map_coarse, weights_coarse = volume_rendering(
rgb_coarse, density_coarse, t_vals_coarse, noise_std
)
outputs = {
'rgb_coarse': rgb_map_coarse,
'depth_coarse': depth_map_coarse,
'acc_coarse': acc_map_coarse
}
if not use_fine:
return outputs
# === Fine Network (Hierarchical Sampling) ===
# Sample additional points based on coarse weights
with torch.no_grad():
t_vals_mid = 0.5 * (t_vals_coarse[..., :-1] + t_vals_coarse[..., 1:])
t_samples_fine = sample_pdf(
t_vals_mid,
weights_coarse[..., 1:-1], # Exclude first and last weights
self.config.n_samples_fine,
perturb=self.config.perturb and self.training
)
# Combine coarse and fine samples, sort by depth
t_vals_fine, _ = torch.sort(torch.cat([t_vals_coarse, t_samples_fine], dim=-1), dim=-1)
# Compute 3D points for fine samples
pts_fine = rays_o[..., None, :] + rays_d[..., None, :] * t_vals_fine[..., :, None]
# Encode and query fine network
N_samples_fine = pts_fine.shape[1]
pts_flat = pts_fine.reshape(-1, 3)
dirs_flat = rays_d[:, None, :].expand(N_rays, N_samples_fine, 3).reshape(-1, 3)
pts_encoded = self.pos_encoder(pts_flat)
dirs_encoded = self.dir_encoder(dirs_flat)
rgb_fine, density_fine = self.nerf_fine(pts_encoded, dirs_encoded)
rgb_fine = rgb_fine.reshape(N_rays, N_samples_fine, 3)
density_fine = density_fine.reshape(N_rays, N_samples_fine, 1)
# Volume rendering for fine network
rgb_map_fine, depth_map_fine, acc_map_fine, _ = volume_rendering(
rgb_fine, density_fine, t_vals_fine, noise_std
)
outputs.update({
'rgb_fine': rgb_map_fine,
'depth_fine': depth_map_fine,
'acc_fine': acc_map_fine
})
return outputs
# ============================================================================
# 7. TRAINING UTILITIES
# ============================================================================
def nerf_loss(outputs: Dict[str, torch.Tensor], target_rgb: torch.Tensor) -> torch.Tensor:
"""
NeRF loss function (MSE on coarse and fine renderings).
Args:
outputs: Model outputs containing rgb_coarse and optionally rgb_fine
target_rgb: Ground truth RGB (N_rays, 3)
Returns:
total_loss: Combined loss
"""
loss_coarse = F.mse_loss(outputs['rgb_coarse'], target_rgb)
if 'rgb_fine' in outputs:
loss_fine = F.mse_loss(outputs['rgb_fine'], target_rgb)
total_loss = loss_coarse + loss_fine
else:
total_loss = loss_coarse
return total_loss
def train_step(
model: NeRF,
optimizer: torch.optim.Optimizer,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
target_rgb: torch.Tensor,
chunk_size: int = 1024
) -> float:
"""
Single training step with ray batching.
Args:
model: NeRF model
optimizer: Optimizer
rays_o: Ray origins (N_rays, 3)
rays_d: Ray directions (N_rays, 3)
target_rgb: Target RGB (N_rays, 3)
chunk_size: Process rays in chunks to save memory
Returns:
loss: Training loss
"""
model.train()
optimizer.zero_grad()
N_rays = rays_o.shape[0]
all_outputs = []
# Process rays in chunks
for i in range(0, N_rays, chunk_size):
chunk_rays_o = rays_o[i:i+chunk_size]
chunk_rays_d = rays_d[i:i+chunk_size]
outputs = model(chunk_rays_o, chunk_rays_d, use_fine=True)
all_outputs.append(outputs)
# Combine outputs
combined_outputs = {
'rgb_coarse': torch.cat([o['rgb_coarse'] for o in all_outputs], dim=0),
'rgb_fine': torch.cat([o['rgb_fine'] for o in all_outputs], dim=0)
}
# Compute loss
loss = nerf_loss(combined_outputs, target_rgb)
# Backward pass
loss.backward()
optimizer.step()
return loss.item()
# ============================================================================
# 8. VISUALIZATION UTILITIES
# ============================================================================
def render_full_image(
model: NeRF,
H: int,
W: int,
focal: float,
c2w: torch.Tensor,
chunk_size: int = 1024,
use_fine: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
"""
Render a full image from a camera pose.
Args:
model: NeRF model
H: Image height
W: Image width
focal: Focal length
c2w: Camera-to-world matrix (3, 4) or (4, 4)
chunk_size: Rays per chunk
use_fine: Use fine network
Returns:
rgb: Rendered RGB image (H, W, 3)
depth: Depth map (H, W)
"""
model.eval()
# Generate rays
rays_o, rays_d = get_rays(H, W, focal, c2w)
rays_o = rays_o.reshape(-1, 3)
rays_d = rays_d.reshape(-1, 3)
all_rgb = []
all_depth = []
with torch.no_grad():
for i in range(0, rays_o.shape[0], chunk_size):
chunk_rays_o = rays_o[i:i+chunk_size]
chunk_rays_d = rays_d[i:i+chunk_size]
outputs = model(chunk_rays_o, chunk_rays_d, use_fine=use_fine)
if use_fine:
all_rgb.append(outputs['rgb_fine'].cpu())
all_depth.append(outputs['depth_fine'].cpu())
else:
all_rgb.append(outputs['rgb_coarse'].cpu())
all_depth.append(outputs['depth_coarse'].cpu())
rgb = torch.cat(all_rgb, dim=0).reshape(H, W, 3).numpy()
depth = torch.cat(all_depth, dim=0).reshape(H, W).numpy()
return rgb, depth
def visualize_nerf_rendering(
rgb: np.ndarray,
depth: np.ndarray,
title: str = "NeRF Rendering"
):
"""
Visualize NeRF rendering (RGB and depth).
Args:
rgb: RGB image (H, W, 3)
depth: Depth map (H, W)
title: Plot title
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# RGB
axes[0].imshow(rgb)
axes[0].set_title('RGB Rendering')
axes[0].axis('off')
# Depth
im = axes[1].imshow(depth, cmap='turbo')
axes[1].set_title('Depth Map')
axes[1].axis('off')
plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
plt.suptitle(title, fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
def create_spiral_path(
n_poses: int = 40,
radius: float = 3.0,
height: float = 0.0,
focal_point: np.ndarray = np.array([0, 0, 0])
) -> List[torch.Tensor]:
"""
Create spiral camera path for novel view synthesis.
Args:
n_poses: Number of poses
radius: Spiral radius
height: Height variation
focal_point: Point to look at
Returns:
List of camera-to-world matrices
"""
poses = []
for i in range(n_poses):
theta = 2 * np.pi * i / n_poses
# Camera position (spiral)
x = radius * np.cos(theta)
y = radius * np.sin(theta)
z = height * np.sin(2 * theta)
cam_pos = np.array([x, y, z])
# Look-at matrix
forward = focal_point - cam_pos
forward = forward / np.linalg.norm(forward)
right = np.cross(np.array([0, 0, 1]), forward)
right = right / np.linalg.norm(right)
up = np.cross(forward, right)
# Camera-to-world matrix
c2w = np.eye(4)
c2w[:3, 0] = right
c2w[:3, 1] = up
c2w[:3, 2] = -forward
c2w[:3, 3] = cam_pos
poses.append(torch.tensor(c2w, dtype=torch.float32))
return poses
# ============================================================================
# 9. DEMONSTRATION
# ============================================================================
def demo_nerf_components():
"""Demonstrate NeRF components."""
print("=" * 80)
print("NeRF Components Demonstration")
print("=" * 80)
# 1. Positional Encoding
print("\n1. Positional Encoding")
print("-" * 40)
pos_encoder = PositionalEncoding(num_freqs=10, include_input=True)
sample_pos = torch.tensor([[0.5, 0.5, 0.5]])
encoded = pos_encoder(sample_pos)
print(f"Input shape: {sample_pos.shape}")
print(f"Encoded shape: {encoded.shape}")
print(f"Encoding dimension: 3 β {encoded.shape[1]}")
# 2. NeRF MLP
print("\n2. NeRF MLP Architecture")
print("-" * 40)
config = NeRFConfig()
nerf = NeRF(config)
print(f"Position encoding: 3D β {nerf.pos_encoder.get_output_dim(3)}D")
print(f"Direction encoding: 3D β {nerf.dir_encoder.get_output_dim(3)}D")
print(f"Coarse network parameters: {sum(p.numel() for p in nerf.nerf_coarse.parameters()):,}")
print(f"Fine network parameters: {sum(p.numel() for p in nerf.nerf_fine.parameters()):,}")
print(f"Total parameters: {sum(p.numel() for p in nerf.parameters()):,}")
# 3. Ray Generation
print("\n3. Ray Generation")
print("-" * 40)
H, W, focal = 100, 100, 138.88
c2w = torch.eye(4)
c2w[:3, 3] = torch.tensor([0, 0, 3]) # Camera at z=3
rays_o, rays_d = get_rays(H, W, focal, c2w)
print(f"Image size: {H} Γ {W}")
print(f"Ray origins shape: {rays_o.shape}")
print(f"Ray directions shape: {rays_d.shape}")
print(f"All rays originate from: {rays_o[0, 0]}")
# 4. Sampling
print("\n4. Sampling Along Rays")
print("-" * 40)
sample_rays_o = rays_o[H//2, W//2:W//2+1] # Center pixel
sample_rays_d = rays_d[H//2, W//2:W//2+1]
pts, t_vals = sample_along_rays(
sample_rays_o, sample_rays_d,
near=2.0, far=6.0, n_samples=64, perturb=False
)
print(f"Number of samples: {t_vals.shape[1]}")
print(f"Sample range: [{t_vals.min():.2f}, {t_vals.max():.2f}]")
print(f"Sampled points shape: {pts.shape}")
# 5. Volume Rendering
print("\n5. Volume Rendering")
print("-" * 40)
# Dummy RGB and density (Gaussian blob)
dummy_rgb = torch.rand(1, 64, 3)
# Create Gaussian density centered at middle of ray
t_center = (2.0 + 6.0) / 2
dummy_density = torch.exp(-0.5 * ((t_vals - t_center) / 0.5) ** 2).unsqueeze(-1)
rgb_map, depth_map, acc_map, weights = volume_rendering(
dummy_rgb, dummy_density, t_vals, noise_std=0.0
)
print(f"Rendered RGB: {rgb_map[0]}")
print(f"Rendered depth: {depth_map[0]:.2f}")
print(f"Accumulated opacity: {acc_map[0]:.4f}")
print(f"Weights sum: {weights.sum(dim=1)[0]:.4f}")
# 6. Forward Pass
print("\n6. NeRF Forward Pass")
print("-" * 40)
batch_rays_o = rays_o.reshape(-1, 3)[:1024]
batch_rays_d = rays_d.reshape(-1, 3)[:1024]
outputs = nerf(batch_rays_o, batch_rays_d, use_fine=True)
print(f"Coarse RGB shape: {outputs['rgb_coarse'].shape}")
print(f"Fine RGB shape: {outputs['rgb_fine'].shape}")
print(f"Coarse depth range: [{outputs['depth_coarse'].min():.2f}, {outputs['depth_coarse'].max():.2f}]")
print(f"Fine depth range: [{outputs['depth_fine'].min():.2f}, {outputs['depth_fine'].max():.2f}]")
def demo_hierarchical_sampling():
"""Demonstrate hierarchical sampling."""
print("\n" + "=" * 80)
print("Hierarchical Sampling Demonstration")
print("=" * 80)
# Create simple scenario
t_vals_coarse = torch.linspace(2.0, 6.0, 65).unsqueeze(0) # (1, 65)
# Weights concentrated in middle (simulating object)
weights = torch.exp(-0.5 * ((t_vals_coarse[:, :-1] - 4.0) / 0.5) ** 2)
weights = weights / weights.sum(dim=1, keepdim=True) # Normalize
# Sample fine points
t_samples_fine = sample_pdf(
t_vals_coarse[:, :-1], # Bins
weights[:, 1:], # Weights (excluding first)
n_samples=128,
perturb=False
)
# Visualize sampling distribution
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
# Coarse weights
axes[0].bar(
t_vals_coarse[0, :-1].numpy(),
weights[0].numpy(),
width=0.06,
alpha=0.7,
label='Coarse weights'
)
axes[0].set_xlabel('Ray depth (t)')
axes[0].set_ylabel('Weight')
axes[0].set_title('Coarse Sampling Weights')
axes[0].legend()
axes[0].grid(alpha=0.3)
# Fine sampling distribution
axes[1].hist(t_samples_fine[0].numpy(), bins=50, alpha=0.7, label='Fine samples')
axes[1].axvline(4.0, color='red', linestyle='--', label='Peak density')
axes[1].set_xlabel('Ray depth (t)')
axes[1].set_ylabel('Sample count')
axes[1].set_title('Hierarchical Fine Sampling (More samples where weights are high)')
axes[1].legend()
axes[1].grid(alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\nCoarse samples: {t_vals_coarse.shape[1] - 1}")
print(f"Fine samples: {t_samples_fine.shape[1]}")
print(f"Total samples (combined): {t_vals_coarse.shape[1] - 1 + t_samples_fine.shape[1]}")
print(f"Fine samples concentrated around t={t_samples_fine.mean():.2f} (target: 4.0)")
def demo_positional_encoding_effect():
"""Visualize effect of positional encoding."""
print("\n" + "=" * 80)
print("Positional Encoding Effect")
print("=" * 80)
# Create 1D coordinate
x = torch.linspace(-1, 1, 200).unsqueeze(-1)
# Different frequency levels
freq_levels = [0, 2, 5, 10]
fig, axes = plt.subplots(len(freq_levels), 1, figsize=(12, 10))
for i, L in enumerate(freq_levels):
if L == 0:
# No encoding
encoded = x
title = "No Encoding (Raw Input)"
else:
encoder = PositionalEncoding(num_freqs=L, include_input=False)
encoded = encoder(x)
title = f"Positional Encoding (L={L}, {encoded.shape[1]} features)"
# Visualize first few encoding dimensions
axes[i].plot(x.numpy(), encoded[:, :min(10, encoded.shape[1])].numpy(), alpha=0.7)
axes[i].set_title(title)
axes[i].set_xlabel('Input coordinate')
axes[i].set_ylabel('Encoded value')
axes[i].grid(alpha=0.3)
axes[i].set_xlim([-1, 1])
plt.tight_layout()
plt.show()
print("\nPositional encoding projects inputs to high-frequency space.")
print("This allows MLPs to learn high-frequency details in the scene.")
# Run demonstrations
if __name__ == "__main__":
print("Starting NeRF demonstrations...\n")
# Component demonstration
demo_nerf_components()
# Hierarchical sampling
demo_hierarchical_sampling()
# Positional encoding visualization
demo_positional_encoding_effect()
print("\n" + "=" * 80)
print("NeRF Implementation Complete!")
print("=" * 80)
print("\nKey Features Implemented:")
print("β Positional encoding for high-frequency details")
print("β NeRF MLP with skip connections")
print("β Volumetric rendering with proper transmittance")
print("β Hierarchical sampling (coarse + fine networks)")
print("β Ray generation and batching")
print("β Training utilities")
print("β Visualization functions")
print("\nThis implementation can be extended with:")
print("- Actual training on real datasets (Blender scenes, LLFF)")
print("- Mip-NeRF cone tracing")
print("- Instant NGP hash encoding")
print("- Dynamic scenes (D-NeRF)")
print("- Semantic outputs")
print("- Text-to-3D (DreamFusion SDS)")
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
1. Volume RenderingΒΆ
Ray EquationΒΆ
where \(\mathbf{o}\) is origin, \(\mathbf{d}\) is direction.
Volume Rendering EquationΒΆ
where:
\(T(t) = \exp\left(-\int_{t_n}^t \sigma(\mathbf{r}(s)) ds\right)\) (transmittance)
\(\sigma\) is volume density
\(\mathbf{c}\) is RGB color
π Reference Materials:
cv_3d_foundation.pdf - Cv 3D Foundation
cv_3d_research.pdf - Cv 3D Research
intermediate_cv_3d.pdf - Intermediate Cv 3D
2. Positional EncodingΒΆ
Mapping to Higher DimensionsΒΆ
Enables high-frequency representation.
class PositionalEncoding(nn.Module):
"""Positional encoding for coordinates."""
def __init__(self, num_freqs=10, include_input=True):
super().__init__()
self.num_freqs = num_freqs
self.include_input = include_input
freq_bands = 2.0 ** torch.linspace(0, num_freqs - 1, num_freqs)
self.register_buffer('freq_bands', freq_bands)
def forward(self, x):
"""
Args:
x: [..., D] input coordinates
Returns:
[..., D * (2 * num_freqs + 1)] if include_input
"""
out = []
if self.include_input:
out.append(x)
for freq in self.freq_bands:
out.append(torch.sin(freq * np.pi * x))
out.append(torch.cos(freq * np.pi * x))
return torch.cat(out, dim=-1)
def output_dim(self, input_dim):
return input_dim * (2 * self.num_freqs + int(self.include_input))
# Test
pe = PositionalEncoding(num_freqs=6)
x = torch.tensor([[0.5, 0.3, 0.8]])
encoded = pe(x)
print(f"Input: {x.shape}, Encoded: {encoded.shape}")
NeRF NetworkΒΆ
The NeRF network is an MLP that maps a 5D input β 3D position \((x, y, z)\) and 2D viewing direction \((\theta, \phi)\) β to a color \((r, g, b)\) and volume density \(\sigma\). Positional encoding (sinusoidal features at multiple frequencies) is applied to the input coordinates to help the network represent high-frequency spatial detail. The density depends only on position (ensuring multi-view consistency), while color depends on both position and direction (modeling view-dependent effects like specularities). This simple architecture, when trained on a set of posed images, learns an implicit 3D representation of the scene that can be queried at any continuous point in space.
class NeRF(nn.Module):
"""Neural Radiance Field."""
def __init__(self, pos_enc_dim=63, dir_enc_dim=27, hidden_dim=256):
super().__init__()
# Position encoding
self.pos_encoder = PositionalEncoding(num_freqs=10)
self.dir_encoder = PositionalEncoding(num_freqs=4)
# Position network (x, y, z) -> features + density
self.pos_net = nn.Sequential(
nn.Linear(pos_enc_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# Skip connection
self.pos_net2 = nn.Sequential(
nn.Linear(hidden_dim + pos_enc_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# Density head
self.density_head = nn.Sequential(
nn.Linear(hidden_dim, 1),
nn.ReLU() # Density must be non-negative
)
# Feature extraction
self.feature_head = nn.Linear(hidden_dim, hidden_dim)
# Direction network (features + direction) -> RGB
self.dir_net = nn.Sequential(
nn.Linear(hidden_dim + dir_enc_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 3),
nn.Sigmoid() # RGB in [0, 1]
)
def forward(self, pos, direction):
"""
Args:
pos: [..., 3] 3D positions
direction: [..., 3] view directions
Returns:
rgb: [..., 3]
density: [..., 1]
"""
# Encode
pos_enc = self.pos_encoder(pos)
dir_enc = self.dir_encoder(direction)
# Position features
h = self.pos_net(pos_enc)
h = torch.cat([h, pos_enc], dim=-1)
h = self.pos_net2(h)
# Density
density = self.density_head(h)
# Features for color
features = self.feature_head(h)
# RGB (view-dependent)
h_dir = torch.cat([features, dir_enc], dim=-1)
rgb = self.dir_net(h_dir)
return rgb, density
# Test
model = NeRF().to(device)
pos = torch.randn(10, 3).to(device)
direction = F.normalize(torch.randn(10, 3), dim=-1).to(device)
rgb, density = model(pos, direction)
print(f"RGB: {rgb.shape}, Density: {density.shape}")
Volume Rendering (Discrete)ΒΆ
To render an image from the NeRF representation, we cast a ray through each pixel and evaluate the network at sampled points along the ray. The discrete volume rendering equation accumulates color weighted by density and transmittance: \(C(r) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) c_i\), where \(T_i = \exp(-\sum_{j<i} \sigma_j \delta_j)\) is the accumulated transmittance. Points with high density and low prior absorption contribute most to the final pixel color. This differentiable rendering process allows end-to-end training: the rendered pixel colors are compared to ground truth via MSE loss, and gradients flow back through the rendering equation into the MLP weights.
def volume_render(rgb, density, t_vals):
"""
Discrete volume rendering.
Args:
rgb: [N_rays, N_samples, 3]
density: [N_rays, N_samples, 1]
t_vals: [N_rays, N_samples]
Returns:
rgb_map: [N_rays, 3]
depth_map: [N_rays]
weights: [N_rays, N_samples]
"""
# Delta
dists = t_vals[..., 1:] - t_vals[..., :-1]
dists = torch.cat([dists, torch.full_like(dists[..., :1], 1e10)], dim=-1)
# Alpha
alpha = 1.0 - torch.exp(-density.squeeze(-1) * dists)
# Transmittance
T = torch.cumprod(torch.cat([
torch.ones_like(alpha[..., :1]),
1.0 - alpha[..., :-1] + 1e-10
], dim=-1), dim=-1)
# Weights
weights = alpha * T
# RGB
rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)
# Depth
depth_map = torch.sum(weights * t_vals, dim=-1)
return rgb_map, depth_map, weights
# Test
N_rays, N_samples = 100, 64
rgb_test = torch.rand(N_rays, N_samples, 3)
density_test = torch.rand(N_rays, N_samples, 1)
t_vals_test = torch.linspace(2, 6, N_samples).expand(N_rays, N_samples)
rgb_map, depth_map, weights = volume_render(rgb_test, density_test, t_vals_test)
print(f"RGB map: {rgb_map.shape}, Depth map: {depth_map.shape}")
Simple Scene DemoΒΆ
Before tackling real photographs, we create a simple synthetic scene to verify the NeRF pipeline works end-to-end. A few geometric primitives (spheres, boxes) with known colors are rendered from multiple viewpoints to create a small training set of posed images. Training the NeRF network on this synthetic data should reproduce the scene geometry and appearance accurately, providing a controlled test bed for debugging the positional encoding, volume rendering, and optimization before moving to complex real-world scenes.
def create_synthetic_scene(model, n_iters=500):
"""Train NeRF on simple synthetic scene."""
# Define simple scene: colored sphere at origin
def scene_sdf(pos):
"""Sphere at origin with radius 0.5."""
return torch.norm(pos, dim=-1, keepdim=True) - 0.5
def scene_color(pos):
"""Color based on position."""
return (pos + 1.0) / 2.0 # Normalize to [0, 1]
# Camera rays
H, W = 50, 50
focal = 50
# Image plane coordinates
i, j = torch.meshgrid(
torch.arange(H, dtype=torch.float32),
torch.arange(W, dtype=torch.float32),
indexing='ij'
)
# Ray directions in camera coordinates
dirs = torch.stack([
(j - W / 2) / focal,
-(i - H / 2) / focal,
-torch.ones_like(i)
], dim=-1)
dirs = F.normalize(dirs, dim=-1)
rays_d = dirs.reshape(-1, 3).to(device)
rays_o = torch.zeros_like(rays_d).to(device)
rays_o[:, 2] = 3.0 # Camera at z=3
# Sample points along rays
N_samples = 64
t_vals = torch.linspace(2.0, 4.0, N_samples, device=device)
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
losses = []
for iter in range(n_iters):
# Sample batch of rays
batch_size = 256
idx = torch.randint(0, len(rays_o), (batch_size,))
rays_o_batch = rays_o[idx]
rays_d_batch = rays_d[idx]
# Points along rays
pts = rays_o_batch[:, None, :] + rays_d_batch[:, None, :] * t_vals[None, :, None]
# Query NeRF
pts_flat = pts.reshape(-1, 3)
dirs_flat = rays_d_batch[:, None, :].expand_as(pts).reshape(-1, 3)
rgb, density = model(pts_flat, dirs_flat)
rgb = rgb.reshape(batch_size, N_samples, 3)
density = density.reshape(batch_size, N_samples, 1)
# Render
rgb_map, _, _ = volume_render(rgb, density, t_vals.expand(batch_size, N_samples))
# Ground truth (sphere)
with torch.no_grad():
sdf = scene_sdf(pts_flat).reshape(batch_size, N_samples, 1)
gt_density = torch.sigmoid(-sdf * 50) # Sharp boundary
gt_rgb = scene_color(pts_flat).reshape(batch_size, N_samples, 3)
gt_rgb_map, _, _ = volume_render(gt_rgb, gt_density, t_vals.expand(batch_size, N_samples))
# Loss
loss = F.mse_loss(rgb_map, gt_rgb_map)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if (iter + 1) % 100 == 0:
print(f"Iter {iter+1}, Loss: {loss.item():.6f}")
return losses
# Train
model = NeRF().to(device)
losses = create_synthetic_scene(model, n_iters=500)
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('MSE Loss', fontsize=11)
plt.title('NeRF Training on Synthetic Sphere', fontsize=12)
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.show()
Render Novel ViewΒΆ
The primary application of NeRF is novel view synthesis: rendering the scene from camera positions not present in the training set. By specifying a new camera pose (position and orientation), casting rays through each pixel, and evaluating the trained network, we obtain a photorealistic image of the scene from an entirely new viewpoint. The quality of novel views β especially in regions that were only partially observed during training β is the definitive test of whether the network has learned a coherent 3D representation rather than merely memorizing the training views.
def render_image(model, H=100, W=100, focal=100, cam_pos=[0, 0, 3]):
"""Render full image."""
model.eval()
i, j = torch.meshgrid(
torch.arange(H, dtype=torch.float32),
torch.arange(W, dtype=torch.float32),
indexing='ij'
)
dirs = torch.stack([
(j - W / 2) / focal,
-(i - H / 2) / focal,
-torch.ones_like(i)
], dim=-1)
dirs = F.normalize(dirs, dim=-1)
rays_d = dirs.reshape(-1, 3).to(device)
rays_o = torch.tensor(cam_pos, dtype=torch.float32, device=device).expand_as(rays_d)
N_samples = 64
t_vals = torch.linspace(2.0, 4.0, N_samples, device=device)
# Render in batches
batch_size = 512
rgb_full = []
with torch.no_grad():
for i in range(0, len(rays_o), batch_size):
rays_o_batch = rays_o[i:i+batch_size]
rays_d_batch = rays_d[i:i+batch_size]
pts = rays_o_batch[:, None, :] + rays_d_batch[:, None, :] * t_vals[None, :, None]
pts_flat = pts.reshape(-1, 3)
dirs_flat = rays_d_batch[:, None, :].expand_as(pts).reshape(-1, 3)
rgb, density = model(pts_flat, dirs_flat)
rgb = rgb.reshape(len(rays_o_batch), N_samples, 3)
density = density.reshape(len(rays_o_batch), N_samples, 1)
rgb_map, _, _ = volume_render(rgb, density, t_vals.expand(len(rays_o_batch), N_samples))
rgb_full.append(rgb_map.cpu())
rgb_full = torch.cat(rgb_full, dim=0)
return rgb_full.reshape(H, W, 3).numpy()
# Render
img = render_image(model, H=100, W=100, cam_pos=[0, 0, 3])
plt.figure(figsize=(8, 8))
plt.imshow(img)
plt.axis('off')
plt.title('NeRF Rendered Image', fontsize=12)
plt.show()
SummaryΒΆ
NeRF Represents:ΒΆ
Scene as continuous 5D function: \((x, y, z, \theta, \phi) \to (r, g, b, \sigma)\)
Key Components:ΒΆ
Positional encoding - High frequency details
MLP network - Maps coordinates to color + density
Volume rendering - Integrates along rays
Hierarchical sampling - Coarse-to-fine
Applications:ΒΆ
Novel view synthesis
3D reconstruction
Relighting
Scene editing
Extensions:ΒΆ
Instant-NGP (hash encoding)
Plenoxels (voxel grids)
NeRF-W (in-the-wild)
Dynamic NeRF (time-varying)
Next Steps:ΒΆ
Study instant-NGP for speed
Explore signed distance functions
Learn multi-view geometry