Monitoring ML Models in ProductionΒΆ

🎯 Learning Objectives¢

  • Implement logging and metrics

  • Monitor model performance

  • Detect data drift

  • Set up alerts

  • Build monitoring dashboards

Why Monitor?ΒΆ

What can go wrong:

  • Model degradation over time

  • Data distribution changes (drift)

  • Infrastructure issues

  • Latency problems

  • Silent failures

Monitoring helps:

  • Detect issues early

  • Maintain model quality

  • Track business metrics

  • Optimize resources

Types of MonitoringΒΆ

1. System MonitoringΒΆ

  • Latency: Response time (p50, p95, p99)

  • Throughput: Requests per second

  • Errors: Error rate, types

  • Resources: CPU, memory, disk

2. Model MonitoringΒΆ

  • Prediction distribution: Are predictions changing?

  • Confidence scores: Is model uncertain?

  • Feature distribution: Input drift

  • Performance metrics: Accuracy, precision, recall

3. Business MonitoringΒΆ

  • User engagement: Click-through rate

  • Conversions: Sales, sign-ups

  • Revenue impact: Direct business value

# Install dependencies
# !pip install prometheus-client psutil

Basic LoggingΒΆ

Structured logging is the foundation of observability for ML systems. Rather than scattering print() statements, Python’s logging module provides severity levels (DEBUG, INFO, WARNING, ERROR), configurable output destinations (file and console via handlers), and a consistent format with timestamps. The log_prediction() function below records every prediction as a JSON object containing the model version, input features, output, and latency. These logs become the raw material for debugging production issues: when a user reports a bad prediction, you can search the log by timestamp to see exactly what the model received and returned.

import logging
from datetime import datetime
import json

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('ml_api.log'),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger(__name__)

def log_prediction(model_version, input_data, prediction, latency_ms):
    """Log prediction details"""
    log_entry = {
        "timestamp": datetime.now().isoformat(),
        "model_version": model_version,
        "input": input_data,
        "prediction": prediction,
        "latency_ms": latency_ms
    }
    logger.info(f"Prediction: {json.dumps(log_entry)}")

# Example
log_prediction(
    model_version="v1.0",
    input_data={"features": [1, 2, 3, 4]},
    prediction={"class": 0, "confidence": 0.95},
    latency_ms=12.5
)
print("βœ“ Logging configured")

Metrics with PrometheusΒΆ

Prometheus is the industry-standard metrics system for monitoring ML services. It works by scraping a /metrics endpoint that your application exposes, collecting time-series data for analysis and alerting. The four metric types serve different purposes: Counter tracks cumulative totals (e.g., total predictions), Histogram captures distributions (e.g., latency percentiles), Gauge records current values that can go up or down (e.g., active requests), and Summary computes quantiles. Labels like model_version and status let you slice metrics by dimension – for example, comparing error rates between model v1.0 and v2.0. The histogram buckets for confidence scores are tuned to the range most relevant for monitoring model quality.

from prometheus_client import Counter, Histogram, Gauge, generate_latest
from fastapi import FastAPI, Response
import time

app = FastAPI()

# Define metrics
PREDICTION_COUNT = Counter(
    'predictions_total',
    'Total number of predictions',
    ['model_version', 'status']
)

PREDICTION_LATENCY = Histogram(
    'prediction_latency_seconds',
    'Prediction latency in seconds',
    ['model_version']
)

MODEL_CONFIDENCE = Histogram(
    'model_confidence',
    'Model confidence scores',
    ['model_version'],
    buckets=[0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1.0]
)

ACTIVE_REQUESTS = Gauge(
    'active_requests',
    'Number of active requests'
)

@app.post("/predict")
async def predict(features: list):
    ACTIVE_REQUESTS.inc()
    
    start_time = time.time()
    model_version = "v1.0"
    
    try:
        # Mock prediction
        prediction = {"class": 0, "confidence": 0.95}
        
        # Record metrics
        latency = time.time() - start_time
        PREDICTION_LATENCY.labels(model_version=model_version).observe(latency)
        MODEL_CONFIDENCE.labels(model_version=model_version).observe(prediction["confidence"])
        PREDICTION_COUNT.labels(model_version=model_version, status="success").inc()
        
        return prediction
    
    except Exception as e:
        PREDICTION_COUNT.labels(model_version=model_version, status="error").inc()
        raise
    
    finally:
        ACTIVE_REQUESTS.dec()

