Multimodal RAG: Text + ImagesΒΆ

Combining visual and textual embeddings for retrieval, ColPali, and production multimodal search pipelines.

# Install dependencies
# !pip install transformers torch pillow numpy sentence-transformers

Multimodal Document StoreΒΆ

import numpy as np
from typing import List, Dict, Optional, Tuple, Union
from dataclasses import dataclass, field

@dataclass
class MultimodalDocument:
    """Document with text and images"""
    id: str
    text: str
    image: Optional[np.ndarray] = None
    image_embedding: Optional[np.ndarray] = None
    text_embedding: Optional[np.ndarray] = None
    metadata: Dict = field(default_factory=dict)
    
    def has_image(self) -> bool:
        return self.image is not None
    
    def has_embeddings(self) -> bool:
        return self.image_embedding is not None or self.text_embedding is not None

class MultimodalDocumentStore:
    """Store and retrieve multimodal documents"""
    
    def __init__(self):
        self.documents: List[MultimodalDocument] = []
        self.id_to_index: Dict[str, int] = {}
    
    def add_document(self, doc: MultimodalDocument):
        """Add document to store"""
        if doc.id in self.id_to_index:
            # Update existing
            idx = self.id_to_index[doc.id]
            self.documents[idx] = doc
        else:
            # Add new
            self.id_to_index[doc.id] = len(self.documents)
            self.documents.append(doc)
    
    def get_document(self, doc_id: str) -> Optional[MultimodalDocument]:
        """Retrieve document by ID"""
        idx = self.id_to_index.get(doc_id)
        return self.documents[idx] if idx is not None else None
    
    def get_all_documents(self) -> List[MultimodalDocument]:
        """Get all documents"""
        return self.documents
    
    def filter_by_metadata(self, key: str, value: any) -> List[MultimodalDocument]:
        """Filter documents by metadata"""
        return [doc for doc in self.documents if doc.metadata.get(key) == value]
    
    def get_stats(self) -> Dict:
        """Get store statistics"""
        return {
            "total_documents": len(self.documents),
            "with_images": sum(1 for doc in self.documents if doc.has_image()),
            "with_embeddings": sum(1 for doc in self.documents if doc.has_embeddings())
        }

# Test document store
store = MultimodalDocumentStore()

# Add documents
for i in range(5):
    doc = MultimodalDocument(
        id=f"doc_{i}",
        text=f"This is document {i} about {'medical' if i % 2 == 0 else 'nature'} imaging.",
        image=np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) if i < 3 else None,
        metadata={"category": "medical" if i % 2 == 0 else "nature"}
    )
    store.add_document(doc)

print(f"Document Store Stats: {store.get_stats()}")
print(f"\nMedical documents: {len(store.filter_by_metadata('category', 'medical'))}")

Multimodal RetrieverΒΆ

