Neural Radiance Fields (NeRF): Comprehensive TheoryΒΆ

IntroductionΒΆ

Neural Radiance Fields (NeRF) revolutionized 3D scene representation and novel view synthesis by representing scenes as continuous volumetric radiance fields encoded by deep neural networks. Instead of explicit 3D representations (meshes, point clouds, voxels), NeRF learns an implicit function that maps 5D coordinates (3D location + 2D viewing direction) to volume density and view-dependent emitted radiance.

Key Innovation: Representing complex 3D scenes as weights of a neural network, achieving photorealistic novel view synthesis from sparse input views.

Mathematical FoundationΒΆ

Scene RepresentationΒΆ

NeRF represents a static scene as a continuous function:

\[F_\Theta: (\mathbf{x}, \mathbf{d}) \rightarrow (\mathbf{c}, \sigma)\]

Where:

  • \(\mathbf{x} = (x, y, z) \in \mathbb{R}^3\) is a 3D location

  • \(\mathbf{d} = (\theta, \phi) \in \mathbb{S}^2\) is a viewing direction (unit vector)

  • \(\mathbf{c} = (r, g, b) \in [0,1]^3\) is emitted RGB color

  • \(\sigma \in \mathbb{R}^+ \) is volume density (opacity)

  • \(\Theta\) represents network parameters

Key Property: Volume density \(\sigma(\mathbf{x})\) is view-independent (geometric property), while color \(\mathbf{c}(\mathbf{x}, \mathbf{d})\) is view-dependent (captures specular reflections, viewing angle effects).

Volume RenderingΒΆ

To render a pixel, NeRF performs volumetric rendering along camera rays. For a camera ray \(\mathbf{r}(t) = \mathbf{o} + t\mathbf{d}\) (origin \(\mathbf{o}\), direction \(\mathbf{d}\), parameter \(t\)), the expected color is:

\[C(\mathbf{r}) = \int_{t_n}^{t_f} T(t) \cdot \sigma(\mathbf{r}(t)) \cdot \mathbf{c}(\mathbf{r}(t), \mathbf{d}) \, dt\]

Where transmittance \(T(t)\) is the accumulated transparency from \(t_n\) to \(t\):

\[T(t) = \exp\left(-\int_{t_n}^{t} \sigma(\mathbf{r}(s)) \, ds\right)\]

Interpretation:

  • \(T(t)\) is the probability that the ray travels from \(t_n\) to \(t\) without hitting any particle

  • \(\sigma(\mathbf{r}(t)) \cdot dt\) is the probability of ray termination in infinitesimal segment

  • The integral accumulates color contributions weighted by opacity and visibility

Discrete ApproximationΒΆ

In practice, we discretize the continuous integral using quadrature:

\[\hat{C}(\mathbf{r}) = \sum_{i=1}^{N} T_i \cdot \alpha_i \cdot \mathbf{c}_i\]

Where:

  • \(\alpha_i = 1 - \exp(-\sigma_i \delta_i)\) is alpha compositing opacity

  • \(\delta_i = t_{i+1} - t_i\) is the distance between adjacent samples

  • \(T_i = \exp\left(-\sum_{j=1}^{i-1} \sigma_j \delta_j\right) = \prod_{j=1}^{i-1}(1 - \alpha_j)\) is discrete transmittance

Sampling Strategy: Sample points along ray using stratified sampling to ensure continuous coverage:

\[t_i \sim \mathcal{U}\left[t_n + \frac{i-1}{N}(t_f - t_n), \quad t_n + \frac{i}{N}(t_f - t_n)\right]\]

Network ArchitectureΒΆ

Two-Stage MLPΒΆ

NeRF uses an 8-layer MLP with a skip connection:

  1. Density Branch (layers 1-8):

    • Input: 3D position \(\mathbf{x}\) (after positional encoding)

    • Hidden: 256 units per layer, ReLU activations

    • Skip connection: Concatenate input at layer 5

    • Output: 256D feature vector + density \(\sigma\)

  2. Color Branch (layer 9):

    • Input: 256D features + viewing direction \(\mathbf{d}\) (after positional encoding)

    • Hidden: 128 units, ReLU

    • Output: RGB color \(\mathbf{c}\)

Architecture Insight: Separating density and color allows the network to learn view-independent geometry while modeling view-dependent appearance.

Positional EncodingΒΆ

Critical innovation: Map input coordinates to higher-dimensional space using sinusoidal functions:

\[\gamma(p) = \left(\sin(2^0 \pi p), \cos(2^0 \pi p), \ldots, \sin(2^{L-1} \pi p), \cos(2^{L-1} \pi p)\right)\]

For 3D position: \(L = 10\) (60D output)
For 2D direction: \(L = 4\) (24D output)

Why It Works:

  • Neural networks have spectral bias toward low-frequency functions

  • Positional encoding projects inputs to high-frequency space

  • Enables learning of high-frequency variations in color and geometry

  • Related to Fourier feature mapping and coordinate-based MLPs

Mathematical Intuition: This is similar to a Fourier basis expansion, allowing the MLP to learn arbitrary functions via the universal approximation property in this expanded feature space.

Hierarchical Volume SamplingΒΆ

NeRF uses a coarse-to-fine sampling strategy to allocate network capacity efficiently:

Coarse Network (\(F_c\))ΒΆ

  1. Sample \(N_c\) points uniformly (stratified) along each ray

  2. Query coarse network \(F_c\) at these positions

  3. Compute coarse rendering \(\hat{C}_c(\mathbf{r})\)

Fine Network (\(F_f\))ΒΆ

  1. Use coarse network’s density predictions to compute importance sampling weights:

\[w_i = T_i \cdot \alpha_i\]
  1. Sample additional \(N_f\) points from piecewise-constant PDF based on \(w_i\)

  2. Combine coarse and fine samples (total \(N_c + N_f\) points)

  3. Query fine network \(F_f\) and compute fine rendering \(\hat{C}_f(\mathbf{r})\)

Benefit: Focus samples in regions with high expected contribution to final rendering, improving efficiency and quality.

Loss FunctionΒΆ

\[\mathcal{L} = \sum_{\mathbf{r} \in \mathcal{R}} \left[\|\hat{C}_c(\mathbf{r}) - C(\mathbf{r})\|_2^2 + \|\hat{C}_f(\mathbf{r}) - C(\mathbf{r})\|_2^2\right]\]

Where \(\mathcal{R}\) is a batch of camera rays, and \(C(\mathbf{r})\) is the ground truth pixel color.

Training StrategiesΒΆ

Ray BatchingΒΆ

  • Sample random batch of camera rays from training images

  • Typical: 4096 rays per batch

  • Each ray requires \(N_c + N_f\) network queries (e.g., 64 + 128 = 192)

  • Total: ~786,432 network evaluations per batch

OptimizationΒΆ

  • Optimizer: Adam

  • Learning rate: \(5 \times 10^{-4}\) with exponential decay

  • Training: 100,000-300,000 iterations

  • Hardware: Single NVIDIA V100 GPU (~1-2 days per scene)

RegularizationΒΆ

Weight Decay: L2 regularization on network weights prevents overfitting on sparse views.

Ray Jittering: Add small random perturbations to ray directions for augmentation.

Advanced NeRF VariantsΒΆ

1. Mip-NeRF (Multiscale Representation)ΒΆ

Problem: Original NeRF samples single points along rays, causing aliasing and ambiguity at different scales.

Solution: Instead of querying discrete points, query 3D conical frustums that account for pixel footprint.

Integrated Positional Encoding (IPE): $\(\mathbb{E}_{\mathbf{x} \sim p(\mathbf{x})}[\gamma(\mathbf{x})] \approx \gamma(\boldsymbol{\mu}, \boldsymbol{\Sigma})\)$

Where \(\boldsymbol{\mu}, \boldsymbol{\Sigma}\) are mean and covariance of Gaussian approximation to conical frustum.

Benefit: Anti-aliasing, improved rendering at different resolutions, faster convergence.

2. Instant NGP (Neural Graphics Primitives)ΒΆ

Problem: NeRF training is slow (~days per scene).

Solution: Multi-resolution hash encoding with small MLP.

Hash Encoding: $\(\mathbf{h} = \bigoplus_{l=1}^{L} \text{Interp}(\mathbf{x}; H_l)\)$

Where \(H_l\) is a hash table at resolution level \(l\), and Interp performs trilinear interpolation.

Performance: 1000Γ— faster training (~5 seconds per scene), interactive rendering.

3. NeRF in the Wild (NeRF-W)ΒΆ

Problem: Original NeRF assumes static scenes with consistent lighting.

Solution: Model appearance variations and transient objects.

