Transfer Learning: Leveraging Pretrained Models for Custom Image TasksΒΆ

Training a CNN from scratch requires millions of images and days of compute. Transfer learning lets you achieve state-of-the-art accuracy on your own dataset in hours, by reusing features learned from ImageNet. This notebook covers feature extraction, fine-tuning, and model selection.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import warnings
warnings.filterwarnings('ignore')

try:
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset, DataLoader, random_split
    from torchvision import models, transforms
    HAS_TORCH = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'PyTorch available. Device: {device}')
except ImportError:
    HAS_TORCH = False
    print('PyTorch not available β€” showing patterns with conceptual explanations')

np.random.seed(42)

# Create synthetic task: classify 4 types of textures (simulating a real dataset)
def make_texture(texture_type: str, size: int = 64) -> np.ndarray:
    img = np.zeros((size, size, 3), dtype=np.float32)
    if texture_type == 'stripes':
        for i in range(0, size, 8):
            img[i:i+4, :, 0] = 1.0
    elif texture_type == 'grid':
        for i in range(0, size, 8):
            img[i:i+2, :, 1] = 1.0
            img[:, i:i+2, 1] = 1.0
    elif texture_type == 'dots':
        for i in range(4, size, 10):
            for j in range(4, size, 10):
                y, x = np.ogrid[:size, :size]
                mask = (x-j)**2 + (y-i)**2 <= 9
                img[mask, 2] = 1.0
    elif texture_type == 'noise':
        img[:, :, :] = np.random.uniform(0.3, 0.7, (size, size, 3))
    img += np.random.normal(0, 0.05, img.shape)
    return np.clip(img, 0, 1)

CLASSES = ['stripes', 'grid', 'dots', 'noise']
images = np.array([make_texture(c) for c in CLASSES for _ in range(100)])
labels = np.array([i for i in range(4) for _ in range(100)])
print(f'Dataset: {len(images)} images, {len(CLASSES)} classes')

1. Two Transfer Learning StrategiesΒΆ

print('Transfer Learning Strategy Comparison')
print('=' * 55)
print()
print('STRATEGY 1: Feature Extraction (Frozen backbone)')
print('  - Freeze all pretrained layers')
print('  - Train only the new classification head')
print('  - Best when: small dataset (< 1000 images per class)')
print('  - Training time: very fast (only head trains)')
print('  - Risk: features may not match your domain')
print()
print('STRATEGY 2: Fine-Tuning (Unfreeze all or last N layers)')
print('  - Start with pretrained weights')
print('  - Train all layers with small learning rate (1e-5 to 1e-4)')
print('  - Best when: enough data (> 1000 images per class)')
print('  - Training time: slower but usually better accuracy')
print('  - Risk: catastrophic forgetting if LR too high')
print()
print('STRATEGY 3: Progressive Unfreezing (Recommended)')
print('  Phase 1: Freeze backbone, train head (5-10 epochs, LR=1e-3)')
print('  Phase 2: Unfreeze last 2-3 layers, reduce LR (1e-4)')
print('  Phase 3: Unfreeze all, very small LR (1e-5)')

2. Feature Extraction with Pretrained ModelΒΆ

if HAS_TORCH:
    # Load pretrained ResNet18 (smallest ResNet)
    backbone = models.resnet18(weights='IMAGENET1K_V1')
    
    # Remove the final classification layer
    feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
    feature_extractor = feature_extractor.to(device)
    feature_extractor.eval()  # Freeze batch norm statistics
    
    # Freeze all parameters
    for param in feature_extractor.parameters():
        param.requires_grad = False
    
    print(f'ResNet18 backbone: {sum(p.numel() for p in backbone.parameters()):,} total params')
    print(f'Feature extractor output: 512-dim vector per image')
    
    # Extract features for all images
    preprocess = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    
    features_list = []
    with torch.no_grad():
        for img in images:
            img_uint8 = (img * 255).astype(np.uint8)
            tensor = preprocess(img_uint8).unsqueeze(0).to(device)
            feat = feature_extractor(tensor).squeeze().cpu().numpy()
            features_list.append(feat)
    
    X_features = np.array(features_list)  # (400, 512)
    print(f'\nExtracted features shape: {X_features.shape}')
    
    # Train a simple classifier on top of features
    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(
        X_features, labels, test_size=0.2, stratify=labels, random_state=42
    )
    
    clf = Pipeline([('scaler', StandardScaler()), ('lr', LogisticRegression(C=1.0, max_iter=1000))])
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    
    print('Feature Extraction + LogReg classifier:')
    print(classification_report(y_test, y_pred, target_names=CLASSES))
