CLIP: Connecting Text and ImagesΒΆ

OpenAI CLIP embeddings for zero-shot classification, image search, and multimodal retrieval.

# Install dependencies
# !pip install transformers torch pillow numpy scikit-learn

CLIP ArchitectureΒΆ

# CLIP with Hugging Face (requires installation)
'''
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

# Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load image
image = Image.open("path/to/image.jpg")

# Prepare inputs
texts = ["a cat", "a dog", "a car"]
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)

# Get predictions
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image  # Image-text similarity
probs = logits_per_image.softmax(dim=1)  # Convert to probabilities

# Print results
for text, prob in zip(texts, probs[0]):
    print(f"{text}: {prob:.2%}")
'''

print("CLIP example (commented - requires transformers & torch)")
print("\nCLIP Models:")
print("  openai/clip-vit-base-patch32 - Base model (default)")
print("  openai/clip-vit-large-patch14 - Large model (better accuracy)")
print("  laion/CLIP-ViT-H-14-laion2B-s32B-b79K - Largest, best performance")

CLIP EmbedderΒΆ

import numpy as np
from typing import List, Union
from dataclasses import dataclass

@dataclass
class Embedding:
    """Embedding representation"""
    vector: np.ndarray
    source: str  # 'image' or 'text'
    metadata: dict = None
    
    def cosine_similarity(self, other: 'Embedding') -> float:
        """Compute cosine similarity"""
        dot_product = np.dot(self.vector, other.vector)
        norm_a = np.linalg.norm(self.vector)
        norm_b = np.linalg.norm(other.vector)
        return dot_product / (norm_a * norm_b) if norm_a * norm_b > 0 else 0.0

class CLIPEmbedder:
    """CLIP embedding wrapper"""
    
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        self.model_name = model_name
        self.embedding_dim = 512  # CLIP base dimension
    
    def embed_image(self, image: np.ndarray, normalize: bool = True) -> Embedding:
        """Generate image embedding"""
        # In production: use actual CLIP model
        # embedding = model.encode_image(image)
        
        # Simulate embedding
        embedding = np.random.randn(self.embedding_dim).astype(np.float32)
        
        if normalize:
            embedding = embedding / np.linalg.norm(embedding)
        
        return Embedding(
            vector=embedding,
            source="image",
            metadata={"shape": image.shape if hasattr(image, 'shape') else None}
        )
    
    def embed_text(self, text: str, normalize: bool = True) -> Embedding:
        """Generate text embedding"""
        # In production: use actual CLIP model
        # embedding = model.encode_text(text)
        
        # Simulate embedding (make it similar to demonstrate)
        embedding = np.random.randn(self.embedding_dim).astype(np.float32)
        
        if normalize:
            embedding = embedding / np.linalg.norm(embedding)
        
        return Embedding(
            vector=embedding,
            source="text",
            metadata={"text": text, "length": len(text)}
        )
    
    def embed_batch_images(self, images: List[np.ndarray]) -> List[Embedding]:
        """Batch image embedding"""
        return [self.embed_image(img) for img in images]
    
    def embed_batch_texts(self, texts: List[str]) -> List[Embedding]:
        """Batch text embedding"""
        return [self.embed_text(text) for text in texts]

# Test embedder
embedder = CLIPEmbedder()

# Create embeddings
image = np.zeros((224, 224, 3), dtype=np.uint8)
img_emb = embedder.embed_image(image)
text_emb = embedder.embed_text("a photo of a cat")

print(f"Image embedding: shape={img_emb.vector.shape}, norm={np.linalg.norm(img_emb.vector):.3f}")
print(f"Text embedding:  shape={text_emb.vector.shape}, norm={np.linalg.norm(text_emb.vector):.3f}")
print(f"\nSimilarity: {img_emb.cosine_similarity(text_emb):.3f}")

Zero-Shot ClassificationΒΆ

from typing import Tuple

class ZeroShotClassifier:
    """Zero-shot image classification with CLIP"""
    
    def __init__(self, embedder: CLIPEmbedder):
        self.embedder = embedder
    
    def classify(self, image: np.ndarray, labels: List[str]) -> List[Tuple[str, float]]:
        """Classify image with text labels (no training needed!)"""
        # Get image embedding
        img_emb = self.embedder.embed_image(image)
        
        # Get text embeddings for all labels
        text_prompts = [f"a photo of a {label}" for label in labels]
        text_embs = self.embedder.embed_batch_texts(text_prompts)
        
        # Compute similarities
        similarities = [img_emb.cosine_similarity(text_emb) for text_emb in text_embs]
        
        # Convert to probabilities (softmax)
        similarities = np.array(similarities)
        exp_sims = np.exp(similarities - np.max(similarities))  # Numerical stability
        probs = exp_sims / np.sum(exp_sims)
        
        # Sort by probability
        results = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
        
        return results
    
    def classify_with_custom_prompts(self, image: np.ndarray, prompts: List[str]) -> List[Tuple[str, float]]:
        """Classify with custom text prompts"""
        img_emb = self.embedder.embed_image(image)
        text_embs = self.embedder.embed_batch_texts(prompts)
        
        similarities = [img_emb.cosine_similarity(text_emb) for text_emb in text_embs]
        similarities = np.array(similarities)
        exp_sims = np.exp(similarities - np.max(similarities))
        probs = exp_sims / np.sum(exp_sims)
        
        return sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True)