Extended Representation: $\(F_\Theta: (\mathbf{x}, \mathbf{d}, \mathbf{l}_i, \boldsymbol{\tau}_i) \rightarrow (\mathbf{c}, \sigma, \beta)\)$

  • \(\mathbf{l}_i\): Per-image appearance embedding

  • \(\boldsymbol{\tau}_i\): Per-image transient embedding

  • \(\beta\): Transient density (for moving objects, lighting changes)

4. PlenoxelsΒΆ

Problem: MLP inference is slow for real-time rendering.

Solution: Replace MLP with sparse voxel grid (spherical harmonics for view-dependence).

Rendering: Direct interpolation from voxel grid (no network inference).

Trade-off: Memory vs. speed (faster rendering but larger storage).

5. TensoRFΒΆ

Problem: Voxel grids scale poorly to high resolutions (memory \(O(N^3)\)).

Solution: Tensor decomposition of radiance field.

CP Decomposition: $\(\mathcal{F}(x,y,z) = \sum_{r=1}^{R} \mathbf{v}^X_r[x] \cdot \mathbf{v}^Y_r[y] \cdot \mathbf{v}^Z_r[z]\)$

VM Decomposition (Vector-Matrix): $\(\mathcal{F}(x,y,z) = \sum_{r=1}^{R} \mathbf{v}^Z_r[z] \cdot \mathbf{M}^{XY}_r[x,y]\)$

Benefit: Compact representation, fast rendering, improved quality.

6. Nerfies / HyperNeRFΒΆ

Problem: NeRF assumes rigid static scenes.

Solution: Model non-rigid deformations for dynamic/deformable scenes.