@app.get("/metrics")
async def metrics():
    """Prometheus metrics endpoint"""
    return Response(content=generate_latest(), media_type="text/plain")

print("βœ“ Metrics configured")

Data Drift DetectionΒΆ

Data drift occurs when the statistical properties of incoming data diverge from the training distribution, causing model accuracy to degrade silently. The DriftDetector class uses the Kolmogorov-Smirnov (KS) test to compare feature distributions between a reference dataset (training data) and a sliding window of recent predictions. The KS test measures the maximum distance between two empirical cumulative distribution functions, returning a p-value that indicates whether the two samples likely come from the same distribution. If \(p < 0.05\) for any feature, drift is flagged. In practice, drift detection triggers alerts that prompt retraining – catching performance degradation before it impacts users.

import numpy as np
from scipy import stats
from collections import deque

class DriftDetector:
    """Simple drift detector using statistical tests"""
    
    def __init__(self, reference_data, window_size=1000):
        self.reference_data = np.array(reference_data)
        self.current_window = deque(maxlen=window_size)
        self.drift_threshold = 0.05  # p-value threshold
    
    def add_sample(self, sample):
        """Add new sample to current window"""
        self.current_window.append(sample)
    
    def detect_drift(self):
        """Detect drift using Kolmogorov-Smirnov test"""
        if len(self.current_window) < 30:  # Need minimum samples
            return False, 1.0
        
        current_data = np.array(list(self.current_window))
        
        # KS test for each feature
        p_values = []
        for i in range(self.reference_data.shape[1]):
            statistic, p_value = stats.ks_2samp(
                self.reference_data[:, i],
                current_data[:, i]
            )
            p_values.append(p_value)
        
        # Drift if any feature has low p-value
        min_p_value = min(p_values)
        drift_detected = min_p_value < self.drift_threshold
        
        return drift_detected, min_p_value
    
    def get_drift_report(self):
        """Get detailed drift report"""
        drift_detected, p_value = self.detect_drift()
        
        return {
            "drift_detected": drift_detected,
            "p_value": p_value,
            "threshold": self.drift_threshold,
            "window_size": len(self.current_window)
        }

# Example usage
reference_data = np.random.randn(1000, 4)  # Training data
detector = DriftDetector(reference_data)

# Add new samples
for _ in range(100):
    new_sample = np.random.randn(4)
    detector.add_sample(new_sample)

# Check for drift
report = detector.get_drift_report()
print("Drift Report:", report)

Alert SystemΒΆ

An alerting system converts raw metrics into actionable notifications. The AlertManager below implements a publish-subscribe pattern: when a metric exceeds its threshold, an Alert object is created with a severity level (INFO, WARNING, CRITICAL) and dispatched to all registered handlers – console output, email, Slack, PagerDuty, or any custom integration. The key design principle is separation of concerns: the code that checks metrics does not need to know how alerts are delivered. In production, severity levels map to different response procedures – a WARNING might create a Jira ticket, while a CRITICAL pages the on-call engineer.

from dataclasses import dataclass
from typing import Callable, List
from enum import Enum

class AlertSeverity(Enum):
    INFO = "info"
    WARNING = "warning"
    CRITICAL = "critical"

@dataclass
class Alert:
    severity: AlertSeverity
    message: str
    metric_name: str
    metric_value: float
    threshold: float

class AlertManager:
    """Manage alerts and notifications"""
    
    def __init__(self):
        self.handlers: List[Callable] = []
        self.alerts: List[Alert] = []
    
    def add_handler(self, handler: Callable):
        """Add alert handler (e.g., email, Slack)"""
        self.handlers.append(handler)
    
    def check_metric(self, metric_name, value, threshold, severity=AlertSeverity.WARNING):
        """Check if metric exceeds threshold"""
        if value > threshold:
            alert = Alert(
                severity=severity,
                message=f"{metric_name} exceeded threshold: {value:.2f} > {threshold:.2f}",
                metric_name=metric_name,
                metric_value=value,
                threshold=threshold
            )
            self.trigger_alert(alert)
    
    def trigger_alert(self, alert: Alert):
        """Trigger alert to all handlers"""
        self.alerts.append(alert)
        
        for handler in self.handlers:
            handler(alert)
    
    def get_recent_alerts(self, n=10):
        """Get recent alerts"""
        return self.alerts[-n:]

