Image Classification with Deep LearningΒΆ
ResNet, EfficientNet, ViT β training and fine-tuning image classifiers with PyTorch and Hugging Face.
# Install dependencies
# !pip install torch torchvision timm transformers pillow matplotlib scikit-learn
ResNet ArchitectureΒΆ
# Example with torchvision (requires installation)
'''
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
# Load pre-trained ResNet-50
model = models.resnet50(pretrained=True)
model.eval()
# Preprocessing
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Classify image
image = Image.open("path/to/image.jpg")
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(input_batch)
# Get predictions
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
print(f"{top5_catid[i]}: {top5_prob[i].item():.4f}")
'''
print("ResNet-50 classification example (commented - requires torch)")
Building a Custom ClassifierΒΆ
import numpy as np
from typing import List, Dict, Tuple
from dataclasses import dataclass
@dataclass
class Prediction:
"""Classification prediction"""
label: str
confidence: float
class_id: int
class ImageClassifier:
"""Wrapper for image classification models"""
def __init__(self, model_name: str = "resnet50", num_classes: int = 1000):
self.model_name = model_name
self.num_classes = num_classes
self.class_names = self._load_class_names()
print(f"Initialized {model_name} classifier with {num_classes} classes")
def _load_class_names(self) -> List[str]:
"""Load ImageNet class names"""
# Simplified - in production, load actual ImageNet labels
return [f"class_{i}" for i in range(self.num_classes)]
def preprocess(self, image_array: np.ndarray) -> np.ndarray:
"""Preprocess image for model"""
# Simulate preprocessing
# In production: resize, normalize, convert to tensor
return image_array
def predict(self, image: np.ndarray, top_k: int = 5) -> List[Prediction]:
"""Predict top-k classes"""
# Simulate model inference
# In production: actual model forward pass
logits = np.random.randn(self.num_classes)
# Softmax to get probabilities
exp_logits = np.exp(logits - np.max(logits))
probabilities = exp_logits / exp_logits.sum()
# Get top-k predictions
top_k_indices = np.argsort(probabilities)[-top_k:][::-1]
predictions = [
Prediction(
label=self.class_names[idx],
confidence=probabilities[idx],
class_id=int(idx)
)
for idx in top_k_indices
]
return predictions
def batch_predict(self, images: List[np.ndarray]) -> List[List[Prediction]]:
"""Predict on batch of images"""
return [self.predict(img) for img in images]
# Test classifier
classifier = ImageClassifier()
# Dummy image
dummy_image = np.random.rand(224, 224, 3)
predictions = classifier.predict(dummy_image, top_k=3)
print("\nTop 3 predictions:")
for pred in predictions:
print(f" {pred.label}: {pred.confidence:.2%}")
Transfer Learning PipelineΒΆ
# Transfer learning example with PyTorch (commented)
'''
import torch
import torch.nn as nn
import torchvision.models as models
class TransferLearningClassifier:
"""Fine-tune pre-trained model for custom task"""
def __init__(self, num_classes: int, freeze_base: bool = True):
# Load pre-trained ResNet
self.model = models.resnet50(pretrained=True)
# Freeze base layers
if freeze_base:
for param in self.model.parameters():
param.requires_grad = False
# Replace final layer
num_ftrs = self.model.fc.in_features
self.model.fc = nn.Linear(num_ftrs, num_classes)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=0.001)
def train_epoch(self, train_loader):
"""Train for one epoch"""
self.model.train()
total_loss = 0
for images, labels in train_loader:
self.optimizer.zero_grad()
outputs = self.model(images)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
def evaluate(self, val_loader):
"""Evaluate on validation set"""
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = self.model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
return accuracy
# Usage:
# classifier = TransferLearningClassifier(num_classes=10)
# for epoch in range(num_epochs):
# train_loss = classifier.train_epoch(train_loader)
# val_acc = classifier.evaluate(val_loader)
# print(f"Epoch {epoch}: Loss={train_loss:.4f}, Acc={val_acc:.2f}%")
'''
print("Transfer learning example (commented - requires PyTorch)")
print("\nKey steps:")
print("1. Load pre-trained model (ResNet, EfficientNet, etc.)")
print("2. Freeze base layers (optional)")
print("3. Replace final classification layer")
print("4. Train on your custom dataset")
print("5. Fine-tune (unfreeze layers gradually)")
Data AugmentationΒΆ
# Data augmentation techniques
'''
from torchvision import transforms
# Training augmentation
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Validation augmentation (less aggressive)
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
'''
augmentation_techniques = {
"Geometric": [
"Random crop",
"Horizontal flip",
"Rotation (5-15Β°)",
"Scaling",
"Affine transforms"
],
"Color": [
"Brightness jitter",
"Contrast adjustment",
"Saturation changes",
"Hue shifts",
"Grayscale conversion"
],
"Advanced": [
"Cutout (random patches)",
"Mixup (blend images)",
"CutMix",
"AutoAugment",
"RandAugment"
]
}
print("Data Augmentation Techniques:\n")
for category, techniques in augmentation_techniques.items():
print(f"{category}:")
for tech in techniques:
print(f" β’ {tech}")
print()
Vision Transformers (ViT)ΒΆ
# Vision Transformer example with Hugging Face
'''
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
# Load ViT model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# Process image
image = Image.open("path/to/image.jpg")
inputs = processor(images=image, return_tensors="pt")
# Classify
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(-1).item()
print(f"Predicted class: {model.config.id2label[predicted_class]}")
'''
print("Vision Transformer (ViT) Overview:\n")
print("Architecture:")
print(" 1. Split image into patches (16x16 or 32x32)")
print(" 2. Flatten patches and add position embeddings")
print(" 3. Pass through transformer encoder")
print(" 4. Classification head on [CLS] token")
print("\nAdvantages:")
print(" β Better for large datasets (100M+ images)")
print(" β Captures long-range dependencies")
print(" β Less inductive bias than CNNs")
print("\nDisadvantages:")
print(" β Requires more data than CNNs")
print(" β Higher computational cost")
print(" β Needs careful training")
Model EvaluationΒΆ
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import numpy as np
class ModelEvaluator:
"""Evaluate classification model performance"""
def __init__(self, class_names: List[str]):
self.class_names = class_names
def compute_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
"""Compute classification metrics"""
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(
y_true, y_pred, average='weighted'
)
return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1_score": f1
}
def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
"""Compute confusion matrix"""
return confusion_matrix(y_true, y_pred)
def print_report(self, y_true: np.ndarray, y_pred: np.ndarray):
"""Print evaluation report"""
metrics = self.compute_metrics(y_true, y_pred)
cm = self.confusion_matrix(y_true, y_pred)
print("\n" + "="*60)
print("MODEL EVALUATION REPORT")
print("="*60)
print(f"\nAccuracy: {metrics['accuracy']:.2%}")
print(f"Precision: {metrics['precision']:.2%}")
print(f"Recall: {metrics['recall']:.2%}")
print(f"F1 Score: {metrics['f1_score']:.2%}")
print(f"\nConfusion Matrix:")
print(cm)
print("="*60 + "\n")
# Test evaluator
evaluator = ModelEvaluator(class_names=["cat", "dog", "bird"])
# Simulated predictions
y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
y_pred = np.array([0, 1, 2, 0, 1, 1, 0, 2, 2, 0])
evaluator.print_report(y_true, y_pred)
Production DeploymentΒΆ
# Production classifier with caching and batching
from typing import Optional
import time
class ProductionClassifier:
"""Production-ready image classifier"""
def __init__(self, model_name: str, batch_size: int = 32):
self.model_name = model_name
self.batch_size = batch_size
self.cache = {}
self.stats = {"requests": 0, "cache_hits": 0, "total_time": 0}
def predict_single(self, image: np.ndarray, use_cache: bool = True) -> Prediction:
"""Predict single image with caching"""
start_time = time.time()
self.stats["requests"] += 1
# Check cache
image_hash = hash(image.tobytes())
if use_cache and image_hash in self.cache:
self.stats["cache_hits"] += 1
return self.cache[image_hash]
# Simulate prediction
time.sleep(0.01) # Simulate inference time
prediction = Prediction(
label="cat",
confidence=0.95,
class_id=0
)
# Cache result
if use_cache:
self.cache[image_hash] = prediction
self.stats["total_time"] += time.time() - start_time
return prediction
def predict_batch(self, images: List[np.ndarray]) -> List[Prediction]:
"""Batch prediction for efficiency"""
predictions = []
# Process in batches
for i in range(0, len(images), self.batch_size):
batch = images[i:i + self.batch_size]
batch_preds = [self.predict_single(img) for img in batch]
predictions.extend(batch_preds)
return predictions
def get_stats(self) -> Dict:
"""Get performance statistics"""
cache_hit_rate = self.stats["cache_hits"] / max(self.stats["requests"], 1)
avg_time = self.stats["total_time"] / max(self.stats["requests"], 1)
return {
"total_requests": self.stats["requests"],
"cache_hits": self.stats["cache_hits"],
"cache_hit_rate": cache_hit_rate,
"avg_inference_time": avg_time
}
# Test production classifier
prod_classifier = ProductionClassifier("resnet50", batch_size=8)
# Simulate requests
dummy_images = [np.random.rand(224, 224, 3) for _ in range(20)]
predictions = prod_classifier.predict_batch(dummy_images)
# Print stats
stats = prod_classifier.get_stats()
print("\nProduction Statistics:")
print(f" Total Requests: {stats['total_requests']}")
print(f" Cache Hit Rate: {stats['cache_hit_rate']:.1%}")
print(f" Avg Time: {stats['avg_inference_time']*1000:.2f}ms")
Best PracticesΒΆ
1. Model SelectionΒΆ
Small datasets (<10K images): Use transfer learning with ResNet-50
Medium datasets (10K-100K): EfficientNet or ResNet-101
Large datasets (>100K): Vision Transformers (ViT)
Mobile/edge: MobileNet or EfficientNet-Lite
2. Training TipsΒΆ
Start with frozen base, then fine-tune
Use data augmentation heavily
Monitor validation accuracy, not just training
Try different learning rates (1e-4 to 1e-3)
Use early stopping
3. Performance OptimizationΒΆ
Batch predictions when possible
Cache common predictions
Use ONNX/TensorRT for faster inference
Quantize models for edge deployment
Profile and optimize bottlenecks
4. Common PitfallsΒΆ
β Not using data augmentation
β Training from scratch with small datasets
β Wrong normalization statistics
β Class imbalance (use weighted loss)
β Overfitting (add dropout, regularization)
Key TakeawaysΒΆ
β Transfer learning is almost always better than training from scratch
β Data augmentation is crucial for good generalization
β ResNet is still a great default choice for most tasks
β Vision Transformers excel with large datasets
β Always evaluate on held-out test data
β Production deployment requires batching and caching
Next: 02_object_detection.ipynb - Detect and locate objects