Image Classification with Deep LearningΒΆ

ResNet, EfficientNet, ViT β€” training and fine-tuning image classifiers with PyTorch and Hugging Face.

# Install dependencies
# !pip install torch torchvision timm transformers pillow matplotlib scikit-learn

ResNet ArchitectureΒΆ

# Example with torchvision (requires installation)
'''
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image

# Load pre-trained ResNet-50
model = models.resnet50(pretrained=True)
model.eval()

# Preprocessing
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Classify image
image = Image.open("path/to/image.jpg")
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)  # Add batch dimension

with torch.no_grad():
    output = model(input_batch)

# Get predictions
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)

for i in range(top5_prob.size(0)):
    print(f"{top5_catid[i]}: {top5_prob[i].item():.4f}")
'''

print("ResNet-50 classification example (commented - requires torch)")

Building a Custom ClassifierΒΆ

import numpy as np
from typing import List, Dict, Tuple
from dataclasses import dataclass

@dataclass
class Prediction:
    """Classification prediction"""
    label: str
    confidence: float
    class_id: int

class ImageClassifier:
    """Wrapper for image classification models"""
    
    def __init__(self, model_name: str = "resnet50", num_classes: int = 1000):
        self.model_name = model_name
        self.num_classes = num_classes
        self.class_names = self._load_class_names()
        print(f"Initialized {model_name} classifier with {num_classes} classes")
    
    def _load_class_names(self) -> List[str]:
        """Load ImageNet class names"""
        # Simplified - in production, load actual ImageNet labels
        return [f"class_{i}" for i in range(self.num_classes)]
    
    def preprocess(self, image_array: np.ndarray) -> np.ndarray:
        """Preprocess image for model"""
        # Simulate preprocessing
        # In production: resize, normalize, convert to tensor
        return image_array
    
    def predict(self, image: np.ndarray, top_k: int = 5) -> List[Prediction]:
        """Predict top-k classes"""
        # Simulate model inference
        # In production: actual model forward pass
        logits = np.random.randn(self.num_classes)
        
        # Softmax to get probabilities
        exp_logits = np.exp(logits - np.max(logits))
        probabilities = exp_logits / exp_logits.sum()
        
        # Get top-k predictions
        top_k_indices = np.argsort(probabilities)[-top_k:][::-1]
        
        predictions = [
            Prediction(
                label=self.class_names[idx],
                confidence=probabilities[idx],
                class_id=int(idx)
            )
            for idx in top_k_indices
        ]
        
        return predictions
    
    def batch_predict(self, images: List[np.ndarray]) -> List[List[Prediction]]:
        """Predict on batch of images"""
        return [self.predict(img) for img in images]

# Test classifier
classifier = ImageClassifier()

# Dummy image
dummy_image = np.random.rand(224, 224, 3)
predictions = classifier.predict(dummy_image, top_k=3)

print("\nTop 3 predictions:")
for pred in predictions:
    print(f"  {pred.label}: {pred.confidence:.2%}")

Transfer Learning PipelineΒΆ

# Transfer learning example with PyTorch (commented)
'''
import torch
import torch.nn as nn
import torchvision.models as models

class TransferLearningClassifier:
    """Fine-tune pre-trained model for custom task"""
    
    def __init__(self, num_classes: int, freeze_base: bool = True):
        # Load pre-trained ResNet
        self.model = models.resnet50(pretrained=True)
        
        # Freeze base layers
        if freeze_base:
            for param in self.model.parameters():
                param.requires_grad = False
        
        # Replace final layer
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)
        
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=0.001)
    
    def train_epoch(self, train_loader):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        
        for images, labels in train_loader:
            self.optimizer.zero_grad()
            outputs = self.model(images)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
        
        return total_loss / len(train_loader)
    
    def evaluate(self, val_loader):
        """Evaluate on validation set"""
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        accuracy = 100 * correct / total
        return accuracy

# Usage:
# classifier = TransferLearningClassifier(num_classes=10)
# for epoch in range(num_epochs):
#     train_loss = classifier.train_epoch(train_loader)
#     val_acc = classifier.evaluate(val_loader)
#     print(f"Epoch {epoch}: Loss={train_loss:.4f}, Acc={val_acc:.2f}%")
'''