def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    """Compute cosine similarity"""
    dot_product = np.dot(a, b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    return dot_product / (norm_a * norm_b) if norm_a * norm_b > 0 else 0.0

class MultimodalRetriever:
    """Retrieve relevant documents using text and/or images"""
    
    def __init__(self, document_store: MultimodalDocumentStore):
        self.document_store = document_store
        self.embedding_dim = 512
    
    def _embed_text(self, text: str) -> np.ndarray:
        """Generate text embedding"""
        # In production: use sentence-transformers or CLIP text encoder
        embedding = np.random.randn(self.embedding_dim).astype(np.float32)
        return embedding / np.linalg.norm(embedding)
    
    def _embed_image(self, image: np.ndarray) -> np.ndarray:
        """Generate image embedding"""
        # In production: use CLIP image encoder
        embedding = np.random.randn(self.embedding_dim).astype(np.float32)
        return embedding / np.linalg.norm(embedding)
    
    def retrieve_by_text(self, query: str, top_k: int = 5) -> List[Tuple[MultimodalDocument, float]]:
        """Retrieve documents by text query"""
        query_embedding = self._embed_text(query)
        
        # Ensure all docs have text embeddings
        for doc in self.document_store.documents:
            if doc.text_embedding is None:
                doc.text_embedding = self._embed_text(doc.text)
        
        # Compute similarities
        results = []
        for doc in self.document_store.documents:
            similarity = cosine_similarity(query_embedding, doc.text_embedding)
            results.append((doc, similarity))
        
        # Sort and return top-k
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]
    
    def retrieve_by_image(self, query_image: np.ndarray, top_k: int = 5) -> List[Tuple[MultimodalDocument, float]]:
        """Retrieve documents by image query"""
        query_embedding = self._embed_image(query_image)
        
        # Ensure all docs with images have embeddings
        for doc in self.document_store.documents:
            if doc.has_image() and doc.image_embedding is None:
                doc.image_embedding = self._embed_image(doc.image)
        
        # Compute similarities (only for docs with images)
        results = []
        for doc in self.document_store.documents:
            if doc.image_embedding is not None:
                similarity = cosine_similarity(query_embedding, doc.image_embedding)
                results.append((doc, similarity))
        
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]
    
    def hybrid_retrieve(self,
                        text_query: Optional[str] = None,
                        image_query: Optional[np.ndarray] = None,
                        text_weight: float = 0.5,
                        top_k: int = 5) -> List[Tuple[MultimodalDocument, float]]:
        """Retrieve using both text and image (hybrid)"""
        if text_query is None and image_query is None:
            raise ValueError("Must provide at least one query type")
        
        # Get results from each modality
        text_results = {} if text_query is None else {
            doc.id: score for doc, score in self.retrieve_by_text(text_query, top_k=len(self.document_store.documents))
        }
        
        image_results = {} if image_query is None else {
            doc.id: score for doc, score in self.retrieve_by_image(image_query, top_k=len(self.document_store.documents))
        }
        
        # Combine scores
        combined_scores = {}
        all_doc_ids = set(text_results.keys()) | set(image_results.keys())
        
        for doc_id in all_doc_ids:
            text_score = text_results.get(doc_id, 0)
            image_score = image_results.get(doc_id, 0)
            combined_scores[doc_id] = text_weight * text_score + (1 - text_weight) * image_score
        
        # Sort and return
        sorted_ids = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
        results = [
            (self.document_store.get_document(doc_id), score)
            for doc_id, score in sorted_ids
        ]
        
        return results

# Test retriever
retriever = MultimodalRetriever(store)

# Text retrieval
print("Text Retrieval: 'medical imaging'")
results = retriever.retrieve_by_text("medical imaging", top_k=3)
for doc, score in results:
    print(f"  {doc.id} ({doc.metadata['category']}): {score:.3f}")

# Image retrieval
print("\nImage Retrieval:")
query_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
results = retriever.retrieve_by_image(query_img, top_k=3)
for doc, score in results:
    print(f"  {doc.id} ({doc.metadata['category']}): {score:.3f}")

# Hybrid retrieval
print("\nHybrid Retrieval (text + image):")
results = retriever.hybrid_retrieve("medical", query_img, text_weight=0.7, top_k=3)
for doc, score in results:
    print(f"  {doc.id} ({doc.metadata['category']}): {score:.3f}")

Visual Question Answering (VQA)ΒΆ

# VQA with Vision-Language Models (requires installation)
'''
from transformers import ViltProcessor, ViltForQuestionAnswering
from PIL import Image

# Load ViLT model (Vision-and-Language Transformer)
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

# Load image
image = Image.open("image.jpg")
question = "What is in this image?"

# Prepare inputs
inputs = processor(image, question, return_tensors="pt")

# Generate answer
outputs = model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = model.config.id2label[idx]

print(f"Q: {question}")
print(f"A: {answer}")

# Alternative: BLIP-2 for open-ended VQA
from transformers import Blip2Processor, Blip2ForConditionalGeneration

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")

inputs = processor(image, question, return_tensors="pt")
outputs = model.generate(**inputs)
answer = processor.decode(outputs[0], skip_special_tokens=True)

print(f"A (BLIP-2): {answer}")
'''

class VisualQA:
    """Visual Question Answering system"""
    
    def __init__(self, model_name: str = "vilt-vqa"):
        self.model_name = model_name
    
    def answer(self, image: np.ndarray, question: str) -> str:
        """Answer question about image"""
        # In production: use actual VQA model (ViLT, BLIP-2, LLaVA)
        # Simulate answer
        answers = [
            "This image shows a medical scan.",
            "There is a cat in the image.",
            "The image contains natural scenery.",
            "It's a photograph taken outdoors."
        ]
        return np.random.choice(answers)
    
    def batch_answer(self, images: List[np.ndarray], questions: List[str]) -> List[str]:
        """Answer multiple questions"""
        return [self.answer(img, q) for img, q in zip(images, questions)]

