Image Segmentation: Pixel-Level UnderstandingΒΆ
Segmentation assigns a class to every pixel β itβs what makes self-driving cars understand roads from sidewalks. This notebook covers semantic segmentation concepts, the U-Net architecture, and practical usage of pretrained segmentation models.
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy import ndimage
import warnings
warnings.filterwarnings('ignore')
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
HAS_TORCH = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
except ImportError:
HAS_TORCH = False
np.random.seed(42)
# Create synthetic segmentation scene
def create_segmentation_scene(size=128):
"""Scene with 4 regions: background, road, car, pedestrian."""
img = np.zeros((size, size, 3), dtype=np.uint8)
mask = np.zeros((size, size), dtype=np.int64)
# Background (sky): class 0 β blue gradient
img[:size//2, :, 2] = 180
img[:size//2, :, 0] = np.linspace(100, 150, size//2)[:, np.newaxis]
# Road: class 1 β gray
img[size//2:, :, :] = 100
mask[size//2:, :] = 1
# Car: class 2 β red
img[size//2:3*size//4, size//4:3*size//4, :] = [180, 50, 50]
mask[size//2:3*size//4, size//4:3*size//4] = 2
# Pedestrian: class 3 β green
img[size//2+10:size//2+40, 10:25, :] = [50, 170, 50]
mask[size//2+10:size//2+40, 10:25] = 3
# Add noise
img = np.clip(img.astype(np.float32) + np.random.normal(0, 10, img.shape), 0, 255).astype(np.uint8)
return img, mask
CLASS_NAMES = ['background', 'road', 'car', 'pedestrian']
CLASS_COLORS = [(100, 150, 180), (100, 100, 100), (180, 50, 50), (50, 170, 50)]
scene_img, gt_mask = create_segmentation_scene()
def colorize_mask(mask, colors=CLASS_COLORS):
colored = np.zeros((*mask.shape, 3), dtype=np.uint8)
for cls, color in enumerate(colors):
colored[mask == cls] = color
return colored
fig, axes = plt.subplots(1, 3, figsize=(13, 4))
axes[0].imshow(scene_img)
axes[0].set_title('Input Image')
axes[0].axis('off')
axes[1].imshow(colorize_mask(gt_mask))
axes[1].set_title('Ground Truth Mask')
axes[1].axis('off')
# Overlay
overlay = scene_img.copy().astype(np.float32)
colored = colorize_mask(gt_mask).astype(np.float32)
axes[2].imshow((0.5 * overlay + 0.5 * colored).astype(np.uint8))
axes[2].set_title('Overlay')
axes[2].axis('off')
for i, (name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS)):
axes[1].text(2, 10 + i*15, f'β {name}', color=np.array(color)/255, fontsize=8)
plt.suptitle('Semantic Segmentation: Pixel-Level Classification')
plt.tight_layout()
plt.show()
print(f'Image: {scene_img.shape}, Mask: {gt_mask.shape}')
for cls in range(4):
pct = (gt_mask == cls).mean() * 100
print(f' Class {cls} ({CLASS_NAMES[cls]}): {pct:.1f}% of pixels')
1. Segmentation vs Detection vs ClassificationΒΆ
comparison = {
'Task': ['Classification', 'Object Detection', 'Semantic Segmentation', 'Instance Segmentation', 'Panoptic Segmentation'],
'Output': ['Single label', 'Boxes + labels', 'Pixel class mask', 'Pixel mask per object', 'Full scene understanding'],
'Resolution': ['Image-level', 'Object-level', 'Pixel-level', 'Pixel-level', 'Pixel-level'],
'Distinguishes instances?': ['No', 'Yes (by box)', 'No', 'Yes', 'Yes'],
'Models': ['ResNet, ViT', 'YOLO, Faster R-CNN', 'FCN, U-Net, DeepLab', 'Mask R-CNN, SOLOv2', 'Panoptic FPN'],
}
import pandas as pd
df = pd.DataFrame(comparison)
print(df.to_string(index=False))
print()
print('Semantic vs Instance segmentation:')
print(' Semantic: all cars are class=2 (same color on mask)')
print(' Instance: car1 and car2 are different colors/IDs')
2. U-Net ArchitectureΒΆ
if HAS_TORCH:
class DoubleConv(nn.Module):
"""Conv β BN β ReLU β Conv β BN β ReLU."""
def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
)
def forward(self, x): return self.net(x)
class UNet(nn.Module):
"""
U-Net: Encoder-Decoder with skip connections.
Encoder: progressively reduce spatial size, increase channels
Bottleneck: deepest representation
Decoder: progressively increase spatial size, decrease channels
Skip connections: concatenate encoder features to decoder
"""
def __init__(self, in_channels=3, num_classes=4, base_ch=32):
super().__init__()
# Encoder (contracting path)
self.enc1 = DoubleConv(in_channels, base_ch) # 128β128
self.enc2 = DoubleConv(base_ch, base_ch*2) # 64β64
self.enc3 = DoubleConv(base_ch*2, base_ch*4) # 32β32
self.pool = nn.MaxPool2d(2) # Halves spatial
# Bottleneck
self.bottleneck = DoubleConv(base_ch*4, base_ch*8) # 16β16
# Decoder (expanding path)
self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, stride=2) # Upsample
self.dec3 = DoubleConv(base_ch*8, base_ch*4) # cat with enc3
self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, stride=2)
self.dec2 = DoubleConv(base_ch*4, base_ch*2)
self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, stride=2)
self.dec1 = DoubleConv(base_ch*2, base_ch)
# Output: 1Γ1 conv to num_classes
self.output_conv = nn.Conv2d(base_ch, num_classes, 1)
def forward(self, x):
# Encoder
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
# Bottleneck
b = self.bottleneck(self.pool(e3))
# Decoder with skip connections (concatenate encoder features)
d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
return self.output_conv(d1) # (B, num_classes, H, W)
unet = UNet(in_channels=3, num_classes=4, base_ch=32)
total_params = sum(p.numel() for p in unet.parameters())
print(f'U-Net parameters: {total_params:,}')
# Verify output shape
dummy = torch.randn(1, 3, 128, 128)
out = unet(dummy)
print(f'Input shape: {dummy.shape}')
print(f'Output shape: {out.shape} β (batch, classes, height, width)')
print('Predicted class per pixel: logits.argmax(dim=1) β (batch, H, W)')
else:
print('U-Net Architecture:')
print()
print('Input (3, 128, 128)')
print(' β Encoder 1: DoubleConv β (32, 128, 128)')
print(' β MaxPool β (32, 64, 64)')
print(' β Encoder 2: DoubleConv β (64, 64, 64)')
print(' β MaxPool β (64, 32, 32)')
print(' β Encoder 3: DoubleConv β (128, 32, 32)')
print(' β MaxPool β (128, 16, 16)')
print(' Bottleneck: DoubleConv β (256, 16, 16)')
print(' β Upsample + concat enc3 β (128, 32, 32)')
print(' β Decoder 3: DoubleConv β (128, 32, 32)')
print(' β Upsample + concat enc2 β (64, 64, 64)')
print(' β Decoder 2: DoubleConv β (64, 64, 64)')
print(' β Upsample + concat enc1 β (32, 128, 128)')
print(' β Decoder 1: DoubleConv β (32, 128, 128)')
print(' Output: Conv1Γ1 β (4, 128, 128) β num_classes channels')
3. Segmentation MetricsΒΆ
def compute_iou_per_class(pred_mask: np.ndarray, gt_mask: np.ndarray, num_classes: int) -> dict:
"""Compute per-class IoU and mean IoU."""
ious = {}
for cls in range(num_classes):
pred_cls = (pred_mask == cls)
gt_cls = (gt_mask == cls)
intersection = (pred_cls & gt_cls).sum()
union = (pred_cls | gt_cls).sum()
iou = intersection / union if union > 0 else float('nan')
ious[CLASS_NAMES[cls]] = iou
return ious
# Simulate a prediction (add some noise to ground truth)
noise = np.random.randint(0, 4, gt_mask.shape)
pred_mask = np.where(np.random.random(gt_mask.shape) < 0.85, gt_mask, noise)
ious = compute_iou_per_class(pred_mask, gt_mask, num_classes=4)
mean_iou = np.nanmean(list(ious.values()))
print('Segmentation Evaluation:')
print(f'{"Class":<15} {"IoU":<10} {"Pixels"}')
print('-' * 35)
for cls_name, iou in ious.items():
n_pixels = (gt_mask == CLASS_NAMES.index(cls_name)).sum()
print(f'{cls_name:<15} {iou:.4f} {n_pixels}')
print(f'{"Mean IoU":<15} {mean_iou:.4f}')
print()
print(f'Pixel Accuracy: {(pred_mask == gt_mask).mean():.4f}')
print()
print('Metrics guide:')
print(' mIoU: primary metric, averaged IoU across classes')
print(' Pixel accuracy: misleading with class imbalance (background dominates)')
print(' Boundary F1: how well edges between segments are preserved')
Segmentation Cheat SheetΒΆ
Task Model mIoU (ADE20K)
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Semantic segmentation FCN ~29%
DeepLabV3+ ~45%
SegFormer-B5 ~51%
Mask2Former ~57%
Instance segmentation Mask R-CNN ~40% (COCO)
SOLOv2 ~43%
Real-time BiSeNetV2 30ms, ~75% CityScapes
PaddleSeg varies
Loss functions for segmentation:
CrossEntropyLoss: standard, works for balanced classes
Focal Loss: downweight easy pixels β helps class imbalance
Dice Loss: directly optimizes IoU metric
Combined: Dice + CrossEntropy (common in medical imaging)
Class imbalance handling:
- Weight loss by inverse class frequency
- Focal loss (Ξ³ parameter)
- Oversample rare class regions
- Use Dice loss (size-invariant)
Data annotation tools:
- CVAT (online, free)
- Labelme (local, polygon annotation)
- Roboflow (cloud, team collaboration)
ExercisesΒΆ
Train the U-Net on the synthetic dataset for 30 epochs β plot the mIoU learning curve.
Replace CrossEntropyLoss with Dice Loss and compare convergence speed.
Implement class-weighted loss to handle the background class imbalance.
Use
torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)on an image and visualize the output.Implement a sliding window inference approach for high-resolution images (split into overlapping tiles, merge predictions).