print("Transfer learning example (commented - requires PyTorch)")
print("\nKey steps:")
print("1. Load pre-trained model (ResNet, EfficientNet, etc.)")
print("2. Freeze base layers (optional)")
print("3. Replace final classification layer")
print("4. Train on your custom dataset")
print("5. Fine-tune (unfreeze layers gradually)")

Data AugmentationΒΆ

# Data augmentation techniques
'''
from torchvision import transforms

# Training augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Validation augmentation (less aggressive)
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
'''

augmentation_techniques = {
    "Geometric": [
        "Random crop",
        "Horizontal flip",
        "Rotation (5-15Β°)",
        "Scaling",
        "Affine transforms"
    ],
    "Color": [
        "Brightness jitter",
        "Contrast adjustment",
        "Saturation changes",
        "Hue shifts",
        "Grayscale conversion"
    ],
    "Advanced": [
        "Cutout (random patches)",
        "Mixup (blend images)",
        "CutMix",
        "AutoAugment",
        "RandAugment"
    ]
}

print("Data Augmentation Techniques:\n")
for category, techniques in augmentation_techniques.items():
    print(f"{category}:")
    for tech in techniques:
        print(f"  β€’ {tech}")
    print()

Vision Transformers (ViT)ΒΆ

# Vision Transformer example with Hugging Face
'''
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image

# Load ViT model
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# Process image
image = Image.open("path/to/image.jpg")
inputs = processor(images=image, return_tensors="pt")

# Classify
with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits
predicted_class = logits.argmax(-1).item()
print(f"Predicted class: {model.config.id2label[predicted_class]}")
'''

print("Vision Transformer (ViT) Overview:\n")
print("Architecture:")
print("  1. Split image into patches (16x16 or 32x32)")
print("  2. Flatten patches and add position embeddings")
print("  3. Pass through transformer encoder")
print("  4. Classification head on [CLS] token")
print("\nAdvantages:")
print("  βœ“ Better for large datasets (100M+ images)")
print("  βœ“ Captures long-range dependencies")
print("  βœ“ Less inductive bias than CNNs")
print("\nDisadvantages:")
print("  βœ— Requires more data than CNNs")
print("  βœ— Higher computational cost")
print("  βœ— Needs careful training")

Model EvaluationΒΆ

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import numpy as np

class ModelEvaluator:
    """Evaluate classification model performance"""
    
    def __init__(self, class_names: List[str]):
        self.class_names = class_names
    
    def compute_metrics(self, y_true: np.ndarray, y_pred: np.ndarray) -> Dict:
        """Compute classification metrics"""
        accuracy = accuracy_score(y_true, y_pred)
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='weighted'
        )
        
        return {
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1
        }
    
    def confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
        """Compute confusion matrix"""
        return confusion_matrix(y_true, y_pred)
    
    def print_report(self, y_true: np.ndarray, y_pred: np.ndarray):
        """Print evaluation report"""
        metrics = self.compute_metrics(y_true, y_pred)
        cm = self.confusion_matrix(y_true, y_pred)
        
        print("\n" + "="*60)
        print("MODEL EVALUATION REPORT")
        print("="*60)
        print(f"\nAccuracy:  {metrics['accuracy']:.2%}")
        print(f"Precision: {metrics['precision']:.2%}")
        print(f"Recall:    {metrics['recall']:.2%}")
        print(f"F1 Score:  {metrics['f1_score']:.2%}")
        
        print(f"\nConfusion Matrix:")
        print(cm)
        print("="*60 + "\n")

# Test evaluator
evaluator = ModelEvaluator(class_names=["cat", "dog", "bird"])

# Simulated predictions
y_true = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0])
y_pred = np.array([0, 1, 2, 0, 1, 1, 0, 2, 2, 0])