# Test VQA
vqa = VisualQA()

image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
questions = [
    "What is in this image?",
    "What color is the main object?",
    "Is this an indoor or outdoor scene?"
]

for question in questions:
    answer = vqa.answer(image, question)
    print(f"Q: {question}")
    print(f"A: {answer}\n")

Multimodal RAG PipelineΒΆ

@dataclass
class RAGResponse:
    """Multimodal RAG response"""
    answer: str
    retrieved_docs: List[MultimodalDocument]
    confidence: float
    metadata: Dict = field(default_factory=dict)

class MultimodalRAG:
    """Complete multimodal RAG system"""
    
    def __init__(self,
                 document_store: MultimodalDocumentStore,
                 retriever: MultimodalRetriever,
                 vqa: VisualQA):
        self.document_store = document_store
        self.retriever = retriever
        self.vqa = vqa
    
    def query(self,
              question: str,
              image: Optional[np.ndarray] = None,
              top_k: int = 3) -> RAGResponse:
        """Answer question with retrieval augmentation"""
        # 1. Retrieve relevant documents
        if image is not None:
            retrieved = self.retriever.hybrid_retrieve(
                text_query=question,
                image_query=image,
                top_k=top_k
            )
        else:
            retrieved = self.retriever.retrieve_by_text(question, top_k=top_k)
        
        docs = [doc for doc, _ in retrieved]
        scores = [score for _, score in retrieved]
        
        # 2. Build context from retrieved docs
        context_parts = []
        for doc in docs:
            context_parts.append(f"Document {doc.id}: {doc.text}")
        context = "\n".join(context_parts)
        
        # 3. Generate answer (using VQA if image provided)
        if image is not None and any(doc.has_image() for doc in docs):
            # Use VQA on most relevant image
            image_doc = next((doc for doc in docs if doc.has_image()), None)
            if image_doc:
                answer = self.vqa.answer(image_doc.image, question)
            else:
                answer = self._text_only_answer(question, context)
        else:
            answer = self._text_only_answer(question, context)
        
        # 4. Compute confidence (avg retrieval score)
        confidence = np.mean(scores) if scores else 0.0
        
        return RAGResponse(
            answer=answer,
            retrieved_docs=docs,
            confidence=confidence,
            metadata={
                "num_retrieved": len(docs),
                "retrieval_scores": scores
            }
        )
    
    def _text_only_answer(self, question: str, context: str) -> str:
        """Generate text-only answer"""
        # In production: use LLM with context
        # answer = llm(f"Context: {context}\n\nQuestion: {question}\n\nAnswer:")
        return f"Based on the retrieved documents, the answer relates to: {context[:100]}..."

# Test multimodal RAG
rag = MultimodalRAG(store, retriever, vqa)

# Text-only query
print("Text Query:")
response = rag.query("Tell me about medical imaging", top_k=2)
print(f"Answer: {response.answer}")
print(f"Confidence: {response.confidence:.3f}")
print(f"Retrieved {len(response.retrieved_docs)} documents")

# Multimodal query
print("\nMultimodal Query (text + image):")
query_img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
response = rag.query("What is this image showing?", image=query_img, top_k=2)
print(f"Answer: {response.answer}")
print(f"Confidence: {response.confidence:.3f}")

Production DeploymentΒΆ

import time
from collections import deque

