Image Segmentation: Pixel-Level UnderstandingΒΆ

Segmentation assigns a class to every pixel β€” it’s what makes self-driving cars understand roads from sidewalks. This notebook covers semantic segmentation concepts, the U-Net architecture, and practical usage of pretrained segmentation models.

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy import ndimage
import warnings
warnings.filterwarnings('ignore')

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    HAS_TORCH = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
except ImportError:
    HAS_TORCH = False

np.random.seed(42)

# Create synthetic segmentation scene
def create_segmentation_scene(size=128):
    """Scene with 4 regions: background, road, car, pedestrian."""
    img = np.zeros((size, size, 3), dtype=np.uint8)
    mask = np.zeros((size, size), dtype=np.int64)
    
    # Background (sky): class 0 β€” blue gradient
    img[:size//2, :, 2] = 180
    img[:size//2, :, 0] = np.linspace(100, 150, size//2)[:, np.newaxis]
    
    # Road: class 1 β€” gray
    img[size//2:, :, :] = 100
    mask[size//2:, :] = 1
    
    # Car: class 2 β€” red
    img[size//2:3*size//4, size//4:3*size//4, :] = [180, 50, 50]
    mask[size//2:3*size//4, size//4:3*size//4] = 2
    
    # Pedestrian: class 3 β€” green
    img[size//2+10:size//2+40, 10:25, :] = [50, 170, 50]
    mask[size//2+10:size//2+40, 10:25] = 3
    
    # Add noise
    img = np.clip(img.astype(np.float32) + np.random.normal(0, 10, img.shape), 0, 255).astype(np.uint8)
    return img, mask

CLASS_NAMES = ['background', 'road', 'car', 'pedestrian']
CLASS_COLORS = [(100, 150, 180), (100, 100, 100), (180, 50, 50), (50, 170, 50)]

scene_img, gt_mask = create_segmentation_scene()

def colorize_mask(mask, colors=CLASS_COLORS):
    colored = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for cls, color in enumerate(colors):
        colored[mask == cls] = color
    return colored

fig, axes = plt.subplots(1, 3, figsize=(13, 4))
axes[0].imshow(scene_img)
axes[0].set_title('Input Image')
axes[0].axis('off')

axes[1].imshow(colorize_mask(gt_mask))
axes[1].set_title('Ground Truth Mask')
axes[1].axis('off')

# Overlay
overlay = scene_img.copy().astype(np.float32)
colored = colorize_mask(gt_mask).astype(np.float32)
axes[2].imshow((0.5 * overlay + 0.5 * colored).astype(np.uint8))
axes[2].set_title('Overlay')
axes[2].axis('off')

for i, (name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS)):
    axes[1].text(2, 10 + i*15, f'β–  {name}', color=np.array(color)/255, fontsize=8)

plt.suptitle('Semantic Segmentation: Pixel-Level Classification')
plt.tight_layout()
plt.show()

print(f'Image: {scene_img.shape}, Mask: {gt_mask.shape}')
for cls in range(4):
    pct = (gt_mask == cls).mean() * 100
    print(f'  Class {cls} ({CLASS_NAMES[cls]}): {pct:.1f}% of pixels')

1. Segmentation vs Detection vs ClassificationΒΆ

comparison = {
    'Task': ['Classification', 'Object Detection', 'Semantic Segmentation', 'Instance Segmentation', 'Panoptic Segmentation'],
    'Output': ['Single label', 'Boxes + labels', 'Pixel class mask', 'Pixel mask per object', 'Full scene understanding'],
    'Resolution': ['Image-level', 'Object-level', 'Pixel-level', 'Pixel-level', 'Pixel-level'],
    'Distinguishes instances?': ['No', 'Yes (by box)', 'No', 'Yes', 'Yes'],
    'Models': ['ResNet, ViT', 'YOLO, Faster R-CNN', 'FCN, U-Net, DeepLab', 'Mask R-CNN, SOLOv2', 'Panoptic FPN'],
}
import pandas as pd
df = pd.DataFrame(comparison)
print(df.to_string(index=False))
print()
print('Semantic vs Instance segmentation:')
print('  Semantic: all cars are class=2 (same color on mask)')
print('  Instance: car1 and car2 are different colors/IDs')

2. U-Net ArchitectureΒΆ

if HAS_TORCH:
    class DoubleConv(nn.Module):
        """Conv β†’ BN β†’ ReLU β†’ Conv β†’ BN β†’ ReLU."""
        def __init__(self, in_ch, out_ch):
            super().__init__()
            self.net = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(),
            )
        def forward(self, x): return self.net(x)
    
    class UNet(nn.Module):
        """
        U-Net: Encoder-Decoder with skip connections.
        
        Encoder: progressively reduce spatial size, increase channels
        Bottleneck: deepest representation
        Decoder: progressively increase spatial size, decrease channels
        Skip connections: concatenate encoder features to decoder
        """
        def __init__(self, in_channels=3, num_classes=4, base_ch=32):
            super().__init__()
            
            # Encoder (contracting path)
            self.enc1 = DoubleConv(in_channels, base_ch)      # 128β†’128
            self.enc2 = DoubleConv(base_ch, base_ch*2)        # 64β†’64
            self.enc3 = DoubleConv(base_ch*2, base_ch*4)      # 32β†’32
            self.pool = nn.MaxPool2d(2)                        # Halves spatial
            
            # Bottleneck
            self.bottleneck = DoubleConv(base_ch*4, base_ch*8) # 16β†’16
            
            # Decoder (expanding path)
            self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, stride=2)  # Upsample
            self.dec3 = DoubleConv(base_ch*8, base_ch*4)      # cat with enc3
            
            self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, stride=2)
            self.dec2 = DoubleConv(base_ch*4, base_ch*2)
            
            self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, stride=2)
            self.dec1 = DoubleConv(base_ch*2, base_ch)
            
            # Output: 1Γ—1 conv to num_classes
            self.output_conv = nn.Conv2d(base_ch, num_classes, 1)
        
        def forward(self, x):
            # Encoder
            e1 = self.enc1(x)
            e2 = self.enc2(self.pool(e1))
            e3 = self.enc3(self.pool(e2))
            
            # Bottleneck
            b = self.bottleneck(self.pool(e3))
            
            # Decoder with skip connections (concatenate encoder features)
            d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
            d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
            d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
            
            return self.output_conv(d1)  # (B, num_classes, H, W)
    
    unet = UNet(in_channels=3, num_classes=4, base_ch=32)
    total_params = sum(p.numel() for p in unet.parameters())
    print(f'U-Net parameters: {total_params:,}')
    
    # Verify output shape
    dummy = torch.randn(1, 3, 128, 128)
    out = unet(dummy)
    print(f'Input shape:  {dummy.shape}')
    print(f'Output shape: {out.shape}  β†’ (batch, classes, height, width)')
    print('Predicted class per pixel: logits.argmax(dim=1) β†’ (batch, H, W)')
else:
    print('U-Net Architecture:')
    print()
    print('Input (3, 128, 128)')
    print('  ↓ Encoder 1: DoubleConv β†’ (32, 128, 128)')
    print('  ↓ MaxPool   β†’ (32, 64, 64)')
    print('  ↓ Encoder 2: DoubleConv β†’ (64, 64, 64)')
    print('  ↓ MaxPool   β†’ (64, 32, 32)')
    print('  ↓ Encoder 3: DoubleConv β†’ (128, 32, 32)')
    print('  ↓ MaxPool   β†’ (128, 16, 16)')
    print('  Bottleneck: DoubleConv β†’ (256, 16, 16)')
    print('  ↑ Upsample + concat enc3 β†’ (128, 32, 32)')
    print('  ↑ Decoder 3: DoubleConv β†’ (128, 32, 32)')
    print('  ↑ Upsample + concat enc2 β†’ (64, 64, 64)')
    print('  ↑ Decoder 2: DoubleConv β†’ (64, 64, 64)')
    print('  ↑ Upsample + concat enc1 β†’ (32, 128, 128)')
    print('  ↑ Decoder 1: DoubleConv β†’ (32, 128, 128)')
    print('  Output: Conv1Γ—1 β†’ (4, 128, 128)  ← num_classes channels')

3. Segmentation MetricsΒΆ

def compute_iou_per_class(pred_mask: np.ndarray, gt_mask: np.ndarray, num_classes: int) -> dict:
    """Compute per-class IoU and mean IoU."""
    ious = {}
    for cls in range(num_classes):
        pred_cls = (pred_mask == cls)
        gt_cls   = (gt_mask == cls)
        intersection = (pred_cls & gt_cls).sum()
        union        = (pred_cls | gt_cls).sum()
        iou = intersection / union if union > 0 else float('nan')
        ious[CLASS_NAMES[cls]] = iou
    return ious

# Simulate a prediction (add some noise to ground truth)
noise = np.random.randint(0, 4, gt_mask.shape)
pred_mask = np.where(np.random.random(gt_mask.shape) < 0.85, gt_mask, noise)

ious = compute_iou_per_class(pred_mask, gt_mask, num_classes=4)
mean_iou = np.nanmean(list(ious.values()))

print('Segmentation Evaluation:')
print(f'{"Class":<15} {"IoU":<10} {"Pixels"}')
print('-' * 35)
for cls_name, iou in ious.items():
    n_pixels = (gt_mask == CLASS_NAMES.index(cls_name)).sum()
    print(f'{cls_name:<15} {iou:.4f}     {n_pixels}')
print(f'{"Mean IoU":<15} {mean_iou:.4f}')
print()
print(f'Pixel Accuracy: {(pred_mask == gt_mask).mean():.4f}')
print()
print('Metrics guide:')
print('  mIoU: primary metric, averaged IoU across classes')
print('  Pixel accuracy: misleading with class imbalance (background dominates)')
print('  Boundary F1: how well edges between segments are preserved')

Segmentation Cheat SheetΒΆ

Task                    Model              mIoU (ADE20K)
────────────────────────────────────────────────────────
Semantic segmentation   FCN                ~29%
                        DeepLabV3+         ~45%
                        SegFormer-B5       ~51%
                        Mask2Former        ~57%
Instance segmentation   Mask R-CNN         ~40% (COCO)
                        SOLOv2             ~43%
Real-time               BiSeNetV2          30ms, ~75% CityScapes
                        PaddleSeg          varies

Loss functions for segmentation:
  CrossEntropyLoss: standard, works for balanced classes
  Focal Loss: downweight easy pixels β€” helps class imbalance
  Dice Loss: directly optimizes IoU metric
  Combined: Dice + CrossEntropy (common in medical imaging)

Class imbalance handling:
  - Weight loss by inverse class frequency
  - Focal loss (Ξ³ parameter)
  - Oversample rare class regions
  - Use Dice loss (size-invariant)

Data annotation tools:
  - CVAT (online, free)
  - Labelme (local, polygon annotation)
  - Roboflow (cloud, team collaboration)

ExercisesΒΆ

  1. Train the U-Net on the synthetic dataset for 30 epochs β€” plot the mIoU learning curve.

  2. Replace CrossEntropyLoss with Dice Loss and compare convergence speed.

  3. Implement class-weighted loss to handle the background class imbalance.

  4. Use torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) on an image and visualize the output.

  5. Implement a sliding window inference approach for high-resolution images (split into overlapping tiles, merge predictions).