CLIP: Connecting Text and ImagesΒΆ
OpenAI CLIP embeddings for zero-shot classification, image search, and multimodal retrieval.
# Install dependencies
# !pip install transformers torch pillow numpy scikit-learn
CLIP ArchitectureΒΆ
# CLIP with Hugging Face (requires installation)
'''
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
# Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load image
image = Image.open("path/to/image.jpg")
# Prepare inputs
texts = ["a cat", "a dog", "a car"]
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
# Get predictions
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # Image-text similarity
probs = logits_per_image.softmax(dim=1) # Convert to probabilities
# Print results
for text, prob in zip(texts, probs[0]):
print(f"{text}: {prob:.2%}")
'''
print("CLIP example (commented - requires transformers & torch)")
print("\nCLIP Models:")
print(" openai/clip-vit-base-patch32 - Base model (default)")
print(" openai/clip-vit-large-patch14 - Large model (better accuracy)")
print(" laion/CLIP-ViT-H-14-laion2B-s32B-b79K - Largest, best performance")
CLIP EmbedderΒΆ
import numpy as np
from typing import List, Union
from dataclasses import dataclass
@dataclass
class Embedding:
"""Embedding representation"""
vector: np.ndarray
source: str # 'image' or 'text'
metadata: dict = None
def cosine_similarity(self, other: 'Embedding') -> float:
"""Compute cosine similarity"""
dot_product = np.dot(self.vector, other.vector)
norm_a = np.linalg.norm(self.vector)
norm_b = np.linalg.norm(other.vector)
return dot_product / (norm_a * norm_b) if norm_a * norm_b > 0 else 0.0
class CLIPEmbedder:
"""CLIP embedding wrapper"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
self.model_name = model_name
self.embedding_dim = 512 # CLIP base dimension
def embed_image(self, image: np.ndarray, normalize: bool = True) -> Embedding:
"""Generate image embedding"""
# In production: use actual CLIP model
# embedding = model.encode_image(image)
# Simulate embedding
embedding = np.random.randn(self.embedding_dim).astype(np.float32)
if normalize:
embedding = embedding / np.linalg.norm(embedding)
return Embedding(
vector=embedding,
source="image",
metadata={"shape": image.shape if hasattr(image, 'shape') else None}
)
def embed_text(self, text: str, normalize: bool = True) -> Embedding:
"""Generate text embedding"""
# In production: use actual CLIP model
# embedding = model.encode_text(text)
# Simulate embedding (make it similar to demonstrate)
embedding = np.random.randn(self.embedding_dim).astype(np.float32)
if normalize:
embedding = embedding / np.linalg.norm(embedding)
return Embedding(
vector=embedding,
source="text",
metadata={"text": text, "length": len(text)}
)
def embed_batch_images(self, images: List[np.ndarray]) -> List[Embedding]:
"""Batch image embedding"""
return [self.embed_image(img) for img in images]
def embed_batch_texts(self, texts: List[str]) -> List[Embedding]:
"""Batch text embedding"""
return [self.embed_text(text) for text in texts]
# Test embedder
embedder = CLIPEmbedder()
# Create embeddings
image = np.zeros((224, 224, 3), dtype=np.uint8)
img_emb = embedder.embed_image(image)
text_emb = embedder.embed_text("a photo of a cat")
print(f"Image embedding: shape={img_emb.vector.shape}, norm={np.linalg.norm(img_emb.vector):.3f}")
print(f"Text embedding: shape={text_emb.vector.shape}, norm={np.linalg.norm(text_emb.vector):.3f}")
print(f"\nSimilarity: {img_emb.cosine_similarity(text_emb):.3f}")
Zero-Shot ClassificationΒΆ
from typing import Tuple
class ZeroShotClassifier:
"""Zero-shot image classification with CLIP"""
def __init__(self, embedder: CLIPEmbedder):
self.embedder = embedder
def classify(self, image: np.ndarray, labels: List[str]) -> List[Tuple[str, float]]:
"""Classify image with text labels (no training needed!)"""
# Get image embedding
img_emb = self.embedder.embed_image(image)
# Get text embeddings for all labels
text_prompts = [f"a photo of a {label}" for label in labels]
text_embs = self.embedder.embed_batch_texts(text_prompts)
# Compute similarities
similarities = [img_emb.cosine_similarity(text_emb) for text_emb in text_embs]
# Convert to probabilities (softmax)
similarities = np.array(similarities)
exp_sims = np.exp(similarities - np.max(similarities)) # Numerical stability
probs = exp_sims / np.sum(exp_sims)
# Sort by probability
results = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
return results
def classify_with_custom_prompts(self, image: np.ndarray, prompts: List[str]) -> List[Tuple[str, float]]:
"""Classify with custom text prompts"""
img_emb = self.embedder.embed_image(image)
text_embs = self.embedder.embed_batch_texts(prompts)
similarities = [img_emb.cosine_similarity(text_emb) for text_emb in text_embs]
similarities = np.array(similarities)
exp_sims = np.exp(similarities - np.max(similarities))
probs = exp_sims / np.sum(exp_sims)
return sorted(zip(prompts, probs), key=lambda x: x[1], reverse=True)
# Test zero-shot classifier
classifier = ZeroShotClassifier(embedder)
# Classify with simple labels
image = np.zeros((224, 224, 3), dtype=np.uint8)
labels = ["cat", "dog", "car", "airplane", "person"]
results = classifier.classify(image, labels)
print("Zero-Shot Classification Results:")
for label, prob in results:
print(f" {label}: {prob:.2%}")
# Classify with custom prompts
print("\nCustom Prompt Classification:")
custom_prompts = [
"a cute cat sleeping",
"a dog running in a park",
"a red sports car",
"an airplane flying in the sky"
]
custom_results = classifier.classify_with_custom_prompts(image, custom_prompts)
for prompt, prob in custom_results:
print(f" {prompt}: {prob:.2%}")
Visual Search EngineΒΆ
from typing import Optional
@dataclass
class ImageEntry:
"""Image database entry"""
id: str
embedding: Embedding
path: str
metadata: dict = None
class VisualSearchEngine:
"""Search images by text or image"""
def __init__(self, embedder: CLIPEmbedder):
self.embedder = embedder
self.image_index: List[ImageEntry] = []
def add_image(self, image: np.ndarray, image_id: str, path: str, metadata: Optional[dict] = None):
"""Add image to search index"""
embedding = self.embedder.embed_image(image)
entry = ImageEntry(
id=image_id,
embedding=embedding,
path=path,
metadata=metadata or {}
)
self.image_index.append(entry)
def search_by_text(self, query: str, top_k: int = 5) -> List[Tuple[ImageEntry, float]]:
"""Search images using text query"""
query_emb = self.embedder.embed_text(query)
# Compute similarities
results = []
for entry in self.image_index:
similarity = query_emb.cosine_similarity(entry.embedding)
results.append((entry, similarity))
# Sort by similarity
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
def search_by_image(self, query_image: np.ndarray, top_k: int = 5) -> List[Tuple[ImageEntry, float]]:
"""Search images using another image"""
query_emb = self.embedder.embed_image(query_image)
results = []
for entry in self.image_index:
similarity = query_emb.cosine_similarity(entry.embedding)
results.append((entry, similarity))
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
def get_stats(self) -> dict:
"""Get index statistics"""
return {
"total_images": len(self.image_index),
"embedding_dim": self.embedder.embedding_dim
}
# Test visual search engine
search_engine = VisualSearchEngine(embedder)
# Index some images
for i in range(10):
image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
search_engine.add_image(
image,
image_id=f"img_{i:03d}",
path=f"/images/sample_{i:03d}.jpg",
metadata={"category": ["cats", "dogs", "cars"][i % 3]}
)
# Search by text
print("Search by text: 'a cute cat'")
results = search_engine.search_by_text("a cute cat", top_k=3)
for entry, score in results:
print(f" {entry.id} ({entry.metadata['category']}): {score:.3f}")
# Search by image
print("\nSearch by image:")
query_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
results = search_engine.search_by_image(query_img, top_k=3)
for entry, score in results:
print(f" {entry.id} ({entry.metadata['category']}): {score:.3f}")
# Stats
print(f"\nIndex stats: {search_engine.get_stats()}")
Production Search SystemΒΆ
import hashlib
import json
from pathlib import Path
class ProductionSearchEngine:
"""Production-ready visual search with caching"""
def __init__(self, embedder: CLIPEmbedder, cache_dir: str = "./cache"):
self.embedder = embedder
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
self.search_engine = VisualSearchEngine(embedder)
self.stats = {
"total_searches": 0,
"cache_hits": 0,
"avg_search_time": 0
}
def _get_cache_key(self, query: str) -> str:
"""Generate cache key for query"""
return hashlib.md5(query.encode()).hexdigest()
def _load_from_cache(self, cache_key: str) -> Optional[List]:
"""Load results from cache"""
cache_file = self.cache_dir / f"{cache_key}.json"
if cache_file.exists():
with open(cache_file, 'r') as f:
return json.load(f)
return None
def _save_to_cache(self, cache_key: str, results: List):
"""Save results to cache"""
cache_file = self.cache_dir / f"{cache_key}.json"
# Convert results to serializable format
serializable = [
{"id": entry.id, "score": float(score), "path": entry.path}
for entry, score in results
]
with open(cache_file, 'w') as f:
json.dump(serializable, f)
def search(self, query: str, top_k: int = 5, use_cache: bool = True) -> List:
"""Search with caching"""
import time
start = time.time()
cache_key = self._get_cache_key(query)
# Try cache first
if use_cache:
cached = self._load_from_cache(cache_key)
if cached is not None:
self.stats["cache_hits"] += 1
self.stats["total_searches"] += 1
return cached[:top_k]
# Perform search
results = self.search_engine.search_by_text(query, top_k)
# Cache results
if use_cache:
self._save_to_cache(cache_key, results)
# Update stats
search_time = time.time() - start
self.stats["total_searches"] += 1
self.stats["avg_search_time"] = (
(self.stats["avg_search_time"] * (self.stats["total_searches"] - 1) + search_time)
/ self.stats["total_searches"]
)
# Convert for serialization
return [
{"id": entry.id, "score": float(score), "path": entry.path}
for entry, score in results
]
def get_performance_stats(self) -> dict:
"""Get performance statistics"""
cache_hit_rate = (
self.stats["cache_hits"] / max(self.stats["total_searches"], 1)
)
return {
"total_searches": self.stats["total_searches"],
"cache_hits": self.stats["cache_hits"],
"cache_hit_rate": cache_hit_rate,
"avg_search_time_ms": self.stats["avg_search_time"] * 1000
}
# Test production search
prod_search = ProductionSearchEngine(embedder)
# Add images to index
for i in range(10):
image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
prod_search.search_engine.add_image(
image, f"img_{i:03d}", f"/images/img_{i:03d}.jpg"
)
# Search multiple times (test caching)
query = "a beautiful sunset"
for i in range(3):
results = prod_search.search(query, top_k=3)
print(f"Search {i+1}: Found {len(results)} results")
# Print stats
stats = prod_search.get_performance_stats()
print(f"\nPerformance Stats:")
print(f" Total searches: {stats['total_searches']}")
print(f" Cache hits: {stats['cache_hits']}")
print(f" Cache hit rate: {stats['cache_hit_rate']:.1%}")
print(f" Avg search time: {stats['avg_search_time_ms']:.2f}ms")
Best PracticesΒΆ
1. Prompt EngineeringΒΆ
Use templates: βa photo of a {object}β
Be specific: βa close-up photo of a red carβ > βcarβ
Try multiple prompts and average results
Include context: βa cat in a living roomβ vs βa catβ
2. Model SelectionΒΆ
ViT-B/32: Fastest, good for most tasks
ViT-L/14: Better accuracy, slower
ViT-H/14: Best performance, requires more resources
Fine-tune on domain-specific data for best results
3. OptimizationΒΆ
Normalize embeddings (L2 norm)
Use batch processing for multiple images/texts
Cache embeddings (they donβt change)
Use approximate nearest neighbor search (FAISS, Annoy)
4. Performance TipsΒΆ
Precompute and store image embeddings
Use smaller models for real-time applications
Convert to ONNX for faster inference
Implement vector database (Pinecone, Weaviate, Milvus)
Use CasesΒΆ
E-commerce: Search products with text (βred shoesβ) or images
Content moderation: Detect inappropriate content with text queries
Medical imaging: Find similar cases with descriptions
Fashion: βShop the lookβ - find similar clothing
Art: Search artwork by style, mood, or description
Key TakeawaysΒΆ
β CLIP learns vision-language joint embedding space
β Zero-shot classification works without training
β Search images with text or images (multimodal retrieval)
β Cosine similarity measures embedding closeness
β Normalize embeddings for better similarity scores
β Cache embeddings for production performance
Next: 04_stable_diffusion.ipynb - Generate images from text