class ProductionMultimodalRAG:
    """Production-ready multimodal RAG with caching and monitoring"""
    
    def __init__(self, rag: MultimodalRAG):
        self.rag = rag
        self.query_cache = {}
        self.stats = {
            "total_queries": 0,
            "cache_hits": 0,
            "avg_latency": 0,
            "latencies": deque(maxlen=100),
            "error_count": 0
        }
    
    def query(self,
              question: str,
              image: Optional[np.ndarray] = None,
              use_cache: bool = True) -> RAGResponse:
        """Query with caching and monitoring"""
        start = time.time()
        
        try:
            # Check cache (text queries only)
            if use_cache and image is None:
                if question in self.query_cache:
                    self.stats["cache_hits"] += 1
                    self.stats["total_queries"] += 1
                    return self.query_cache[question]
            
            # Execute query
            response = self.rag.query(question, image)
            
            # Cache result (text only)
            if use_cache and image is None:
                self.query_cache[question] = response
            
            return response
            
        except Exception as e:
            self.stats["error_count"] += 1
            raise e
        
        finally:
            # Update stats
            latency = time.time() - start
            self.stats["latencies"].append(latency)
            self.stats["total_queries"] += 1
            self.stats["avg_latency"] = np.mean(self.stats["latencies"])
    
    def get_performance_stats(self) -> Dict:
        """Get performance statistics"""
        total = max(self.stats["total_queries"], 1)
        return {
            "total_queries": self.stats["total_queries"],
            "cache_hits": self.stats["cache_hits"],
            "cache_hit_rate": self.stats["cache_hits"] / total,
            "avg_latency_ms": self.stats["avg_latency"] * 1000,
            "error_count": self.stats["error_count"],
            "error_rate": self.stats["error_count"] / total
        }
    
    def clear_cache(self):
        """Clear query cache"""
        self.query_cache.clear()

# Test production RAG
prod_rag = ProductionMultimodalRAG(rag)

# Run queries
queries = [
    "What is medical imaging?",
    "Explain nature photography",
    "What is medical imaging?",  # Duplicate (cache hit)
]

for query in queries:
    response = prod_rag.query(query)
    print(f"Q: {query}")
    print(f"A: {response.answer[:80]}...\n")

# Print stats
stats = prod_rag.get_performance_stats()
print("Performance Stats:")
print(f"  Total queries: {stats['total_queries']}")
print(f"  Cache hits: {stats['cache_hits']}")
print(f"  Cache hit rate: {stats['cache_hit_rate']:.1%}")
print(f"  Avg latency: {stats['avg_latency_ms']:.2f}ms")
print(f"  Error rate: {stats['error_rate']:.1%}")

Best PracticesΒΆ

1. Model SelectionΒΆ

  • VQA: BLIP-2, LLaVA, GPT-4 Vision

  • Image Encoder: CLIP, DINOv2, SAM

  • Text Encoder: sentence-transformers, CLIP text

  • LLM: GPT-4, Claude, Llama 2

2. Retrieval StrategyΒΆ

  • Use hybrid retrieval (text + image) when available

  • Tune text/image weight based on your use case

  • Implement re-ranking for better relevance

  • Use metadata filtering to narrow search

3. Performance OptimizationΒΆ

  • Cache embeddings (don’t recompute)

  • Use vector databases (Pinecone, Weaviate, Milvus)

  • Implement query caching for common questions

  • Batch process when possible

  • Use approximate nearest neighbor search

4. Quality ImprovementsΒΆ

  • Fine-tune VQA model on domain data

  • Use larger context windows

  • Implement confidence thresholds

  • Add fallback mechanisms

  • Monitor and log errors

Use CasesΒΆ

  • Medical: Radiology report generation, case retrieval

  • E-commerce: Visual product search, recommendations

  • Education: Interactive learning with diagrams

  • Legal: Document analysis with images

  • Manufacturing: Equipment manuals with photos

  • Real estate: Property search with images

Key TakeawaysΒΆ

βœ… Multimodal RAG combines text and vision for richer context

βœ… Hybrid retrieval (text+image) improves relevance

βœ… VQA models answer questions about images

βœ… Vision-Language Models (VLMs) understand both modalities

βœ… Cache embeddings and queries for production performance

βœ… Monitor latency, accuracy, and errors in production

πŸŽ‰ Computer Vision Series Complete!ΒΆ

You’ve learned:

  • βœ… Image classification (CNNs, ViT)

  • βœ… Object detection (YOLO)

  • βœ… Multimodal embeddings (CLIP)

  • βœ… Image generation (Stable Diffusion)

  • βœ… Multimodal RAG (VQA, retrieval)

Next Steps:

  • Explore projects in /projects directory

  • Build your own CV applications

  • Fine-tune models on your data

  • Deploy to production with monitoring