# Test zero-shot classifier
classifier = ZeroShotClassifier(embedder)

# Classify with simple labels
image = np.zeros((224, 224, 3), dtype=np.uint8)
labels = ["cat", "dog", "car", "airplane", "person"]
results = classifier.classify(image, labels)

print("Zero-Shot Classification Results:")
for label, prob in results:
    print(f"  {label}: {prob:.2%}")

# Classify with custom prompts
print("\nCustom Prompt Classification:")
custom_prompts = [
    "a cute cat sleeping",
    "a dog running in a park",
    "a red sports car",
    "an airplane flying in the sky"
]
custom_results = classifier.classify_with_custom_prompts(image, custom_prompts)
for prompt, prob in custom_results:
    print(f"  {prompt}: {prob:.2%}")

Visual Search EngineΒΆ

from typing import Optional

@dataclass
class ImageEntry:
    """Image database entry"""
    id: str
    embedding: Embedding
    path: str
    metadata: dict = None

class VisualSearchEngine:
    """Search images by text or image"""
    
    def __init__(self, embedder: CLIPEmbedder):
        self.embedder = embedder
        self.image_index: List[ImageEntry] = []
    
    def add_image(self, image: np.ndarray, image_id: str, path: str, metadata: Optional[dict] = None):
        """Add image to search index"""
        embedding = self.embedder.embed_image(image)
        entry = ImageEntry(
            id=image_id,
            embedding=embedding,
            path=path,
            metadata=metadata or {}
        )
        self.image_index.append(entry)
    
    def search_by_text(self, query: str, top_k: int = 5) -> List[Tuple[ImageEntry, float]]:
        """Search images using text query"""
        query_emb = self.embedder.embed_text(query)
        
        # Compute similarities
        results = []
        for entry in self.image_index:
            similarity = query_emb.cosine_similarity(entry.embedding)
            results.append((entry, similarity))
        
        # Sort by similarity
        results.sort(key=lambda x: x[1], reverse=True)
        
        return results[:top_k]
    
    def search_by_image(self, query_image: np.ndarray, top_k: int = 5) -> List[Tuple[ImageEntry, float]]:
        """Search images using another image"""
        query_emb = self.embedder.embed_image(query_image)
        
        results = []
        for entry in self.image_index:
            similarity = query_emb.cosine_similarity(entry.embedding)
            results.append((entry, similarity))
        
        results.sort(key=lambda x: x[1], reverse=True)
        
        return results[:top_k]
    
    def get_stats(self) -> dict:
        """Get index statistics"""
        return {
            "total_images": len(self.image_index),
            "embedding_dim": self.embedder.embedding_dim
        }

# Test visual search engine
search_engine = VisualSearchEngine(embedder)

# Index some images
for i in range(10):
    image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
    search_engine.add_image(
        image,
        image_id=f"img_{i:03d}",
        path=f"/images/sample_{i:03d}.jpg",
        metadata={"category": ["cats", "dogs", "cars"][i % 3]}
    )

# Search by text
print("Search by text: 'a cute cat'")
results = search_engine.search_by_text("a cute cat", top_k=3)
for entry, score in results:
    print(f"  {entry.id} ({entry.metadata['category']}): {score:.3f}")

# Search by image
print("\nSearch by image:")
query_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
results = search_engine.search_by_image(query_img, top_k=3)
for entry, score in results:
    print(f"  {entry.id} ({entry.metadata['category']}): {score:.3f}")

# Stats
print(f"\nIndex stats: {search_engine.get_stats()}")

Production Search SystemΒΆ

import hashlib
import json
from pathlib import Path