Deformation Field: $\(\mathbf{x}' = \mathbf{x} + \Delta\mathbf{x}(\mathbf{x}, t)\)$

Query canonical NeRF at deformed position \(\mathbf{x}'\).

Applications: Face reenactment, human performance capture, dynamic object modeling.

7. NeRF++ (Unbounded Scenes)ΒΆ

Problem: Original NeRF assumes bounded scenes.

Solution: Separate foreground and background modeling with inverted sphere parameterization.

Background Representation: $\(\mathbf{x}_\text{bg} = \frac{\mathbf{d}}{|\mathbf{d}|^2}\)$

Map infinite rays to bounded space using inverse radius.

Dynamic NeRF ExtensionsΒΆ

D-NeRF (Dynamic NeRF)ΒΆ

Representation: $\(F_\Theta: (\mathbf{x}, t) \rightarrow (\Delta\mathbf{x}, \mathbf{c}, \sigma)\)$

  • Canonical space: Static NeRF

  • Deformation network: Maps \((\mathbf{x}, t)\) to canonical coordinates

  • Separate networks for deformation and canonical radiance

Neural Scene Flow FieldsΒΆ

Represent 4D space-time as continuous flow field: $\(F: (\mathbf{x}, t_1, t_2) \rightarrow \Delta\mathbf{x}_{t_1 \rightarrow t_2}\)$

Track 3D points across time for dynamic reconstruction.

Generalization and Few-Shot NeRFΒΆ

pixelNeRFΒΆ

Problem: NeRF requires many views per scene (50-100 images).

Solution: Condition NeRF on input image features (meta-learning across scenes).

Architecture: $\(F_\Theta: (\mathbf{x}, \mathbf{d}, \mathbf{W}(\mathbf{x})) \rightarrow (\mathbf{c}, \sigma)\)$

Where \(\mathbf{W}(\mathbf{x})\) is image feature extracted from CNN encoder at projected position.

Benefit: Single or few-view novel view synthesis.

IBRNet / MVSNeRFΒΆ

Combine multi-view stereo with NeRF for few-shot generalization:

  • Extract features from source views

  • Aggregate features along epipolar lines

  • Condition NeRF decoder on aggregated features

Semantic and Controllable NeRFΒΆ

Semantic-NeRFΒΆ

Extended Output: $\(F_\Theta: (\mathbf{x}, \mathbf{d}) \rightarrow (\mathbf{c}, \sigma, \mathbf{s})\)$

Where \(\mathbf{s}\) is semantic logits for object classes.

Rendering: Composite semantics using same volume rendering equation.

Editable NeRFΒΆ

Object Decomposition: Represent scene as composition of object NeRFs.

Control: Manipulate object poses, appearances, and properties independently.

Text-to-3D and Diffusion NeRFΒΆ

DreamFusionΒΆ

Problem: No 3D training data, only 2D text-to-image models.

Solution: Score Distillation Sampling (SDS) loss.

SDS Loss: $\(\mathcal{L}_\text{SDS} = \mathbb{E}_{t,\epsilon}\left[w(t) \left\|\epsilon_\phi(\mathbf{z}_t; t, y) - \epsilon\right\|_2^2\right]\)$

Where:

  • \(\epsilon_\phi\): Pre-trained diffusion model (e.g., Stable Diffusion)

  • \(\mathbf{z}_t\): Noised rendering from NeRF

  • \(y\): Text prompt

  • \(\epsilon\): Random noise

Interpretation: Optimize NeRF to produce renderings that the diffusion model considers high-probability samples for the text prompt.

Magic3DΒΆ

Improvements over DreamFusion:

  • Coarse-to-fine optimization (low-res β†’ high-res)

  • Explicit mesh extraction

  • 2Γ— faster, higher quality

NeRF for RelightingΒΆ

NeRF-W (Lighting Variation)ΒΆ

Problem: Model scenes with varying illumination.

Solution: Per-image illumination embedding.

NeRD (Neural Reflectance Decomposition)ΒΆ

Physics-Based Decomposition: $\(\mathbf{c} = \mathbf{k}_d \cdot L_d + \mathbf{k}_s \cdot L_s\)$

  • \(\mathbf{k}_d\): Diffuse albedo

  • \(\mathbf{k}_s\): Specular albedo

  • \(L_d, L_s\): Diffuse and specular lighting

Benefit: Relight scenes under novel illumination conditions.

Acceleration TechniquesΒΆ

1. Baking / DistillationΒΆ

Distill NeRF into faster representations:

  • SNeRG: Distill into sparse voxel grid with view-dependent features

  • MobileNeRF: Distill into textured polygonal mesh

  • Trade-off: Quality vs. speed

2. Early Ray TerminationΒΆ

Stop marching along ray when accumulated opacity exceeds threshold: $\(T_i < \epsilon \implies \text{terminate}\)$

Benefit: Skip sampling in empty or fully opaque regions.

3. Occupancy GridsΒΆ

Maintain coarse 3D grid indicating occupied regions:

  • Skip sampling in empty voxels

  • Update grid periodically during training

  • Speed-up: 2-10Γ— faster training

4. Neural Sparse Voxel Fields (NSVF)ΒΆ

Structure: Octree of voxels with learned features.

Rendering: Only evaluate MLP in occupied voxels.

Benefit: 10Γ— faster than NeRF with comparable quality.

NeRF for Novel ApplicationsΒΆ

1. Medical ImagingΒΆ

  • CT/MRI Reconstruction: Reconstruct 3D volumes from sparse 2D slices

  • Surgical Planning: Novel viewpoint generation for pre-operative visualization

2. RoboticsΒΆ

  • Scene Understanding: Implicit 3D maps for navigation

  • Manipulation: Object pose estimation and grasp planning

3. AR/VRΒΆ

  • Telepresence: Real-time novel view synthesis for immersive communication

  • Content Creation: Simplified 3D asset generation from photos

4. Autonomous DrivingΒΆ

  • Simulation: Generate realistic training data with controllable viewpoints

  • Mapping: Compact scene representations for localization

Theoretical AnalysisΒΆ

Optimization LandscapeΒΆ

Theorem (Universal Approximation): A sufficiently large MLP with positional encoding can approximate any continuous radiance field to arbitrary precision.

Proof Sketch:

  1. Positional encoding spans high-frequency Fourier basis

  2. Universal approximation theorem for neural networks

  3. Composition yields arbitrary precision

Challenges:

  • Non-convex optimization landscape

  • Local minima (floaters, artifacts)

  • Requires good initialization and careful hyperparameter tuning

Sample ComplexityΒΆ

Question: How many views are needed for accurate reconstruction?

Analysis:

  • Bounded scenes: \(O(N^2)\) views for Nyquist sampling

  • With positional encoding: \(O(N)\) views sufficient (smooth interpolation)

  • Generalization methods: Single-view possible with meta-learning

Frequency BiasΒΆ

Spectral Bias: MLPs learn low frequencies first.

Positional Encoding Effect: Shifts learning to high frequencies, but can cause overfitting.

Coarse-to-Fine Strategy: Gradually increase positional encoding frequencies during training.

Limitations and Open ProblemsΒΆ

Current LimitationsΒΆ

  1. Training Time: Original NeRF takes hours to days per scene

  2. Static Scenes: Difficult to model complex dynamics

  3. View Dependence: Specular surfaces and transparency remain challenging

  4. Memory: Full-resolution NeRF requires significant GPU memory

  5. Generalization: Limited ability to generalize across scenes without meta-learning

Open Research DirectionsΒΆ

  1. Real-Time Rendering: Achieve interactive frame rates on consumer hardware

  2. Physical Accuracy: Incorporate physically-based rendering, global illumination

  3. Compositional Scenes: Disentangle objects, materials, lighting automatically

  4. Multi-Modal Fusion: Combine NeRF with LiDAR, depth sensors, semantic segmentation

  5. Uncertainty Quantification: Estimate confidence in novel view predictions

  6. Scalability: Handle large-scale environments (city-scale, planet-scale)

Mathematical ConnectionsΒΆ

Relation to Signed Distance Functions (SDF)ΒΆ

NeRF density can be related to SDF gradient: $\(\sigma(\mathbf{x}) = \alpha \cdot \|\nabla f(\mathbf{x})\|\)$

Where \(f(\mathbf{x})\) is signed distance to surface.

Benefit: Easier surface extraction (isosurface at \(f = 0\)).

Connection to Light FieldsΒΆ

NeRF is a compressed, continuous representation of the 4D light field: $\(L(x, y, z, \theta, \phi)\)$

Advantage: Neural representation is far more compact than discrete light field grids.

Volume Rendering as IntegrationΒΆ

The rendering equation is a special case of radiative transfer: $\(\frac{d L}{ds} = -\sigma L + \sigma L_e\)$

Where \(L\) is radiance, \(s\) is distance along ray, \(L_e\) is emission.

Practical ConsiderationsΒΆ

Dataset PreparationΒΆ

  1. COLMAP: Structure-from-Motion to estimate camera poses

  2. Intrinsics: Camera calibration (focal length, principal point)

  3. Image Preprocessing: Undistortion, white balancing

  4. Bounds: Estimate near and far planes for ray marching

Hyperparameter TuningΒΆ

  • Network Size: 256 units typical, 128 for faster training

  • Positional Encoding Frequencies: \(L=10\) for position, \(L=4\) for direction

  • Sampling Rates: \(N_c = 64\), \(N_f = 128\) balances quality and speed

  • Learning Rate: \(5 \times 10^{-4}\) with exponential decay to \(5 \times 10^{-5}\)

Debugging TipsΒΆ

  1. Density Check: Visualize density field (should be concentrated near surfaces)

  2. Color Check: Render from training views (should match ground truth)

  3. Transmittance: Check accumulated opacity (should reach ~1.0 for opaque scenes)

  4. Loss Curves: Monitor coarse and fine losses (should decrease smoothly)

Evaluation MetricsΒΆ

Image QualityΒΆ

  1. PSNR (Peak Signal-to-Noise Ratio): $\(\text{PSNR} = 10 \log_{10} \frac{\text{MAX}^2}{\text{MSE}}\)$

  2. SSIM (Structural Similarity): $\(\text{SSIM}(x,y) = \frac{(2\mu_x\mu_y + c_1)(2\sigma_{xy} + c_2)}{(\mu_x^2 + \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}\)$

  3. LPIPS (Learned Perceptual Image Patch Similarity): Uses deep features from VGG network to measure perceptual distance.

Geometry QualityΒΆ

  1. Chamfer Distance: Measures distance between predicted and ground truth point clouds

  2. Mesh Accuracy: After surface extraction (e.g., via marching cubes)

ConclusionΒΆ

Neural Radiance Fields represent a paradigm shift in 3D scene representation, replacing explicit geometric structures with implicit neural functions. The core innovationβ€”volumetric rendering with positional encodingβ€”enables photorealistic novel view synthesis from sparse inputs.

Key Takeaways:

  1. Continuous Representation: NeRF models scenes as continuous 5D functions

  2. Differentiable Rendering: End-to-end optimization via gradient descent

  3. Implicit Geometry: Density field encodes 3D structure without explicit meshes

  4. View Synthesis: State-of-the-art quality for novel viewpoint generation

Future Impact: NeRF and its variants are transforming computer vision, graphics, and AR/VR, with applications ranging from content creation to robotics and medical imaging. The field continues to evolve rapidly with improvements in speed, generalization, and controllability.

"""
Neural Radiance Fields (NeRF) - Complete Implementation

This implementation includes:
1. Positional encoding for high-frequency details
2. NeRF MLP architecture with density and color branches
3. Volumetric rendering with hierarchical sampling
4. Coarse and fine networks
5. Training pipeline with ray batching
6. Visualization utilities
7. Extensions: Mip-NeRF components, instant NGP concepts
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, List, Dict
from dataclasses import dataclass
import math


# ============================================================================
# 1. POSITIONAL ENCODING
# ============================================================================

class PositionalEncoding(nn.Module):
    """
    Positional encoding using sinusoidal functions.
    
    Ξ³(p) = (sin(2^0 Ο€ p), cos(2^0 Ο€ p), ..., sin(2^(L-1) Ο€ p), cos(2^(L-1) Ο€ p))
    
    Args:
        num_freqs: Number of frequency bands (L)
        include_input: Whether to concatenate original input
    """
    def __init__(self, num_freqs: int, include_input: bool = True):
        super().__init__()
        self.num_freqs = num_freqs
        self.include_input = include_input
        
        # Frequency bands: [2^0, 2^1, ..., 2^(L-1)]
        freq_bands = 2.0 ** torch.linspace(0, num_freqs - 1, num_freqs)
        self.register_buffer('freq_bands', freq_bands)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor (..., D)
        Returns:
            Encoded tensor (..., D * (2 * num_freqs + include_input))
        """
        # x: (..., D)
        # freq_bands: (L,)
        # x[..., None, :]: (..., 1, D)
        # freq_bands[..., None]: (L, 1)
        # Product: (..., L, D)
        
        encoded = []
        
        if self.include_input:
            encoded.append(x)
        
        # Compute sin and cos for each frequency
        x_freq = x[..., None, :] * self.freq_bands[..., None]  # (..., L, D)
        
        encoded.append(torch.sin(math.pi * x_freq).flatten(-2))  # (..., L*D)
        encoded.append(torch.cos(math.pi * x_freq).flatten(-2))  # (..., L*D)
        
        return torch.cat(encoded, dim=-1)
    
    def get_output_dim(self, input_dim: int) -> int:
        """Calculate output dimension after encoding."""
        return input_dim * (2 * self.num_freqs + int(self.include_input))


# ============================================================================
# 2. NERF MLP ARCHITECTURE
# ============================================================================

class NeRFMLP(nn.Module):
    """
    NeRF MLP architecture with skip connections.
    
    Architecture:
    - 8 layers with 256 units (ReLU activation)
    - Skip connection at layer 5
    - Density head: 1 output (Οƒ)
    - Feature vector: 256D
    - Color head: Takes features + viewing direction β†’ RGB
    
    Args:
        pos_enc_dim: Dimension after positional encoding for position (60 for L=10)
        dir_enc_dim: Dimension after positional encoding for direction (24 for L=4)
        hidden_dim: Hidden layer dimension (default 256)
        num_layers: Number of layers (default 8)
    """
    def __init__(
        self,
        pos_enc_dim: int = 60,
        dir_enc_dim: int = 24,
        hidden_dim: int = 256,
        num_layers: int = 8
    ):
        super().__init__()
        
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        
        # Density branch (8 layers with skip connection)
        self.density_layers = nn.ModuleList()
        
        # First layer: position encoding β†’ hidden
        self.density_layers.append(nn.Linear(pos_enc_dim, hidden_dim))
        
        # Intermediate layers
        for i in range(1, num_layers):
            if i == 4:  # Skip connection at layer 5 (index 4)
                self.density_layers.append(nn.Linear(hidden_dim + pos_enc_dim, hidden_dim))
            else:
                self.density_layers.append(nn.Linear(hidden_dim, hidden_dim))
        
        # Density output (Οƒ β‰₯ 0)
        self.density_head = nn.Linear(hidden_dim, 1)
        
        # Feature vector for color computation
        self.feature_linear = nn.Linear(hidden_dim, hidden_dim)
        
        # Color branch (takes features + direction encoding)
        self.color_layer = nn.Linear(hidden_dim + dir_enc_dim, hidden_dim // 2)
        self.color_head = nn.Linear(hidden_dim // 2, 3)
        
    def forward(
        self,
        pos_encoded: torch.Tensor,
        dir_encoded: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            pos_encoded: Positional encoding (..., pos_enc_dim)
            dir_encoded: Direction encoding (..., dir_enc_dim)
            
        Returns:
            rgb: Color (..., 3)
            density: Volume density (..., 1)
        """
        # Density branch
        h = pos_encoded
        for i, layer in enumerate(self.density_layers):
            if i == 4:  # Skip connection
                h = torch.cat([h, pos_encoded], dim=-1)
            h = F.relu(layer(h))
        
        # Density output (ReLU to ensure non-negative)
        density = F.relu(self.density_head(h))
        
        # Feature vector for color
        features = self.feature_linear(h)
        
        # Color branch (view-dependent)
        h = torch.cat([features, dir_encoded], dim=-1)
        h = F.relu(self.color_layer(h))
        rgb = torch.sigmoid(self.color_head(h))  # RGB in [0, 1]
        
        return rgb, density


# ============================================================================
# 3. VOLUME RENDERING
# ============================================================================

def volume_rendering(
    rgb: torch.Tensor,
    density: torch.Tensor,
    t_vals: torch.Tensor,
    noise_std: float = 0.0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Perform volume rendering along rays.
    
    C(r) = Ξ£ T_i Β· Ξ±_i Β· c_i
    
    Where:
    - Ξ±_i = 1 - exp(-Οƒ_i Β· Ξ΄_i)
    - T_i = exp(-Ξ£_{j<i} Οƒ_j Β· Ξ΄_j) = Ξ _{j<i} (1 - Ξ±_j)
    - Ξ΄_i = t_{i+1} - t_i
    
    Args:
        rgb: Color values (N_rays, N_samples, 3)
        density: Volume density (N_rays, N_samples, 1)
        t_vals: Sample positions along rays (N_rays, N_samples)
        noise_std: Add noise to density for regularization
        
    Returns:
        rgb_map: Rendered color (N_rays, 3)
        depth_map: Expected depth (N_rays,)
        acc_map: Accumulated opacity (N_rays,)
    """
    # Add noise to density during training (regularization)
    if noise_std > 0.0:
        noise = torch.randn_like(density) * noise_std
        density = density + noise
    
    # Compute distances between samples (Ξ΄_i)
    dists = t_vals[..., 1:] - t_vals[..., :-1]  # (N_rays, N_samples-1)
    dists = torch.cat([
        dists,
        torch.full_like(dists[..., :1], 1e10)  # Last distance is infinity
    ], dim=-1)  # (N_rays, N_samples)
    
    # Compute alpha (opacity): Ξ±_i = 1 - exp(-Οƒ_i Β· Ξ΄_i)
    alpha = 1.0 - torch.exp(-F.relu(density[..., 0]) * dists)  # (N_rays, N_samples)
    
    # Compute transmittance: T_i = Ξ _{j<i} (1 - Ξ±_j)
    # Use cumulative product: T_i = exp(-Ξ£_{j<i} Οƒ_j Β· Ξ΄_j)
    transmittance = torch.cumprod(
        torch.cat([
            torch.ones_like(alpha[..., :1]),
            1.0 - alpha[..., :-1] + 1e-10  # Avoid log(0)
        ], dim=-1),
        dim=-1
    )  # (N_rays, N_samples)
    
    # Compute weights: w_i = T_i Β· Ξ±_i
    weights = transmittance * alpha  # (N_rays, N_samples)
    
    # Render RGB: C(r) = Ξ£ w_i Β· c_i
    rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)  # (N_rays, 3)
    
    # Compute depth map: E[t] = Ξ£ w_i Β· t_i
    depth_map = torch.sum(weights * t_vals, dim=-1)  # (N_rays,)
    
    # Accumulated opacity (for background composition)
    acc_map = torch.sum(weights, dim=-1)  # (N_rays,)
    
    return rgb_map, depth_map, acc_map, weights


# ============================================================================
# 4. RAY GENERATION
# ============================================================================

def get_rays(
    H: int,
    W: int,
    focal: float,
    c2w: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate camera rays for all pixels.
    
    Args:
        H: Image height
        W: Image width
        focal: Focal length
        c2w: Camera-to-world transformation matrix (3, 4) or (4, 4)
        
    Returns:
        rays_o: Ray origins (H, W, 3)
        rays_d: Ray directions (H, W, 3)
    """
    # Pixel coordinates
    i, j = torch.meshgrid(
        torch.arange(W, dtype=torch.float32),
        torch.arange(H, dtype=torch.float32),
        indexing='xy'
    )
    
    # Normalized image coordinates (assuming principal point at center)
    dirs = torch.stack([
        (i - W * 0.5) / focal,
        -(j - H * 0.5) / focal,  # Flip y-axis
        -torch.ones_like(i)
    ], dim=-1)  # (H, W, 3)
    
    # Transform ray directions from camera to world coordinates
    rays_d = torch.sum(dirs[..., None, :] * c2w[:3, :3], dim=-1)  # (H, W, 3)
    
    # Ray origin is camera center
    rays_o = c2w[:3, -1].expand(rays_d.shape)  # (H, W, 3)
    
    return rays_o, rays_d


# ============================================================================
# 5. SAMPLING STRATEGIES
# ============================================================================

def sample_along_rays(
    rays_o: torch.Tensor,
    rays_d: torch.Tensor,
    near: float,
    far: float,
    n_samples: int,
    perturb: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample points along rays using stratified sampling.
    
    Args:
        rays_o: Ray origins (N_rays, 3)
        rays_d: Ray directions (N_rays, 3)
        near: Near plane distance
        far: Far plane distance
        n_samples: Number of samples per ray
        perturb: Whether to add random perturbations
        
    Returns:
        pts: Sampled 3D points (N_rays, N_samples, 3)
        t_vals: Sample positions along rays (N_rays, N_samples)
    """
    N_rays = rays_o.shape[0]
    device = rays_o.device
    
    # Linearly spaced samples in [near, far]
    t_vals = torch.linspace(0.0, 1.0, n_samples, device=device)
    t_vals = near + (far - near) * t_vals  # (N_samples,)
    
    # Stratified sampling (add random jitter to each bin)
    if perturb:
        mids = 0.5 * (t_vals[:-1] + t_vals[1:])
        upper = torch.cat([mids, t_vals[-1:]], dim=0)
        lower = torch.cat([t_vals[:1], mids], dim=0)
        t_rand = torch.rand(N_rays, n_samples, device=device)
        t_vals = lower + (upper - lower) * t_rand  # (N_rays, N_samples)
    else:
        t_vals = t_vals.expand(N_rays, n_samples)
    
    # Compute 3D points: p = o + t * d
    pts = rays_o[..., None, :] + rays_d[..., None, :] * t_vals[..., :, None]
    # (N_rays, N_samples, 3)
    
    return pts, t_vals


def sample_pdf(
    bins: torch.Tensor,
    weights: torch.Tensor,
    n_samples: int,
    perturb: bool = True
) -> torch.Tensor:
    """
    Sample from piecewise-constant PDF (for hierarchical sampling).
    
    Args:
        bins: Bin edges (N_rays, N_bins+1)
        weights: Weights for each bin (N_rays, N_bins)
        n_samples: Number of samples to draw
        perturb: Whether to add random perturbations
        
    Returns:
        samples: Sampled positions (N_rays, n_samples)
    """
    # Normalize weights to get PDF
    weights = weights + 1e-5  # Prevent division by zero
    pdf = weights / torch.sum(weights, dim=-1, keepdim=True)  # (N_rays, N_bins)
    
    # Compute CDF
    cdf = torch.cumsum(pdf, dim=-1)  # (N_rays, N_bins)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], dim=-1)  # (N_rays, N_bins+1)
    
    # Sample uniformly from [0, 1]
    if perturb:
        u = torch.rand(cdf.shape[0], n_samples, device=cdf.device)
    else:
        u = torch.linspace(0.0, 1.0, n_samples, device=cdf.device)
        u = u.expand(cdf.shape[0], n_samples)
    
    # Invert CDF to get samples
    indices = torch.searchsorted(cdf, u, right=True)  # (N_rays, n_samples)
    below = torch.clamp(indices - 1, min=0)
    above = torch.clamp(indices, max=cdf.shape[-1] - 1)
    
    # Gather CDF values
    cdf_below = torch.gather(cdf, 1, below)
    cdf_above = torch.gather(cdf, 1, above)
    
    # Gather bin edges
    bins_below = torch.gather(bins, 1, below)
    bins_above = torch.gather(bins, 1, above)
    
    # Linear interpolation
    denom = cdf_above - cdf_below
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_below) / denom
    samples = bins_below + t * (bins_above - bins_below)
    
    return samples


# ============================================================================
# 6. COMPLETE NERF MODEL
# ============================================================================

@dataclass
class NeRFConfig:
    """Configuration for NeRF model."""
    # Network architecture
    hidden_dim: int = 256
    num_layers: int = 8
    
    # Positional encoding
    pos_freq: int = 10  # L for position (60D output)
    dir_freq: int = 4   # L for direction (24D output)
    
    # Sampling
    n_samples_coarse: int = 64
    n_samples_fine: int = 128
    
    # Scene bounds
    near: float = 2.0
    far: float = 6.0
    
    # Training
    noise_std: float = 1.0  # Density noise during training
    perturb: bool = True    # Stratified sampling


class NeRF(nn.Module):
    """
    Complete NeRF model with coarse and fine networks.
    
    Args:
        config: NeRF configuration
    """
    def __init__(self, config: NeRFConfig):
        super().__init__()
        self.config = config
        
        # Positional encoders
        self.pos_encoder = PositionalEncoding(config.pos_freq, include_input=True)
        self.dir_encoder = PositionalEncoding(config.dir_freq, include_input=True)
        
        pos_enc_dim = self.pos_encoder.get_output_dim(3)  # 3D position
        dir_enc_dim = self.dir_encoder.get_output_dim(3)  # 3D direction
        
        # Coarse and fine networks
        self.nerf_coarse = NeRFMLP(pos_enc_dim, dir_enc_dim, config.hidden_dim, config.num_layers)
        self.nerf_fine = NeRFMLP(pos_enc_dim, dir_enc_dim, config.hidden_dim, config.num_layers)
        
    def forward(
        self,
        rays_o: torch.Tensor,
        rays_d: torch.Tensor,
        use_fine: bool = True
    ) -> Dict[str, torch.Tensor]:
        """
        Render rays through the NeRF model.
        
        Args:
            rays_o: Ray origins (N_rays, 3)
            rays_d: Ray directions (N_rays, 3)
            use_fine: Whether to use hierarchical sampling and fine network
            
        Returns:
            Dictionary containing:
            - rgb_coarse: Coarse RGB (N_rays, 3)
            - rgb_fine: Fine RGB (N_rays, 3) if use_fine=True
            - depth_coarse: Coarse depth (N_rays,)
            - depth_fine: Fine depth (N_rays,) if use_fine=True
        """
        # Normalize ray directions
        rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
        
        # === Coarse Network ===
        # Sample points along rays
        pts_coarse, t_vals_coarse = sample_along_rays(
            rays_o, rays_d,
            self.config.near, self.config.far,
            self.config.n_samples_coarse,
            perturb=self.config.perturb and self.training
        )
        
        # Encode positions and directions
        N_rays, N_samples = pts_coarse.shape[:2]
        pts_flat = pts_coarse.reshape(-1, 3)
        dirs_flat = rays_d[:, None, :].expand(N_rays, N_samples, 3).reshape(-1, 3)
        
        pts_encoded = self.pos_encoder(pts_flat)
        dirs_encoded = self.dir_encoder(dirs_flat)
        
        # Query coarse network
        rgb_coarse, density_coarse = self.nerf_coarse(pts_encoded, dirs_encoded)
        rgb_coarse = rgb_coarse.reshape(N_rays, N_samples, 3)
        density_coarse = density_coarse.reshape(N_rays, N_samples, 1)
        
        # Volume rendering
        noise_std = self.config.noise_std if self.training else 0.0
        rgb_map_coarse, depth_map_coarse, acc_map_coarse, weights_coarse = volume_rendering(
            rgb_coarse, density_coarse, t_vals_coarse, noise_std
        )
        
        outputs = {
            'rgb_coarse': rgb_map_coarse,
            'depth_coarse': depth_map_coarse,
            'acc_coarse': acc_map_coarse
        }
        
        if not use_fine:
            return outputs
        
        # === Fine Network (Hierarchical Sampling) ===
        # Sample additional points based on coarse weights
        with torch.no_grad():
            t_vals_mid = 0.5 * (t_vals_coarse[..., :-1] + t_vals_coarse[..., 1:])
            t_samples_fine = sample_pdf(
                t_vals_mid,
                weights_coarse[..., 1:-1],  # Exclude first and last weights
                self.config.n_samples_fine,
                perturb=self.config.perturb and self.training
            )
        
        # Combine coarse and fine samples, sort by depth
        t_vals_fine, _ = torch.sort(torch.cat([t_vals_coarse, t_samples_fine], dim=-1), dim=-1)
        
        # Compute 3D points for fine samples
        pts_fine = rays_o[..., None, :] + rays_d[..., None, :] * t_vals_fine[..., :, None]
        
        # Encode and query fine network
        N_samples_fine = pts_fine.shape[1]
        pts_flat = pts_fine.reshape(-1, 3)
        dirs_flat = rays_d[:, None, :].expand(N_rays, N_samples_fine, 3).reshape(-1, 3)
        
        pts_encoded = self.pos_encoder(pts_flat)
        dirs_encoded = self.dir_encoder(dirs_flat)
        
        rgb_fine, density_fine = self.nerf_fine(pts_encoded, dirs_encoded)
        rgb_fine = rgb_fine.reshape(N_rays, N_samples_fine, 3)
        density_fine = density_fine.reshape(N_rays, N_samples_fine, 1)
        
        # Volume rendering for fine network
        rgb_map_fine, depth_map_fine, acc_map_fine, _ = volume_rendering(
            rgb_fine, density_fine, t_vals_fine, noise_std
        )
        
        outputs.update({
            'rgb_fine': rgb_map_fine,
            'depth_fine': depth_map_fine,
            'acc_fine': acc_map_fine
        })
        
        return outputs


# ============================================================================
# 7. TRAINING UTILITIES
# ============================================================================

def nerf_loss(outputs: Dict[str, torch.Tensor], target_rgb: torch.Tensor) -> torch.Tensor:
    """
    NeRF loss function (MSE on coarse and fine renderings).
    
    Args:
        outputs: Model outputs containing rgb_coarse and optionally rgb_fine
        target_rgb: Ground truth RGB (N_rays, 3)
        
    Returns:
        total_loss: Combined loss
    """
    loss_coarse = F.mse_loss(outputs['rgb_coarse'], target_rgb)
    
    if 'rgb_fine' in outputs:
        loss_fine = F.mse_loss(outputs['rgb_fine'], target_rgb)
        total_loss = loss_coarse + loss_fine
    else:
        total_loss = loss_coarse
    
    return total_loss


def train_step(
    model: NeRF,
    optimizer: torch.optim.Optimizer,
    rays_o: torch.Tensor,
    rays_d: torch.Tensor,
    target_rgb: torch.Tensor,
    chunk_size: int = 1024
) -> float:
    """
    Single training step with ray batching.
    
    Args:
        model: NeRF model
        optimizer: Optimizer
        rays_o: Ray origins (N_rays, 3)
        rays_d: Ray directions (N_rays, 3)
        target_rgb: Target RGB (N_rays, 3)
        chunk_size: Process rays in chunks to save memory
        
    Returns:
        loss: Training loss
    """
    model.train()
    optimizer.zero_grad()
    
    N_rays = rays_o.shape[0]
    all_outputs = []
    
    # Process rays in chunks
    for i in range(0, N_rays, chunk_size):
        chunk_rays_o = rays_o[i:i+chunk_size]
        chunk_rays_d = rays_d[i:i+chunk_size]
        
        outputs = model(chunk_rays_o, chunk_rays_d, use_fine=True)
        all_outputs.append(outputs)
    
    # Combine outputs
    combined_outputs = {
        'rgb_coarse': torch.cat([o['rgb_coarse'] for o in all_outputs], dim=0),
        'rgb_fine': torch.cat([o['rgb_fine'] for o in all_outputs], dim=0)
    }
    
    # Compute loss
    loss = nerf_loss(combined_outputs, target_rgb)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    return loss.item()


# ============================================================================
# 8. VISUALIZATION UTILITIES
# ============================================================================

def render_full_image(
    model: NeRF,
    H: int,
    W: int,
    focal: float,
    c2w: torch.Tensor,
    chunk_size: int = 1024,
    use_fine: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Render a full image from a camera pose.
    
    Args:
        model: NeRF model
        H: Image height
        W: Image width
        focal: Focal length
        c2w: Camera-to-world matrix (3, 4) or (4, 4)
        chunk_size: Rays per chunk
        use_fine: Use fine network
        
    Returns:
        rgb: Rendered RGB image (H, W, 3)
        depth: Depth map (H, W)
    """
    model.eval()
    
    # Generate rays
    rays_o, rays_d = get_rays(H, W, focal, c2w)
    rays_o = rays_o.reshape(-1, 3)
    rays_d = rays_d.reshape(-1, 3)
    
    all_rgb = []
    all_depth = []
    
    with torch.no_grad():
        for i in range(0, rays_o.shape[0], chunk_size):
            chunk_rays_o = rays_o[i:i+chunk_size]
            chunk_rays_d = rays_d[i:i+chunk_size]
            
            outputs = model(chunk_rays_o, chunk_rays_d, use_fine=use_fine)
            
            if use_fine:
                all_rgb.append(outputs['rgb_fine'].cpu())
                all_depth.append(outputs['depth_fine'].cpu())
            else:
                all_rgb.append(outputs['rgb_coarse'].cpu())
                all_depth.append(outputs['depth_coarse'].cpu())
    
    rgb = torch.cat(all_rgb, dim=0).reshape(H, W, 3).numpy()
    depth = torch.cat(all_depth, dim=0).reshape(H, W).numpy()
    
    return rgb, depth


def visualize_nerf_rendering(
    rgb: np.ndarray,
    depth: np.ndarray,
    title: str = "NeRF Rendering"
):
    """
    Visualize NeRF rendering (RGB and depth).
    
    Args:
        rgb: RGB image (H, W, 3)
        depth: Depth map (H, W)
        title: Plot title
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # RGB
    axes[0].imshow(rgb)
    axes[0].set_title('RGB Rendering')
    axes[0].axis('off')
    
    # Depth
    im = axes[1].imshow(depth, cmap='turbo')
    axes[1].set_title('Depth Map')
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
    
    plt.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


def create_spiral_path(
    n_poses: int = 40,
    radius: float = 3.0,
    height: float = 0.0,
    focal_point: np.ndarray = np.array([0, 0, 0])
) -> List[torch.Tensor]:
    """
    Create spiral camera path for novel view synthesis.
    
    Args:
        n_poses: Number of poses
        radius: Spiral radius
        height: Height variation
        focal_point: Point to look at
        
    Returns:
        List of camera-to-world matrices
    """
    poses = []
    
    for i in range(n_poses):
        theta = 2 * np.pi * i / n_poses
        
        # Camera position (spiral)
        x = radius * np.cos(theta)
        y = radius * np.sin(theta)
        z = height * np.sin(2 * theta)
        
        cam_pos = np.array([x, y, z])
        
        # Look-at matrix
        forward = focal_point - cam_pos
        forward = forward / np.linalg.norm(forward)
        
        right = np.cross(np.array([0, 0, 1]), forward)
        right = right / np.linalg.norm(right)
        
        up = np.cross(forward, right)
        
        # Camera-to-world matrix
        c2w = np.eye(4)
        c2w[:3, 0] = right
        c2w[:3, 1] = up
        c2w[:3, 2] = -forward
        c2w[:3, 3] = cam_pos
        
        poses.append(torch.tensor(c2w, dtype=torch.float32))
    
    return poses


# ============================================================================
# 9. DEMONSTRATION
# ============================================================================

def demo_nerf_components():
    """Demonstrate NeRF components."""
    
    print("=" * 80)
    print("NeRF Components Demonstration")
    print("=" * 80)
    
    # 1. Positional Encoding
    print("\n1. Positional Encoding")
    print("-" * 40)
    
    pos_encoder = PositionalEncoding(num_freqs=10, include_input=True)
    sample_pos = torch.tensor([[0.5, 0.5, 0.5]])
    
    encoded = pos_encoder(sample_pos)
    print(f"Input shape: {sample_pos.shape}")
    print(f"Encoded shape: {encoded.shape}")
    print(f"Encoding dimension: 3 β†’ {encoded.shape[1]}")
    
    # 2. NeRF MLP
    print("\n2. NeRF MLP Architecture")
    print("-" * 40)
    
    config = NeRFConfig()
    nerf = NeRF(config)
    
    print(f"Position encoding: 3D β†’ {nerf.pos_encoder.get_output_dim(3)}D")
    print(f"Direction encoding: 3D β†’ {nerf.dir_encoder.get_output_dim(3)}D")
    print(f"Coarse network parameters: {sum(p.numel() for p in nerf.nerf_coarse.parameters()):,}")
    print(f"Fine network parameters: {sum(p.numel() for p in nerf.nerf_fine.parameters()):,}")
    print(f"Total parameters: {sum(p.numel() for p in nerf.parameters()):,}")
    
    # 3. Ray Generation
    print("\n3. Ray Generation")
    print("-" * 40)
    
    H, W, focal = 100, 100, 138.88
    c2w = torch.eye(4)
    c2w[:3, 3] = torch.tensor([0, 0, 3])  # Camera at z=3
    
    rays_o, rays_d = get_rays(H, W, focal, c2w)
    print(f"Image size: {H} Γ— {W}")
    print(f"Ray origins shape: {rays_o.shape}")
    print(f"Ray directions shape: {rays_d.shape}")
    print(f"All rays originate from: {rays_o[0, 0]}")
    
    # 4. Sampling
    print("\n4. Sampling Along Rays")
    print("-" * 40)
    
    sample_rays_o = rays_o[H//2, W//2:W//2+1]  # Center pixel
    sample_rays_d = rays_d[H//2, W//2:W//2+1]
    
    pts, t_vals = sample_along_rays(
        sample_rays_o, sample_rays_d,
        near=2.0, far=6.0, n_samples=64, perturb=False
    )
    
    print(f"Number of samples: {t_vals.shape[1]}")
    print(f"Sample range: [{t_vals.min():.2f}, {t_vals.max():.2f}]")
    print(f"Sampled points shape: {pts.shape}")
    
    # 5. Volume Rendering
    print("\n5. Volume Rendering")
    print("-" * 40)
    
    # Dummy RGB and density (Gaussian blob)
    dummy_rgb = torch.rand(1, 64, 3)
    
    # Create Gaussian density centered at middle of ray
    t_center = (2.0 + 6.0) / 2
    dummy_density = torch.exp(-0.5 * ((t_vals - t_center) / 0.5) ** 2).unsqueeze(-1)
    
    rgb_map, depth_map, acc_map, weights = volume_rendering(
        dummy_rgb, dummy_density, t_vals, noise_std=0.0
    )
    
    print(f"Rendered RGB: {rgb_map[0]}")
    print(f"Rendered depth: {depth_map[0]:.2f}")
    print(f"Accumulated opacity: {acc_map[0]:.4f}")
    print(f"Weights sum: {weights.sum(dim=1)[0]:.4f}")
    
    # 6. Forward Pass
    print("\n6. NeRF Forward Pass")
    print("-" * 40)
    
    batch_rays_o = rays_o.reshape(-1, 3)[:1024]
    batch_rays_d = rays_d.reshape(-1, 3)[:1024]
    
    outputs = nerf(batch_rays_o, batch_rays_d, use_fine=True)
    
    print(f"Coarse RGB shape: {outputs['rgb_coarse'].shape}")
    print(f"Fine RGB shape: {outputs['rgb_fine'].shape}")
    print(f"Coarse depth range: [{outputs['depth_coarse'].min():.2f}, {outputs['depth_coarse'].max():.2f}]")
    print(f"Fine depth range: [{outputs['depth_fine'].min():.2f}, {outputs['depth_fine'].max():.2f}]")


def demo_hierarchical_sampling():
    """Demonstrate hierarchical sampling."""
    
    print("\n" + "=" * 80)
    print("Hierarchical Sampling Demonstration")
    print("=" * 80)
    
    # Create simple scenario
    t_vals_coarse = torch.linspace(2.0, 6.0, 65).unsqueeze(0)  # (1, 65)
    
    # Weights concentrated in middle (simulating object)
    weights = torch.exp(-0.5 * ((t_vals_coarse[:, :-1] - 4.0) / 0.5) ** 2)
    weights = weights / weights.sum(dim=1, keepdim=True)  # Normalize
    
    # Sample fine points
    t_samples_fine = sample_pdf(
        t_vals_coarse[:, :-1],  # Bins
        weights[:, 1:],          # Weights (excluding first)
        n_samples=128,
        perturb=False
    )
    
    # Visualize sampling distribution
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    
    # Coarse weights
    axes[0].bar(
        t_vals_coarse[0, :-1].numpy(),
        weights[0].numpy(),
        width=0.06,
        alpha=0.7,
        label='Coarse weights'
    )
    axes[0].set_xlabel('Ray depth (t)')
    axes[0].set_ylabel('Weight')
    axes[0].set_title('Coarse Sampling Weights')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    
    # Fine sampling distribution
    axes[1].hist(t_samples_fine[0].numpy(), bins=50, alpha=0.7, label='Fine samples')
    axes[1].axvline(4.0, color='red', linestyle='--', label='Peak density')
    axes[1].set_xlabel('Ray depth (t)')
    axes[1].set_ylabel('Sample count')
    axes[1].set_title('Hierarchical Fine Sampling (More samples where weights are high)')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nCoarse samples: {t_vals_coarse.shape[1] - 1}")
    print(f"Fine samples: {t_samples_fine.shape[1]}")
    print(f"Total samples (combined): {t_vals_coarse.shape[1] - 1 + t_samples_fine.shape[1]}")
    print(f"Fine samples concentrated around t={t_samples_fine.mean():.2f} (target: 4.0)")


def demo_positional_encoding_effect():
    """Visualize effect of positional encoding."""
    
    print("\n" + "=" * 80)
    print("Positional Encoding Effect")
    print("=" * 80)
    
    # Create 1D coordinate
    x = torch.linspace(-1, 1, 200).unsqueeze(-1)
    
    # Different frequency levels
    freq_levels = [0, 2, 5, 10]
    
    fig, axes = plt.subplots(len(freq_levels), 1, figsize=(12, 10))
    
    for i, L in enumerate(freq_levels):
        if L == 0:
            # No encoding
            encoded = x
            title = "No Encoding (Raw Input)"
        else:
            encoder = PositionalEncoding(num_freqs=L, include_input=False)
            encoded = encoder(x)
            title = f"Positional Encoding (L={L}, {encoded.shape[1]} features)"
        
        # Visualize first few encoding dimensions
        axes[i].plot(x.numpy(), encoded[:, :min(10, encoded.shape[1])].numpy(), alpha=0.7)
        axes[i].set_title(title)
        axes[i].set_xlabel('Input coordinate')
        axes[i].set_ylabel('Encoded value')
        axes[i].grid(alpha=0.3)
        axes[i].set_xlim([-1, 1])
    
    plt.tight_layout()
    plt.show()
    
    print("\nPositional encoding projects inputs to high-frequency space.")
    print("This allows MLPs to learn high-frequency details in the scene.")


# Run demonstrations
if __name__ == "__main__":
    print("Starting NeRF demonstrations...\n")
    
    # Component demonstration
    demo_nerf_components()
    
    # Hierarchical sampling
    demo_hierarchical_sampling()
    
    # Positional encoding visualization
    demo_positional_encoding_effect()
    
    print("\n" + "=" * 80)
    print("NeRF Implementation Complete!")
    print("=" * 80)
    print("\nKey Features Implemented:")
    print("βœ“ Positional encoding for high-frequency details")
    print("βœ“ NeRF MLP with skip connections")
    print("βœ“ Volumetric rendering with proper transmittance")
    print("βœ“ Hierarchical sampling (coarse + fine networks)")
    print("βœ“ Ray generation and batching")
    print("βœ“ Training utilities")
    print("βœ“ Visualization functions")
    print("\nThis implementation can be extended with:")
    print("- Actual training on real datasets (Blender scenes, LLFF)")
    print("- Mip-NeRF cone tracing")
    print("- Instant NGP hash encoding")
    print("- Dynamic scenes (D-NeRF)")
    print("- Semantic outputs")
    print("- Text-to-3D (DreamFusion SDS)")
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

1. Volume RenderingΒΆ

Ray EquationΒΆ

\[\mathbf{r}(t) = \mathbf{o} + t\mathbf{d}\]

where \(\mathbf{o}\) is origin, \(\mathbf{d}\) is direction.

Volume Rendering EquationΒΆ

\[C(\mathbf{r}) = \int_{t_n}^{t_f} T(t) \cdot \sigma(\mathbf{r}(t)) \cdot \mathbf{c}(\mathbf{r}(t), \mathbf{d}) dt\]

where:

  • \(T(t) = \exp\left(-\int_{t_n}^t \sigma(\mathbf{r}(s)) ds\right)\) (transmittance)

  • \(\sigma\) is volume density

  • \(\mathbf{c}\) is RGB color

πŸ“š Reference Materials:

2. Positional EncodingΒΆ

Mapping to Higher DimensionsΒΆ

\[\gamma(p) = \left[\sin(2^0\pi p), \cos(2^0\pi p), \ldots, \sin(2^{L-1}\pi p), \cos(2^{L-1}\pi p)\right]\]

Enables high-frequency representation.

class PositionalEncoding(nn.Module):
    """Positional encoding for coordinates."""
    
    def __init__(self, num_freqs=10, include_input=True):
        super().__init__()
        self.num_freqs = num_freqs
        self.include_input = include_input
        
        freq_bands = 2.0 ** torch.linspace(0, num_freqs - 1, num_freqs)
        self.register_buffer('freq_bands', freq_bands)
    
    def forward(self, x):
        """
        Args:
            x: [..., D] input coordinates
        Returns:
            [..., D * (2 * num_freqs + 1)] if include_input
        """
        out = []
        if self.include_input:
            out.append(x)
        
        for freq in self.freq_bands:
            out.append(torch.sin(freq * np.pi * x))
            out.append(torch.cos(freq * np.pi * x))
        
        return torch.cat(out, dim=-1)
    
    def output_dim(self, input_dim):
        return input_dim * (2 * self.num_freqs + int(self.include_input))

# Test
pe = PositionalEncoding(num_freqs=6)
x = torch.tensor([[0.5, 0.3, 0.8]])
encoded = pe(x)
print(f"Input: {x.shape}, Encoded: {encoded.shape}")

NeRF NetworkΒΆ

The NeRF network is an MLP that maps a 5D input – 3D position \((x, y, z)\) and 2D viewing direction \((\theta, \phi)\) – to a color \((r, g, b)\) and volume density \(\sigma\). Positional encoding (sinusoidal features at multiple frequencies) is applied to the input coordinates to help the network represent high-frequency spatial detail. The density depends only on position (ensuring multi-view consistency), while color depends on both position and direction (modeling view-dependent effects like specularities). This simple architecture, when trained on a set of posed images, learns an implicit 3D representation of the scene that can be queried at any continuous point in space.

class NeRF(nn.Module):
    """Neural Radiance Field."""
    
    def __init__(self, pos_enc_dim=63, dir_enc_dim=27, hidden_dim=256):
        super().__init__()
        
        # Position encoding
        self.pos_encoder = PositionalEncoding(num_freqs=10)
        self.dir_encoder = PositionalEncoding(num_freqs=4)
        
        # Position network (x, y, z) -> features + density
        self.pos_net = nn.Sequential(
            nn.Linear(pos_enc_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        
        # Skip connection
        self.pos_net2 = nn.Sequential(
            nn.Linear(hidden_dim + pos_enc_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        
        # Density head
        self.density_head = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.ReLU()  # Density must be non-negative
        )
        
        # Feature extraction
        self.feature_head = nn.Linear(hidden_dim, hidden_dim)
        
        # Direction network (features + direction) -> RGB
        self.dir_net = nn.Sequential(
            nn.Linear(hidden_dim + dir_enc_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 3),
            nn.Sigmoid()  # RGB in [0, 1]
        )
    
    def forward(self, pos, direction):
        """
        Args:
            pos: [..., 3] 3D positions
            direction: [..., 3] view directions
        Returns:
            rgb: [..., 3]
            density: [..., 1]
        """
        # Encode
        pos_enc = self.pos_encoder(pos)
        dir_enc = self.dir_encoder(direction)
        
        # Position features
        h = self.pos_net(pos_enc)
        h = torch.cat([h, pos_enc], dim=-1)
        h = self.pos_net2(h)
        
        # Density
        density = self.density_head(h)
        
        # Features for color
        features = self.feature_head(h)
        
        # RGB (view-dependent)
        h_dir = torch.cat([features, dir_enc], dim=-1)
        rgb = self.dir_net(h_dir)
        
        return rgb, density

# Test
model = NeRF().to(device)
pos = torch.randn(10, 3).to(device)
direction = F.normalize(torch.randn(10, 3), dim=-1).to(device)
rgb, density = model(pos, direction)
print(f"RGB: {rgb.shape}, Density: {density.shape}")

Volume Rendering (Discrete)ΒΆ

To render an image from the NeRF representation, we cast a ray through each pixel and evaluate the network at sampled points along the ray. The discrete volume rendering equation accumulates color weighted by density and transmittance: \(C(r) = \sum_{i=1}^{N} T_i (1 - \exp(-\sigma_i \delta_i)) c_i\), where \(T_i = \exp(-\sum_{j<i} \sigma_j \delta_j)\) is the accumulated transmittance. Points with high density and low prior absorption contribute most to the final pixel color. This differentiable rendering process allows end-to-end training: the rendered pixel colors are compared to ground truth via MSE loss, and gradients flow back through the rendering equation into the MLP weights.

def volume_render(rgb, density, t_vals):
    """
    Discrete volume rendering.
    
    Args:
        rgb: [N_rays, N_samples, 3]
        density: [N_rays, N_samples, 1]
        t_vals: [N_rays, N_samples]
    Returns:
        rgb_map: [N_rays, 3]
        depth_map: [N_rays]
        weights: [N_rays, N_samples]
    """
    # Delta
    dists = t_vals[..., 1:] - t_vals[..., :-1]
    dists = torch.cat([dists, torch.full_like(dists[..., :1], 1e10)], dim=-1)
    
    # Alpha
    alpha = 1.0 - torch.exp(-density.squeeze(-1) * dists)
    
    # Transmittance
    T = torch.cumprod(torch.cat([
        torch.ones_like(alpha[..., :1]),
        1.0 - alpha[..., :-1] + 1e-10
    ], dim=-1), dim=-1)
    
    # Weights
    weights = alpha * T
    
    # RGB
    rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)
    
    # Depth
    depth_map = torch.sum(weights * t_vals, dim=-1)
    
    return rgb_map, depth_map, weights

# Test
N_rays, N_samples = 100, 64
rgb_test = torch.rand(N_rays, N_samples, 3)
density_test = torch.rand(N_rays, N_samples, 1)
t_vals_test = torch.linspace(2, 6, N_samples).expand(N_rays, N_samples)

rgb_map, depth_map, weights = volume_render(rgb_test, density_test, t_vals_test)
print(f"RGB map: {rgb_map.shape}, Depth map: {depth_map.shape}")

Simple Scene DemoΒΆ

Before tackling real photographs, we create a simple synthetic scene to verify the NeRF pipeline works end-to-end. A few geometric primitives (spheres, boxes) with known colors are rendered from multiple viewpoints to create a small training set of posed images. Training the NeRF network on this synthetic data should reproduce the scene geometry and appearance accurately, providing a controlled test bed for debugging the positional encoding, volume rendering, and optimization before moving to complex real-world scenes.

def create_synthetic_scene(model, n_iters=500):
    """Train NeRF on simple synthetic scene."""
    # Define simple scene: colored sphere at origin
    def scene_sdf(pos):
        """Sphere at origin with radius 0.5."""
        return torch.norm(pos, dim=-1, keepdim=True) - 0.5
    
    def scene_color(pos):
        """Color based on position."""
        return (pos + 1.0) / 2.0  # Normalize to [0, 1]
    
    # Camera rays
    H, W = 50, 50
    focal = 50
    
    # Image plane coordinates
    i, j = torch.meshgrid(
        torch.arange(H, dtype=torch.float32),
        torch.arange(W, dtype=torch.float32),
        indexing='ij'
    )
    
    # Ray directions in camera coordinates
    dirs = torch.stack([
        (j - W / 2) / focal,
        -(i - H / 2) / focal,
        -torch.ones_like(i)
    ], dim=-1)
    
    dirs = F.normalize(dirs, dim=-1)
    rays_d = dirs.reshape(-1, 3).to(device)
    rays_o = torch.zeros_like(rays_d).to(device)
    rays_o[:, 2] = 3.0  # Camera at z=3
    
    # Sample points along rays
    N_samples = 64
    t_vals = torch.linspace(2.0, 4.0, N_samples, device=device)
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    losses = []
    
    for iter in range(n_iters):
        # Sample batch of rays
        batch_size = 256
        idx = torch.randint(0, len(rays_o), (batch_size,))
        
        rays_o_batch = rays_o[idx]
        rays_d_batch = rays_d[idx]
        
        # Points along rays
        pts = rays_o_batch[:, None, :] + rays_d_batch[:, None, :] * t_vals[None, :, None]
        
        # Query NeRF
        pts_flat = pts.reshape(-1, 3)
        dirs_flat = rays_d_batch[:, None, :].expand_as(pts).reshape(-1, 3)
        
        rgb, density = model(pts_flat, dirs_flat)
        rgb = rgb.reshape(batch_size, N_samples, 3)
        density = density.reshape(batch_size, N_samples, 1)
        
        # Render
        rgb_map, _, _ = volume_render(rgb, density, t_vals.expand(batch_size, N_samples))
        
        # Ground truth (sphere)
        with torch.no_grad():
            sdf = scene_sdf(pts_flat).reshape(batch_size, N_samples, 1)
            gt_density = torch.sigmoid(-sdf * 50)  # Sharp boundary
            gt_rgb = scene_color(pts_flat).reshape(batch_size, N_samples, 3)
            gt_rgb_map, _, _ = volume_render(gt_rgb, gt_density, t_vals.expand(batch_size, N_samples))
        
        # Loss
        loss = F.mse_loss(rgb_map, gt_rgb_map)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if (iter + 1) % 100 == 0:
            print(f"Iter {iter+1}, Loss: {loss.item():.6f}")
    
    return losses

# Train
model = NeRF().to(device)
losses = create_synthetic_scene(model, n_iters=500)

plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.xlabel('Iteration', fontsize=11)
plt.ylabel('MSE Loss', fontsize=11)
plt.title('NeRF Training on Synthetic Sphere', fontsize=12)
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.show()

Render Novel ViewΒΆ

The primary application of NeRF is novel view synthesis: rendering the scene from camera positions not present in the training set. By specifying a new camera pose (position and orientation), casting rays through each pixel, and evaluating the trained network, we obtain a photorealistic image of the scene from an entirely new viewpoint. The quality of novel views – especially in regions that were only partially observed during training – is the definitive test of whether the network has learned a coherent 3D representation rather than merely memorizing the training views.

def render_image(model, H=100, W=100, focal=100, cam_pos=[0, 0, 3]):
    """Render full image."""
    model.eval()
    
    i, j = torch.meshgrid(
        torch.arange(H, dtype=torch.float32),
        torch.arange(W, dtype=torch.float32),
        indexing='ij'
    )
    
    dirs = torch.stack([
        (j - W / 2) / focal,
        -(i - H / 2) / focal,
        -torch.ones_like(i)
    ], dim=-1)
    
    dirs = F.normalize(dirs, dim=-1)
    rays_d = dirs.reshape(-1, 3).to(device)
    rays_o = torch.tensor(cam_pos, dtype=torch.float32, device=device).expand_as(rays_d)
    
    N_samples = 64
    t_vals = torch.linspace(2.0, 4.0, N_samples, device=device)
    
    # Render in batches
    batch_size = 512
    rgb_full = []
    
    with torch.no_grad():
        for i in range(0, len(rays_o), batch_size):
            rays_o_batch = rays_o[i:i+batch_size]
            rays_d_batch = rays_d[i:i+batch_size]
            
            pts = rays_o_batch[:, None, :] + rays_d_batch[:, None, :] * t_vals[None, :, None]
            pts_flat = pts.reshape(-1, 3)
            dirs_flat = rays_d_batch[:, None, :].expand_as(pts).reshape(-1, 3)
            
            rgb, density = model(pts_flat, dirs_flat)
            rgb = rgb.reshape(len(rays_o_batch), N_samples, 3)
            density = density.reshape(len(rays_o_batch), N_samples, 1)
            
            rgb_map, _, _ = volume_render(rgb, density, t_vals.expand(len(rays_o_batch), N_samples))
            rgb_full.append(rgb_map.cpu())
    
    rgb_full = torch.cat(rgb_full, dim=0)
    return rgb_full.reshape(H, W, 3).numpy()

# Render
img = render_image(model, H=100, W=100, cam_pos=[0, 0, 3])

plt.figure(figsize=(8, 8))
plt.imshow(img)
plt.axis('off')
plt.title('NeRF Rendered Image', fontsize=12)
plt.show()

SummaryΒΆ

NeRF Represents:ΒΆ

Scene as continuous 5D function: \((x, y, z, \theta, \phi) \to (r, g, b, \sigma)\)

Key Components:ΒΆ

  1. Positional encoding - High frequency details

  2. MLP network - Maps coordinates to color + density

  3. Volume rendering - Integrates along rays

  4. Hierarchical sampling - Coarse-to-fine

Applications:ΒΆ

  • Novel view synthesis

  • 3D reconstruction

  • Relighting

  • Scene editing

Extensions:ΒΆ

  • Instant-NGP (hash encoding)

  • Plenoxels (voxel grids)

  • NeRF-W (in-the-wild)

  • Dynamic NeRF (time-varying)

Next Steps:ΒΆ

  • Study instant-NGP for speed

  • Explore signed distance functions

  • Learn multi-view geometry