else:
    print('Feature extraction pattern (ResNet18):')
    print()
    print('backbone = models.resnet18(weights="IMAGENET1K_V1")')
    print('feature_extractor = nn.Sequential(*list(backbone.children())[:-1])')
    print('for param in feature_extractor.parameters():')
    print('    param.requires_grad = False')
    print()
    print('Extracted 512-dim features β†’ LogisticRegression classifier')
    print('Expected accuracy on texture task: ~85-90% (features from ImageNet may not perfectly match)')

3. Fine-Tuning: Replacing the Classifier HeadΒΆ

if HAS_TORCH:
    class ImageDataset(Dataset):
        def __init__(self, images, labels, transform):
            self.images = images
            self.labels = labels
            self.transform = transform
        def __len__(self): return len(self.images)
        def __getitem__(self, idx):
            img = (self.images[idx] * 255).astype(np.uint8)
            return self.transform(img), self.labels[idx]
    
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])
    test_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
    ])
    
    from sklearn.model_selection import train_test_split
    X_tr_idx, X_te_idx = train_test_split(range(len(images)), test_size=0.2, stratify=labels, random_state=42)
    
    train_ds = ImageDataset(images[X_tr_idx], labels[X_tr_idx], train_transform)
    test_ds  = ImageDataset(images[X_te_idx], labels[X_te_idx], test_transform)
    
    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
    test_loader  = DataLoader(test_ds,  batch_size=32, shuffle=False)
    
    # Build fine-tuning model
    model_ft = models.resnet18(weights='IMAGENET1K_V1')
    
    # PHASE 1: Freeze backbone, replace head
    for param in model_ft.parameters():
        param.requires_grad = False
    
    # Replace the final FC layer
    num_features = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(128, len(CLASSES)),
    )
    model_ft = model_ft.to(device)
    
    optimizer = torch.optim.Adam(model_ft.fc.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    # Train head only
    for epoch in range(10):
        model_ft.train()
        for X_b, y_b in train_loader:
            optimizer.zero_grad()
            criterion(model_ft(X_b.to(device)), y_b.to(device)).backward()
            optimizer.step()
    
    # PHASE 2: Unfreeze all, small LR
    for param in model_ft.parameters():
        param.requires_grad = True
    
    optimizer = torch.optim.Adam([
        {'params': list(model_ft.parameters())[:-4], 'lr': 1e-5},  # Backbone: tiny LR
        {'params': model_ft.fc.parameters(),         'lr': 1e-4},  # Head: larger LR
    ])
    
    for epoch in range(10):
        model_ft.train()
        for X_b, y_b in train_loader:
            optimizer.zero_grad()
            criterion(model_ft(X_b.to(device)), y_b.to(device)).backward()
            optimizer.step()
    
    # Evaluate
    model_ft.eval()
    preds, actuals = [], []
    with torch.no_grad():
        for X_b, y_b in test_loader:
            p = model_ft(X_b.to(device)).argmax(1).cpu().numpy()
            preds.extend(p)
            actuals.extend(y_b.numpy())
    print('Fine-tuned ResNet18:')
    print(classification_report(actuals, preds, target_names=CLASSES))
else:
    print('Fine-tuning pattern:')
    print('  Phase 1: freeze backbone, train only new head (10 epochs, lr=1e-3)')
    print('  Phase 2: unfreeze all, use differential LR:')
    print('    backbone: lr=1e-5  (tiny β€” avoid catastrophic forgetting)')
    print('    head:     lr=1e-4  (10x larger)')
    print()
    print('Expected result: fine-tuning beats feature extraction by ~5-10% accuracy')

4. Model Zoo β€” Which Architecture to ChooseΒΆ

model_comparison = {
    'Model': ['ResNet18', 'ResNet50', 'EfficientNet-B0', 'EfficientNet-B4', 
               'MobileNetV3-Small', 'ViT-B/16', 'ConvNeXt-Tiny'],
    'Params (M)': [11.7, 25.6, 5.3, 19.3, 2.5, 86.6, 28.6],
    'Top-1 Acc (%)': [69.8, 76.1, 77.7, 83.4, 67.7, 81.1, 82.1],
    'Speed (fps, V100)': [2800, 1200, 2700, 800, 5000, 290, 1000],
    'Best for': [
        'Fast prototyping, feature extraction',
        'Good balance accuracy/speed',
        'Mobile/edge deployment',
        'High accuracy, moderate speed',
        'Phones and embedded devices',
        'Large datasets, SOTA accuracy',
        'Modern CNN, competes with ViT',
    ]
}

df = pd.DataFrame(model_comparison)
print('Model Comparison (ImageNet):')
print(df.to_string(index=False))
print()
print('Selection guide:')
print('  Prototyping:    ResNet18 (fast, well-understood)')
print('  Deployment:     EfficientNet-B0 or MobileNetV3')
print('  Best accuracy:  EfficientNet-B4, ViT-B, or ConvNeXt')
print('  < 1000 images:  Feature extraction (frozen backbone)')
print('  > 5000 images:  Fine-tune all layers')

import pandas as pd
fig, ax = plt.subplots(figsize=(10, 5))
params  = [11.7, 25.6, 5.3, 19.3, 2.5, 86.6, 28.6]
accuracies = [69.8, 76.1, 77.7, 83.4, 67.7, 81.1, 82.1]
names = ['ResNet18', 'ResNet50', 'EffNet-B0', 'EffNet-B4', 'MobileNetV3', 'ViT-B/16', 'ConvNeXt-T']
scatter = ax.scatter(params, accuracies, s=200, c=range(len(params)), cmap='viridis', zorder=3)
for name, p, a in zip(names, params, accuracies):
    ax.annotate(name, (p, a), xytext=(5, 5), textcoords='offset points', fontsize=9)
ax.set_xlabel('Parameters (Millions)')
ax.set_ylabel('Top-1 Accuracy (ImageNet)')
ax.set_title('Accuracy vs Model Size Trade-off')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Transfer Learning Cheat SheetΒΆ

Scenario                        Recommended Strategy
────────────────────────────────────────────────────────────────
Small dataset (< 1K/class)      Feature extraction only
Medium dataset (1-10K/class)    Fine-tune last 2-3 layers
Large dataset (> 10K/class)     Fine-tune all layers
Similar to ImageNet domain      Any strategy works well
Very different domain           Full fine-tuning or train from scratch

Learning Rate Rules:
  Backbone (frozen β†’ unfreeze): use 1/10 of head LR
  Head LR: start at 1e-3, reduce by 10x if plateau
  Use ReduceLROnPlateau or CosineAnnealingLR

Common Mistakes:
  ❌ Using same LR for backbone and head
  ❌ Forgetting to set model.eval() for batch norm in feature extraction
  ❌ Not normalizing with ImageNet mean/std (required for pretrained models)
  ❌ Augmenting test data
  βœ… Always validate on a held-out set before selecting strategy

ExercisesΒΆ

  1. Compare ResNet18 vs EfficientNet-B0 feature extraction on your dataset β€” which gives better features?

  2. Implement learning rate warmup for fine-tuning (linearly increase LR for first 3 epochs).

  3. Use torch.nn.utils.clip_grad_norm_ during fine-tuning and observe training stability.

  4. Apply test-time augmentation (TTA): average predictions across 5 augmented versions of each test image.

  5. Visualize t-SNE of the 512-dim features before and after fine-tuning β€” do class clusters improve?