evaluator.print_report(y_true, y_pred)

Production DeploymentΒΆ

# Production classifier with caching and batching
from typing import Optional
import time

class ProductionClassifier:
    """Production-ready image classifier"""
    
    def __init__(self, model_name: str, batch_size: int = 32):
        self.model_name = model_name
        self.batch_size = batch_size
        self.cache = {}
        self.stats = {"requests": 0, "cache_hits": 0, "total_time": 0}
    
    def predict_single(self, image: np.ndarray, use_cache: bool = True) -> Prediction:
        """Predict single image with caching"""
        start_time = time.time()
        self.stats["requests"] += 1
        
        # Check cache
        image_hash = hash(image.tobytes())
        if use_cache and image_hash in self.cache:
            self.stats["cache_hits"] += 1
            return self.cache[image_hash]
        
        # Simulate prediction
        time.sleep(0.01)  # Simulate inference time
        prediction = Prediction(
            label="cat",
            confidence=0.95,
            class_id=0
        )
        
        # Cache result
        if use_cache:
            self.cache[image_hash] = prediction
        
        self.stats["total_time"] += time.time() - start_time
        return prediction
    
    def predict_batch(self, images: List[np.ndarray]) -> List[Prediction]:
        """Batch prediction for efficiency"""
        predictions = []
        
        # Process in batches
        for i in range(0, len(images), self.batch_size):
            batch = images[i:i + self.batch_size]
            batch_preds = [self.predict_single(img) for img in batch]
            predictions.extend(batch_preds)
        
        return predictions
    
    def get_stats(self) -> Dict:
        """Get performance statistics"""
        cache_hit_rate = self.stats["cache_hits"] / max(self.stats["requests"], 1)
        avg_time = self.stats["total_time"] / max(self.stats["requests"], 1)
        
        return {
            "total_requests": self.stats["requests"],
            "cache_hits": self.stats["cache_hits"],
            "cache_hit_rate": cache_hit_rate,
            "avg_inference_time": avg_time
        }

# Test production classifier
prod_classifier = ProductionClassifier("resnet50", batch_size=8)

# Simulate requests
dummy_images = [np.random.rand(224, 224, 3) for _ in range(20)]
predictions = prod_classifier.predict_batch(dummy_images)

# Print stats
stats = prod_classifier.get_stats()
print("\nProduction Statistics:")
print(f"  Total Requests: {stats['total_requests']}")
print(f"  Cache Hit Rate: {stats['cache_hit_rate']:.1%}")
print(f"  Avg Time: {stats['avg_inference_time']*1000:.2f}ms")

Best PracticesΒΆ

1. Model SelectionΒΆ

  • Small datasets (<10K images): Use transfer learning with ResNet-50

  • Medium datasets (10K-100K): EfficientNet or ResNet-101

  • Large datasets (>100K): Vision Transformers (ViT)

  • Mobile/edge: MobileNet or EfficientNet-Lite

2. Training TipsΒΆ

  • Start with frozen base, then fine-tune

  • Use data augmentation heavily

  • Monitor validation accuracy, not just training

  • Try different learning rates (1e-4 to 1e-3)

  • Use early stopping

3. Performance OptimizationΒΆ

  • Batch predictions when possible

  • Cache common predictions

  • Use ONNX/TensorRT for faster inference

  • Quantize models for edge deployment

  • Profile and optimize bottlenecks

4. Common PitfallsΒΆ

  • ❌ Not using data augmentation

  • ❌ Training from scratch with small datasets

  • ❌ Wrong normalization statistics

  • ❌ Class imbalance (use weighted loss)

  • ❌ Overfitting (add dropout, regularization)

Key TakeawaysΒΆ

βœ… Transfer learning is almost always better than training from scratch

βœ… Data augmentation is crucial for good generalization

βœ… ResNet is still a great default choice for most tasks

βœ… Vision Transformers excel with large datasets

βœ… Always evaluate on held-out test data

βœ… Production deployment requires batching and caching

Next: 02_object_detection.ipynb - Detect and locate objects