# Example handlers
def console_handler(alert: Alert):
    icon = {"info": "ℹ️", "warning": "⚠️", "critical": "🚨"}[alert.severity.value]
    print(f"{icon} [{alert.severity.value.upper()}] {alert.message}")

def email_handler(alert: Alert):
    # In practice, send actual email
    print(f"πŸ“§ Email sent: {alert.message}")

# Setup alert manager
alert_mgr = AlertManager()
alert_mgr.add_handler(console_handler)
alert_mgr.add_handler(email_handler)

# Check metrics
alert_mgr.check_metric("latency_ms", 150, 100, AlertSeverity.WARNING)
alert_mgr.check_metric("error_rate", 0.15, 0.05, AlertSeverity.CRITICAL)

print(f"\nTotal alerts: {len(alert_mgr.get_recent_alerts())}")

Performance Dashboard (Mock)ΒΆ

A performance dashboard aggregates request-level metrics into summary statistics that give you an at-a-glance view of system health. The PerformanceDashboard class computes key indicators over a configurable time window: total requests, success rate, and latency percentiles (p95, p99). Latency percentiles are more informative than averages because a few slow requests can hide behind a healthy mean – the p99 tells you what the slowest 1% of users experience. The simulated data uses a gamma distribution for latencies, which closely approximates real-world API response time distributions with their characteristic right skew.

import pandas as pd
import numpy as np
from datetime import datetime, timedelta

class PerformanceDashboard:
    """Track and display performance metrics"""
    
    def __init__(self):
        self.metrics = []
    
    def log_request(self, latency_ms, success, model_version):
        """Log request metrics"""
        self.metrics.append({
            "timestamp": datetime.now(),
            "latency_ms": latency_ms,
            "success": success,
            "model_version": model_version
        })
    
    def get_summary(self, last_n_minutes=60):
        """Get performance summary"""
        if not self.metrics:
            return "No data available"
        
        df = pd.DataFrame(self.metrics)
        cutoff = datetime.now() - timedelta(minutes=last_n_minutes)
        recent_df = df[df['timestamp'] > cutoff]
        
        if recent_df.empty:
            return "No recent data"
        
        summary = {
            "total_requests": len(recent_df),
            "success_rate": recent_df['success'].mean() * 100,
            "avg_latency_ms": recent_df['latency_ms'].mean(),
            "p95_latency_ms": recent_df['latency_ms'].quantile(0.95),
            "p99_latency_ms": recent_df['latency_ms'].quantile(0.99),
            "max_latency_ms": recent_df['latency_ms'].max()
        }
        
        return summary
    
    def display_summary(self):
        """Display formatted summary"""
        summary = self.get_summary()
        if isinstance(summary, str):
            print(summary)
            return
        
        print("\n=== Performance Summary (Last 60 min) ===")
        print(f"Total Requests: {summary['total_requests']}")
        print(f"Success Rate:   {summary['success_rate']:.2f}%")
        print(f"Avg Latency:    {summary['avg_latency_ms']:.2f}ms")
        print(f"P95 Latency:    {summary['p95_latency_ms']:.2f}ms")
        print(f"P99 Latency:    {summary['p99_latency_ms']:.2f}ms")
        print(f"Max Latency:    {summary['max_latency_ms']:.2f}ms")
        print("="  * 40)

# Example usage
dashboard = PerformanceDashboard()

# Simulate requests
for _ in range(1000):
    latency = np.random.gamma(2, 15)  # Realistic latency distribution
    success = np.random.random() > 0.02  # 98% success rate
    dashboard.log_request(latency, success, "v1.0")

dashboard.display_summary()

Best PracticesΒΆ

  1. Monitor Everything

    • System metrics (CPU, memory, latency)

    • Model metrics (confidence, predictions)

    • Business metrics (conversions, revenue)

  2. Set Meaningful Alerts

    • Not too sensitive (alert fatigue)

    • Not too relaxed (miss issues)

    • Actionable thresholds

  3. Track Trends

    • Don’t just monitor current values

    • Look for degradation over time

    • Use moving averages

  4. Automate Responses

    • Auto-scaling on high load

    • Automatic rollback on errors

    • Self-healing systems

  5. Regular Reviews

    • Weekly performance reviews

    • Monthly model evaluations

    • Quarterly business impact analysis

Key TakeawaysΒΆ

βœ… Monitor system, model, and business metrics βœ… Detect data drift proactively βœ… Set up meaningful alerts βœ… Build dashboards for visibility βœ… Automate responses when possible