Multimodal RAG: Text + ImagesΒΆ
Combining visual and textual embeddings for retrieval, ColPali, and production multimodal search pipelines.
# Install dependencies
# !pip install transformers torch pillow numpy sentence-transformers
Multimodal Document StoreΒΆ
import numpy as np
from typing import List, Dict, Optional, Tuple, Union
from dataclasses import dataclass, field
@dataclass
class MultimodalDocument:
"""Document with text and images"""
id: str
text: str
image: Optional[np.ndarray] = None
image_embedding: Optional[np.ndarray] = None
text_embedding: Optional[np.ndarray] = None
metadata: Dict = field(default_factory=dict)
def has_image(self) -> bool:
return self.image is not None
def has_embeddings(self) -> bool:
return self.image_embedding is not None or self.text_embedding is not None
class MultimodalDocumentStore:
"""Store and retrieve multimodal documents"""
def __init__(self):
self.documents: List[MultimodalDocument] = []
self.id_to_index: Dict[str, int] = {}
def add_document(self, doc: MultimodalDocument):
"""Add document to store"""
if doc.id in self.id_to_index:
# Update existing
idx = self.id_to_index[doc.id]
self.documents[idx] = doc
else:
# Add new
self.id_to_index[doc.id] = len(self.documents)
self.documents.append(doc)
def get_document(self, doc_id: str) -> Optional[MultimodalDocument]:
"""Retrieve document by ID"""
idx = self.id_to_index.get(doc_id)
return self.documents[idx] if idx is not None else None
def get_all_documents(self) -> List[MultimodalDocument]:
"""Get all documents"""
return self.documents
def filter_by_metadata(self, key: str, value: any) -> List[MultimodalDocument]:
"""Filter documents by metadata"""
return [doc for doc in self.documents if doc.metadata.get(key) == value]
def get_stats(self) -> Dict:
"""Get store statistics"""
return {
"total_documents": len(self.documents),
"with_images": sum(1 for doc in self.documents if doc.has_image()),
"with_embeddings": sum(1 for doc in self.documents if doc.has_embeddings())
}
# Test document store
store = MultimodalDocumentStore()
# Add documents
for i in range(5):
doc = MultimodalDocument(
id=f"doc_{i}",
text=f"This is document {i} about {'medical' if i % 2 == 0 else 'nature'} imaging.",
image=np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) if i < 3 else None,
metadata={"category": "medical" if i % 2 == 0 else "nature"}
)
store.add_document(doc)
print(f"Document Store Stats: {store.get_stats()}")
print(f"\nMedical documents: {len(store.filter_by_metadata('category', 'medical'))}")
Multimodal RetrieverΒΆ
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity"""
dot_product = np.dot(a, b)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
return dot_product / (norm_a * norm_b) if norm_a * norm_b > 0 else 0.0
class MultimodalRetriever:
"""Retrieve relevant documents using text and/or images"""
def __init__(self, document_store: MultimodalDocumentStore):
self.document_store = document_store
self.embedding_dim = 512
def _embed_text(self, text: str) -> np.ndarray:
"""Generate text embedding"""
# In production: use sentence-transformers or CLIP text encoder
embedding = np.random.randn(self.embedding_dim).astype(np.float32)
return embedding / np.linalg.norm(embedding)
def _embed_image(self, image: np.ndarray) -> np.ndarray:
"""Generate image embedding"""
# In production: use CLIP image encoder
embedding = np.random.randn(self.embedding_dim).astype(np.float32)
return embedding / np.linalg.norm(embedding)
def retrieve_by_text(self, query: str, top_k: int = 5) -> List[Tuple[MultimodalDocument, float]]:
"""Retrieve documents by text query"""
query_embedding = self._embed_text(query)
# Ensure all docs have text embeddings
for doc in self.document_store.documents:
if doc.text_embedding is None:
doc.text_embedding = self._embed_text(doc.text)
# Compute similarities
results = []
for doc in self.document_store.documents:
similarity = cosine_similarity(query_embedding, doc.text_embedding)
results.append((doc, similarity))
# Sort and return top-k
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
def retrieve_by_image(self, query_image: np.ndarray, top_k: int = 5) -> List[Tuple[MultimodalDocument, float]]:
"""Retrieve documents by image query"""
query_embedding = self._embed_image(query_image)
# Ensure all docs with images have embeddings
for doc in self.document_store.documents:
if doc.has_image() and doc.image_embedding is None:
doc.image_embedding = self._embed_image(doc.image)
# Compute similarities (only for docs with images)
results = []
for doc in self.document_store.documents:
if doc.image_embedding is not None:
similarity = cosine_similarity(query_embedding, doc.image_embedding)
results.append((doc, similarity))
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
def hybrid_retrieve(self,
text_query: Optional[str] = None,
image_query: Optional[np.ndarray] = None,
text_weight: float = 0.5,
top_k: int = 5) -> List[Tuple[MultimodalDocument, float]]:
"""Retrieve using both text and image (hybrid)"""
if text_query is None and image_query is None:
raise ValueError("Must provide at least one query type")
# Get results from each modality
text_results = {} if text_query is None else {
doc.id: score for doc, score in self.retrieve_by_text(text_query, top_k=len(self.document_store.documents))
}
image_results = {} if image_query is None else {
doc.id: score for doc, score in self.retrieve_by_image(image_query, top_k=len(self.document_store.documents))
}
# Combine scores
combined_scores = {}
all_doc_ids = set(text_results.keys()) | set(image_results.keys())
for doc_id in all_doc_ids:
text_score = text_results.get(doc_id, 0)
image_score = image_results.get(doc_id, 0)
combined_scores[doc_id] = text_weight * text_score + (1 - text_weight) * image_score
# Sort and return
sorted_ids = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
results = [
(self.document_store.get_document(doc_id), score)
for doc_id, score in sorted_ids
]
return results
# Test retriever
retriever = MultimodalRetriever(store)
# Text retrieval
print("Text Retrieval: 'medical imaging'")
results = retriever.retrieve_by_text("medical imaging", top_k=3)
for doc, score in results:
print(f" {doc.id} ({doc.metadata['category']}): {score:.3f}")
# Image retrieval
print("\nImage Retrieval:")
query_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
results = retriever.retrieve_by_image(query_img, top_k=3)
for doc, score in results:
print(f" {doc.id} ({doc.metadata['category']}): {score:.3f}")
# Hybrid retrieval
print("\nHybrid Retrieval (text + image):")
results = retriever.hybrid_retrieve("medical", query_img, text_weight=0.7, top_k=3)
for doc, score in results:
print(f" {doc.id} ({doc.metadata['category']}): {score:.3f}")
Visual Question Answering (VQA)ΒΆ
# VQA with Vision-Language Models (requires installation)
'''
from transformers import ViltProcessor, ViltForQuestionAnswering
from PIL import Image
# Load ViLT model (Vision-and-Language Transformer)
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
# Load image
image = Image.open("image.jpg")
question = "What is in this image?"
# Prepare inputs
inputs = processor(image, question, return_tensors="pt")
# Generate answer
outputs = model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = model.config.id2label[idx]
print(f"Q: {question}")
print(f"A: {answer}")
# Alternative: BLIP-2 for open-ended VQA
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
inputs = processor(image, question, return_tensors="pt")
outputs = model.generate(**inputs)
answer = processor.decode(outputs[0], skip_special_tokens=True)
print(f"A (BLIP-2): {answer}")
'''
class VisualQA:
"""Visual Question Answering system"""
def __init__(self, model_name: str = "vilt-vqa"):
self.model_name = model_name
def answer(self, image: np.ndarray, question: str) -> str:
"""Answer question about image"""
# In production: use actual VQA model (ViLT, BLIP-2, LLaVA)
# Simulate answer
answers = [
"This image shows a medical scan.",
"There is a cat in the image.",
"The image contains natural scenery.",
"It's a photograph taken outdoors."
]
return np.random.choice(answers)
def batch_answer(self, images: List[np.ndarray], questions: List[str]) -> List[str]:
"""Answer multiple questions"""
return [self.answer(img, q) for img, q in zip(images, questions)]
# Test VQA
vqa = VisualQA()
image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
questions = [
"What is in this image?",
"What color is the main object?",
"Is this an indoor or outdoor scene?"
]
for question in questions:
answer = vqa.answer(image, question)
print(f"Q: {question}")
print(f"A: {answer}\n")
Multimodal RAG PipelineΒΆ
@dataclass
class RAGResponse:
"""Multimodal RAG response"""
answer: str
retrieved_docs: List[MultimodalDocument]
confidence: float
metadata: Dict = field(default_factory=dict)
class MultimodalRAG:
"""Complete multimodal RAG system"""
def __init__(self,
document_store: MultimodalDocumentStore,
retriever: MultimodalRetriever,
vqa: VisualQA):
self.document_store = document_store
self.retriever = retriever
self.vqa = vqa
def query(self,
question: str,
image: Optional[np.ndarray] = None,
top_k: int = 3) -> RAGResponse:
"""Answer question with retrieval augmentation"""
# 1. Retrieve relevant documents
if image is not None:
retrieved = self.retriever.hybrid_retrieve(
text_query=question,
image_query=image,
top_k=top_k
)
else:
retrieved = self.retriever.retrieve_by_text(question, top_k=top_k)
docs = [doc for doc, _ in retrieved]
scores = [score for _, score in retrieved]
# 2. Build context from retrieved docs
context_parts = []
for doc in docs:
context_parts.append(f"Document {doc.id}: {doc.text}")
context = "\n".join(context_parts)
# 3. Generate answer (using VQA if image provided)
if image is not None and any(doc.has_image() for doc in docs):
# Use VQA on most relevant image
image_doc = next((doc for doc in docs if doc.has_image()), None)
if image_doc:
answer = self.vqa.answer(image_doc.image, question)
else:
answer = self._text_only_answer(question, context)
else:
answer = self._text_only_answer(question, context)
# 4. Compute confidence (avg retrieval score)
confidence = np.mean(scores) if scores else 0.0
return RAGResponse(
answer=answer,
retrieved_docs=docs,
confidence=confidence,
metadata={
"num_retrieved": len(docs),
"retrieval_scores": scores
}
)
def _text_only_answer(self, question: str, context: str) -> str:
"""Generate text-only answer"""
# In production: use LLM with context
# answer = llm(f"Context: {context}\n\nQuestion: {question}\n\nAnswer:")
return f"Based on the retrieved documents, the answer relates to: {context[:100]}..."
# Test multimodal RAG
rag = MultimodalRAG(store, retriever, vqa)
# Text-only query
print("Text Query:")
response = rag.query("Tell me about medical imaging", top_k=2)
print(f"Answer: {response.answer}")
print(f"Confidence: {response.confidence:.3f}")
print(f"Retrieved {len(response.retrieved_docs)} documents")
# Multimodal query
print("\nMultimodal Query (text + image):")
query_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
response = rag.query("What is this image showing?", image=query_img, top_k=2)
print(f"Answer: {response.answer}")
print(f"Confidence: {response.confidence:.3f}")
Production DeploymentΒΆ
import time
from collections import deque
class ProductionMultimodalRAG:
"""Production-ready multimodal RAG with caching and monitoring"""
def __init__(self, rag: MultimodalRAG):
self.rag = rag
self.query_cache = {}
self.stats = {
"total_queries": 0,
"cache_hits": 0,
"avg_latency": 0,
"latencies": deque(maxlen=100),
"error_count": 0
}
def query(self,
question: str,
image: Optional[np.ndarray] = None,
use_cache: bool = True) -> RAGResponse:
"""Query with caching and monitoring"""
start = time.time()
try:
# Check cache (text queries only)
if use_cache and image is None:
if question in self.query_cache:
self.stats["cache_hits"] += 1
self.stats["total_queries"] += 1
return self.query_cache[question]
# Execute query
response = self.rag.query(question, image)
# Cache result (text only)
if use_cache and image is None:
self.query_cache[question] = response
return response
except Exception as e:
self.stats["error_count"] += 1
raise e
finally:
# Update stats
latency = time.time() - start
self.stats["latencies"].append(latency)
self.stats["total_queries"] += 1
self.stats["avg_latency"] = np.mean(self.stats["latencies"])
def get_performance_stats(self) -> Dict:
"""Get performance statistics"""
total = max(self.stats["total_queries"], 1)
return {
"total_queries": self.stats["total_queries"],
"cache_hits": self.stats["cache_hits"],
"cache_hit_rate": self.stats["cache_hits"] / total,
"avg_latency_ms": self.stats["avg_latency"] * 1000,
"error_count": self.stats["error_count"],
"error_rate": self.stats["error_count"] / total
}
def clear_cache(self):
"""Clear query cache"""
self.query_cache.clear()
# Test production RAG
prod_rag = ProductionMultimodalRAG(rag)
# Run queries
queries = [
"What is medical imaging?",
"Explain nature photography",
"What is medical imaging?", # Duplicate (cache hit)
]
for query in queries:
response = prod_rag.query(query)
print(f"Q: {query}")
print(f"A: {response.answer[:80]}...\n")
# Print stats
stats = prod_rag.get_performance_stats()
print("Performance Stats:")
print(f" Total queries: {stats['total_queries']}")
print(f" Cache hits: {stats['cache_hits']}")
print(f" Cache hit rate: {stats['cache_hit_rate']:.1%}")
print(f" Avg latency: {stats['avg_latency_ms']:.2f}ms")
print(f" Error rate: {stats['error_rate']:.1%}")
Best PracticesΒΆ
1. Model SelectionΒΆ
VQA: BLIP-2, LLaVA, GPT-4 Vision
Image Encoder: CLIP, DINOv2, SAM
Text Encoder: sentence-transformers, CLIP text
LLM: GPT-4, Claude, Llama 2
2. Retrieval StrategyΒΆ
Use hybrid retrieval (text + image) when available
Tune text/image weight based on your use case
Implement re-ranking for better relevance
Use metadata filtering to narrow search
3. Performance OptimizationΒΆ
Cache embeddings (donβt recompute)
Use vector databases (Pinecone, Weaviate, Milvus)
Implement query caching for common questions
Batch process when possible
Use approximate nearest neighbor search
4. Quality ImprovementsΒΆ
Fine-tune VQA model on domain data
Use larger context windows
Implement confidence thresholds
Add fallback mechanisms
Monitor and log errors
Use CasesΒΆ
Medical: Radiology report generation, case retrieval
E-commerce: Visual product search, recommendations
Education: Interactive learning with diagrams
Legal: Document analysis with images
Manufacturing: Equipment manuals with photos
Real estate: Property search with images
Key TakeawaysΒΆ
β Multimodal RAG combines text and vision for richer context
β Hybrid retrieval (text+image) improves relevance
β VQA models answer questions about images
β Vision-Language Models (VLMs) understand both modalities
β Cache embeddings and queries for production performance
β Monitor latency, accuracy, and errors in production
π Computer Vision Series Complete!ΒΆ
Youβve learned:
β Image classification (CNNs, ViT)
β Object detection (YOLO)
β Multimodal embeddings (CLIP)
β Image generation (Stable Diffusion)
β Multimodal RAG (VQA, retrieval)
Next Steps:
Explore projects in
/projectsdirectoryBuild your own CV applications
Fine-tune models on your data
Deploy to production with monitoring