class ProductionSearchEngine:
    """Production-ready visual search with caching"""
    
    def __init__(self, embedder: CLIPEmbedder, cache_dir: str = "./cache"):
        self.embedder = embedder
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        self.search_engine = VisualSearchEngine(embedder)
        self.stats = {
            "total_searches": 0,
            "cache_hits": 0,
            "avg_search_time": 0
        }
    
    def _get_cache_key(self, query: str) -> str:
        """Generate cache key for query"""
        return hashlib.md5(query.encode()).hexdigest()
    
    def _load_from_cache(self, cache_key: str) -> Optional[List]:
        """Load results from cache"""
        cache_file = self.cache_dir / f"{cache_key}.json"
        if cache_file.exists():
            with open(cache_file, 'r') as f:
                return json.load(f)
        return None
    
    def _save_to_cache(self, cache_key: str, results: List):
        """Save results to cache"""
        cache_file = self.cache_dir / f"{cache_key}.json"
        # Convert results to serializable format
        serializable = [
            {"id": entry.id, "score": float(score), "path": entry.path}
            for entry, score in results
        ]
        with open(cache_file, 'w') as f:
            json.dump(serializable, f)
    
    def search(self, query: str, top_k: int = 5, use_cache: bool = True) -> List:
        """Search with caching"""
        import time
        start = time.time()
        
        cache_key = self._get_cache_key(query)
        
        # Try cache first
        if use_cache:
            cached = self._load_from_cache(cache_key)
            if cached is not None:
                self.stats["cache_hits"] += 1
                self.stats["total_searches"] += 1
                return cached[:top_k]
        
        # Perform search
        results = self.search_engine.search_by_text(query, top_k)
        
        # Cache results
        if use_cache:
            self._save_to_cache(cache_key, results)
        
        # Update stats
        search_time = time.time() - start
        self.stats["total_searches"] += 1
        self.stats["avg_search_time"] = (
            (self.stats["avg_search_time"] * (self.stats["total_searches"] - 1) + search_time)
            / self.stats["total_searches"]
        )
        
        # Convert for serialization
        return [
            {"id": entry.id, "score": float(score), "path": entry.path}
            for entry, score in results
        ]
    
    def get_performance_stats(self) -> dict:
        """Get performance statistics"""
        cache_hit_rate = (
            self.stats["cache_hits"] / max(self.stats["total_searches"], 1)
        )
        return {
            "total_searches": self.stats["total_searches"],
            "cache_hits": self.stats["cache_hits"],
            "cache_hit_rate": cache_hit_rate,
            "avg_search_time_ms": self.stats["avg_search_time"] * 1000
        }

# Test production search
prod_search = ProductionSearchEngine(embedder)

# Add images to index
for i in range(10):
    image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
    prod_search.search_engine.add_image(
        image, f"img_{i:03d}", f"/images/img_{i:03d}.jpg"
    )

# Search multiple times (test caching)
query = "a beautiful sunset"
for i in range(3):
    results = prod_search.search(query, top_k=3)
    print(f"Search {i+1}: Found {len(results)} results")

# Print stats
stats = prod_search.get_performance_stats()
print(f"\nPerformance Stats:")
print(f"  Total searches: {stats['total_searches']}")
print(f"  Cache hits: {stats['cache_hits']}")
print(f"  Cache hit rate: {stats['cache_hit_rate']:.1%}")
print(f"  Avg search time: {stats['avg_search_time_ms']:.2f}ms")

Best PracticesΒΆ

1. Prompt EngineeringΒΆ

  • Use templates: β€œa photo of a {object}”

  • Be specific: β€œa close-up photo of a red car” > β€œcar”

  • Try multiple prompts and average results

  • Include context: β€œa cat in a living room” vs β€œa cat”

2. Model SelectionΒΆ

  • ViT-B/32: Fastest, good for most tasks

  • ViT-L/14: Better accuracy, slower

  • ViT-H/14: Best performance, requires more resources

  • Fine-tune on domain-specific data for best results

3. OptimizationΒΆ

  • Normalize embeddings (L2 norm)

  • Use batch processing for multiple images/texts

  • Cache embeddings (they don’t change)

  • Use approximate nearest neighbor search (FAISS, Annoy)

4. Performance TipsΒΆ

  • Precompute and store image embeddings

  • Use smaller models for real-time applications

  • Convert to ONNX for faster inference

  • Implement vector database (Pinecone, Weaviate, Milvus)

Use CasesΒΆ

  • E-commerce: Search products with text (β€œred shoes”) or images

  • Content moderation: Detect inappropriate content with text queries

  • Medical imaging: Find similar cases with descriptions

  • Fashion: β€œShop the look” - find similar clothing

  • Art: Search artwork by style, mood, or description

Key TakeawaysΒΆ

βœ… CLIP learns vision-language joint embedding space

βœ… Zero-shot classification works without training

βœ… Search images with text or images (multimodal retrieval)

βœ… Cosine similarity measures embedding closeness

βœ… Normalize embeddings for better similarity scores

βœ… Cache embeddings for production performance

Next: 04_stable_diffusion.ipynb - Generate images from text