Transfer Learning: Leveraging Pretrained Models for Custom Image TasksΒΆ
Training a CNN from scratch requires millions of images and days of compute. Transfer learning lets you achieve state-of-the-art accuracy on your own dataset in hours, by reusing features learned from ImageNet. This notebook covers feature extraction, fine-tuning, and model selection.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import warnings
warnings.filterwarnings('ignore')
try:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
HAS_TORCH = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'PyTorch available. Device: {device}')
except ImportError:
HAS_TORCH = False
print('PyTorch not available β showing patterns with conceptual explanations')
np.random.seed(42)
# Create synthetic task: classify 4 types of textures (simulating a real dataset)
def make_texture(texture_type: str, size: int = 64) -> np.ndarray:
img = np.zeros((size, size, 3), dtype=np.float32)
if texture_type == 'stripes':
for i in range(0, size, 8):
img[i:i+4, :, 0] = 1.0
elif texture_type == 'grid':
for i in range(0, size, 8):
img[i:i+2, :, 1] = 1.0
img[:, i:i+2, 1] = 1.0
elif texture_type == 'dots':
for i in range(4, size, 10):
for j in range(4, size, 10):
y, x = np.ogrid[:size, :size]
mask = (x-j)**2 + (y-i)**2 <= 9
img[mask, 2] = 1.0
elif texture_type == 'noise':
img[:, :, :] = np.random.uniform(0.3, 0.7, (size, size, 3))
img += np.random.normal(0, 0.05, img.shape)
return np.clip(img, 0, 1)
CLASSES = ['stripes', 'grid', 'dots', 'noise']
images = np.array([make_texture(c) for c in CLASSES for _ in range(100)])
labels = np.array([i for i in range(4) for _ in range(100)])
print(f'Dataset: {len(images)} images, {len(CLASSES)} classes')
1. Two Transfer Learning StrategiesΒΆ
print('Transfer Learning Strategy Comparison')
print('=' * 55)
print()
print('STRATEGY 1: Feature Extraction (Frozen backbone)')
print(' - Freeze all pretrained layers')
print(' - Train only the new classification head')
print(' - Best when: small dataset (< 1000 images per class)')
print(' - Training time: very fast (only head trains)')
print(' - Risk: features may not match your domain')
print()
print('STRATEGY 2: Fine-Tuning (Unfreeze all or last N layers)')
print(' - Start with pretrained weights')
print(' - Train all layers with small learning rate (1e-5 to 1e-4)')
print(' - Best when: enough data (> 1000 images per class)')
print(' - Training time: slower but usually better accuracy')
print(' - Risk: catastrophic forgetting if LR too high')
print()
print('STRATEGY 3: Progressive Unfreezing (Recommended)')
print(' Phase 1: Freeze backbone, train head (5-10 epochs, LR=1e-3)')
print(' Phase 2: Unfreeze last 2-3 layers, reduce LR (1e-4)')
print(' Phase 3: Unfreeze all, very small LR (1e-5)')
2. Feature Extraction with Pretrained ModelΒΆ
if HAS_TORCH:
# Load pretrained ResNet18 (smallest ResNet)
backbone = models.resnet18(weights='IMAGENET1K_V1')
# Remove the final classification layer
feature_extractor = nn.Sequential(*list(backbone.children())[:-1])
feature_extractor = feature_extractor.to(device)
feature_extractor.eval() # Freeze batch norm statistics
# Freeze all parameters
for param in feature_extractor.parameters():
param.requires_grad = False
print(f'ResNet18 backbone: {sum(p.numel() for p in backbone.parameters()):,} total params')
print(f'Feature extractor output: 512-dim vector per image')
# Extract features for all images
preprocess = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
features_list = []
with torch.no_grad():
for img in images:
img_uint8 = (img * 255).astype(np.uint8)
tensor = preprocess(img_uint8).unsqueeze(0).to(device)
feat = feature_extractor(tensor).squeeze().cpu().numpy()
features_list.append(feat)
X_features = np.array(features_list) # (400, 512)
print(f'\nExtracted features shape: {X_features.shape}')
# Train a simple classifier on top of features
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X_features, labels, test_size=0.2, stratify=labels, random_state=42
)
clf = Pipeline([('scaler', StandardScaler()), ('lr', LogisticRegression(C=1.0, max_iter=1000))])
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print('Feature Extraction + LogReg classifier:')
print(classification_report(y_test, y_pred, target_names=CLASSES))
else:
print('Feature extraction pattern (ResNet18):')
print()
print('backbone = models.resnet18(weights="IMAGENET1K_V1")')
print('feature_extractor = nn.Sequential(*list(backbone.children())[:-1])')
print('for param in feature_extractor.parameters():')
print(' param.requires_grad = False')
print()
print('Extracted 512-dim features β LogisticRegression classifier')
print('Expected accuracy on texture task: ~85-90% (features from ImageNet may not perfectly match)')
3. Fine-Tuning: Replacing the Classifier HeadΒΆ
if HAS_TORCH:
class ImageDataset(Dataset):
def __init__(self, images, labels, transform):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self): return len(self.images)
def __getitem__(self, idx):
img = (self.images[idx] * 255).astype(np.uint8)
return self.transform(img), self.labels[idx]
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
test_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
from sklearn.model_selection import train_test_split
X_tr_idx, X_te_idx = train_test_split(range(len(images)), test_size=0.2, stratify=labels, random_state=42)
train_ds = ImageDataset(images[X_tr_idx], labels[X_tr_idx], train_transform)
test_ds = ImageDataset(images[X_te_idx], labels[X_te_idx], test_transform)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)
# Build fine-tuning model
model_ft = models.resnet18(weights='IMAGENET1K_V1')
# PHASE 1: Freeze backbone, replace head
for param in model_ft.parameters():
param.requires_grad = False
# Replace the final FC layer
num_features = model_ft.fc.in_features
model_ft.fc = nn.Sequential(
nn.Linear(num_features, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, len(CLASSES)),
)
model_ft = model_ft.to(device)
optimizer = torch.optim.Adam(model_ft.fc.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# Train head only
for epoch in range(10):
model_ft.train()
for X_b, y_b in train_loader:
optimizer.zero_grad()
criterion(model_ft(X_b.to(device)), y_b.to(device)).backward()
optimizer.step()
# PHASE 2: Unfreeze all, small LR
for param in model_ft.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam([
{'params': list(model_ft.parameters())[:-4], 'lr': 1e-5}, # Backbone: tiny LR
{'params': model_ft.fc.parameters(), 'lr': 1e-4}, # Head: larger LR
])
for epoch in range(10):
model_ft.train()
for X_b, y_b in train_loader:
optimizer.zero_grad()
criterion(model_ft(X_b.to(device)), y_b.to(device)).backward()
optimizer.step()
# Evaluate
model_ft.eval()
preds, actuals = [], []
with torch.no_grad():
for X_b, y_b in test_loader:
p = model_ft(X_b.to(device)).argmax(1).cpu().numpy()
preds.extend(p)
actuals.extend(y_b.numpy())
print('Fine-tuned ResNet18:')
print(classification_report(actuals, preds, target_names=CLASSES))
else:
print('Fine-tuning pattern:')
print(' Phase 1: freeze backbone, train only new head (10 epochs, lr=1e-3)')
print(' Phase 2: unfreeze all, use differential LR:')
print(' backbone: lr=1e-5 (tiny β avoid catastrophic forgetting)')
print(' head: lr=1e-4 (10x larger)')
print()
print('Expected result: fine-tuning beats feature extraction by ~5-10% accuracy')
4. Model Zoo β Which Architecture to ChooseΒΆ
model_comparison = {
'Model': ['ResNet18', 'ResNet50', 'EfficientNet-B0', 'EfficientNet-B4',
'MobileNetV3-Small', 'ViT-B/16', 'ConvNeXt-Tiny'],
'Params (M)': [11.7, 25.6, 5.3, 19.3, 2.5, 86.6, 28.6],
'Top-1 Acc (%)': [69.8, 76.1, 77.7, 83.4, 67.7, 81.1, 82.1],
'Speed (fps, V100)': [2800, 1200, 2700, 800, 5000, 290, 1000],
'Best for': [
'Fast prototyping, feature extraction',
'Good balance accuracy/speed',
'Mobile/edge deployment',
'High accuracy, moderate speed',
'Phones and embedded devices',
'Large datasets, SOTA accuracy',
'Modern CNN, competes with ViT',
]
}
df = pd.DataFrame(model_comparison)
print('Model Comparison (ImageNet):')
print(df.to_string(index=False))
print()
print('Selection guide:')
print(' Prototyping: ResNet18 (fast, well-understood)')
print(' Deployment: EfficientNet-B0 or MobileNetV3')
print(' Best accuracy: EfficientNet-B4, ViT-B, or ConvNeXt')
print(' < 1000 images: Feature extraction (frozen backbone)')
print(' > 5000 images: Fine-tune all layers')
import pandas as pd
fig, ax = plt.subplots(figsize=(10, 5))
params = [11.7, 25.6, 5.3, 19.3, 2.5, 86.6, 28.6]
accuracies = [69.8, 76.1, 77.7, 83.4, 67.7, 81.1, 82.1]
names = ['ResNet18', 'ResNet50', 'EffNet-B0', 'EffNet-B4', 'MobileNetV3', 'ViT-B/16', 'ConvNeXt-T']
scatter = ax.scatter(params, accuracies, s=200, c=range(len(params)), cmap='viridis', zorder=3)
for name, p, a in zip(names, params, accuracies):
ax.annotate(name, (p, a), xytext=(5, 5), textcoords='offset points', fontsize=9)
ax.set_xlabel('Parameters (Millions)')
ax.set_ylabel('Top-1 Accuracy (ImageNet)')
ax.set_title('Accuracy vs Model Size Trade-off')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Transfer Learning Cheat SheetΒΆ
Scenario Recommended Strategy
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Small dataset (< 1K/class) Feature extraction only
Medium dataset (1-10K/class) Fine-tune last 2-3 layers
Large dataset (> 10K/class) Fine-tune all layers
Similar to ImageNet domain Any strategy works well
Very different domain Full fine-tuning or train from scratch
Learning Rate Rules:
Backbone (frozen β unfreeze): use 1/10 of head LR
Head LR: start at 1e-3, reduce by 10x if plateau
Use ReduceLROnPlateau or CosineAnnealingLR
Common Mistakes:
β Using same LR for backbone and head
β Forgetting to set model.eval() for batch norm in feature extraction
β Not normalizing with ImageNet mean/std (required for pretrained models)
β Augmenting test data
β
Always validate on a held-out set before selecting strategy
ExercisesΒΆ
Compare ResNet18 vs EfficientNet-B0 feature extraction on your dataset β which gives better features?
Implement learning rate warmup for fine-tuning (linearly increase LR for first 3 epochs).
Use
torch.nn.utils.clip_grad_norm_during fine-tuning and observe training stability.Apply test-time augmentation (TTA): average predictions across 5 augmented versions of each test image.
Visualize t-SNE of the 512-dim features before and after fine-tuning β do class clusters improve?