Monitoring ML Models in ProductionΒΆ
π― Learning ObjectivesΒΆ
Implement logging and metrics
Monitor model performance
Detect data drift
Set up alerts
Build monitoring dashboards
Why Monitor?ΒΆ
What can go wrong:
Model degradation over time
Data distribution changes (drift)
Infrastructure issues
Latency problems
Silent failures
Monitoring helps:
Detect issues early
Maintain model quality
Track business metrics
Optimize resources
Types of MonitoringΒΆ
1. System MonitoringΒΆ
Latency: Response time (p50, p95, p99)
Throughput: Requests per second
Errors: Error rate, types
Resources: CPU, memory, disk
2. Model MonitoringΒΆ
Prediction distribution: Are predictions changing?
Confidence scores: Is model uncertain?
Feature distribution: Input drift
Performance metrics: Accuracy, precision, recall
3. Business MonitoringΒΆ
User engagement: Click-through rate
Conversions: Sales, sign-ups
Revenue impact: Direct business value
# Install dependencies
# !pip install prometheus-client psutil
Basic LoggingΒΆ
Structured logging is the foundation of observability for ML systems. Rather than scattering print() statements, Pythonβs logging module provides severity levels (DEBUG, INFO, WARNING, ERROR), configurable output destinations (file and console via handlers), and a consistent format with timestamps. The log_prediction() function below records every prediction as a JSON object containing the model version, input features, output, and latency. These logs become the raw material for debugging production issues: when a user reports a bad prediction, you can search the log by timestamp to see exactly what the model received and returned.
import logging
from datetime import datetime
import json
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('ml_api.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def log_prediction(model_version, input_data, prediction, latency_ms):
"""Log prediction details"""
log_entry = {
"timestamp": datetime.now().isoformat(),
"model_version": model_version,
"input": input_data,
"prediction": prediction,
"latency_ms": latency_ms
}
logger.info(f"Prediction: {json.dumps(log_entry)}")
# Example
log_prediction(
model_version="v1.0",
input_data={"features": [1, 2, 3, 4]},
prediction={"class": 0, "confidence": 0.95},
latency_ms=12.5
)
print("β Logging configured")
Metrics with PrometheusΒΆ
Prometheus is the industry-standard metrics system for monitoring ML services. It works by scraping a /metrics endpoint that your application exposes, collecting time-series data for analysis and alerting. The four metric types serve different purposes: Counter tracks cumulative totals (e.g., total predictions), Histogram captures distributions (e.g., latency percentiles), Gauge records current values that can go up or down (e.g., active requests), and Summary computes quantiles. Labels like model_version and status let you slice metrics by dimension β for example, comparing error rates between model v1.0 and v2.0. The histogram buckets for confidence scores are tuned to the range most relevant for monitoring model quality.
from prometheus_client import Counter, Histogram, Gauge, generate_latest
from fastapi import FastAPI, Response
import time
app = FastAPI()
# Define metrics
PREDICTION_COUNT = Counter(
'predictions_total',
'Total number of predictions',
['model_version', 'status']
)
PREDICTION_LATENCY = Histogram(
'prediction_latency_seconds',
'Prediction latency in seconds',
['model_version']
)
MODEL_CONFIDENCE = Histogram(
'model_confidence',
'Model confidence scores',
['model_version'],
buckets=[0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1.0]
)
ACTIVE_REQUESTS = Gauge(
'active_requests',
'Number of active requests'
)
@app.post("/predict")
async def predict(features: list):
ACTIVE_REQUESTS.inc()
start_time = time.time()
model_version = "v1.0"
try:
# Mock prediction
prediction = {"class": 0, "confidence": 0.95}
# Record metrics
latency = time.time() - start_time
PREDICTION_LATENCY.labels(model_version=model_version).observe(latency)
MODEL_CONFIDENCE.labels(model_version=model_version).observe(prediction["confidence"])
PREDICTION_COUNT.labels(model_version=model_version, status="success").inc()
return prediction
except Exception as e:
PREDICTION_COUNT.labels(model_version=model_version, status="error").inc()
raise
finally:
ACTIVE_REQUESTS.dec()
@app.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint"""
return Response(content=generate_latest(), media_type="text/plain")
print("β Metrics configured")
Data Drift DetectionΒΆ
Data drift occurs when the statistical properties of incoming data diverge from the training distribution, causing model accuracy to degrade silently. The DriftDetector class uses the Kolmogorov-Smirnov (KS) test to compare feature distributions between a reference dataset (training data) and a sliding window of recent predictions. The KS test measures the maximum distance between two empirical cumulative distribution functions, returning a p-value that indicates whether the two samples likely come from the same distribution. If \(p < 0.05\) for any feature, drift is flagged. In practice, drift detection triggers alerts that prompt retraining β catching performance degradation before it impacts users.
import numpy as np
from scipy import stats
from collections import deque
class DriftDetector:
"""Simple drift detector using statistical tests"""
def __init__(self, reference_data, window_size=1000):
self.reference_data = np.array(reference_data)
self.current_window = deque(maxlen=window_size)
self.drift_threshold = 0.05 # p-value threshold
def add_sample(self, sample):
"""Add new sample to current window"""
self.current_window.append(sample)
def detect_drift(self):
"""Detect drift using Kolmogorov-Smirnov test"""
if len(self.current_window) < 30: # Need minimum samples
return False, 1.0
current_data = np.array(list(self.current_window))
# KS test for each feature
p_values = []
for i in range(self.reference_data.shape[1]):
statistic, p_value = stats.ks_2samp(
self.reference_data[:, i],
current_data[:, i]
)
p_values.append(p_value)
# Drift if any feature has low p-value
min_p_value = min(p_values)
drift_detected = min_p_value < self.drift_threshold
return drift_detected, min_p_value
def get_drift_report(self):
"""Get detailed drift report"""
drift_detected, p_value = self.detect_drift()
return {
"drift_detected": drift_detected,
"p_value": p_value,
"threshold": self.drift_threshold,
"window_size": len(self.current_window)
}
# Example usage
reference_data = np.random.randn(1000, 4) # Training data
detector = DriftDetector(reference_data)
# Add new samples
for _ in range(100):
new_sample = np.random.randn(4)
detector.add_sample(new_sample)
# Check for drift
report = detector.get_drift_report()
print("Drift Report:", report)
Alert SystemΒΆ
An alerting system converts raw metrics into actionable notifications. The AlertManager below implements a publish-subscribe pattern: when a metric exceeds its threshold, an Alert object is created with a severity level (INFO, WARNING, CRITICAL) and dispatched to all registered handlers β console output, email, Slack, PagerDuty, or any custom integration. The key design principle is separation of concerns: the code that checks metrics does not need to know how alerts are delivered. In production, severity levels map to different response procedures β a WARNING might create a Jira ticket, while a CRITICAL pages the on-call engineer.
from dataclasses import dataclass
from typing import Callable, List
from enum import Enum
class AlertSeverity(Enum):
INFO = "info"
WARNING = "warning"
CRITICAL = "critical"
@dataclass
class Alert:
severity: AlertSeverity
message: str
metric_name: str
metric_value: float
threshold: float
class AlertManager:
"""Manage alerts and notifications"""
def __init__(self):
self.handlers: List[Callable] = []
self.alerts: List[Alert] = []
def add_handler(self, handler: Callable):
"""Add alert handler (e.g., email, Slack)"""
self.handlers.append(handler)
def check_metric(self, metric_name, value, threshold, severity=AlertSeverity.WARNING):
"""Check if metric exceeds threshold"""
if value > threshold:
alert = Alert(
severity=severity,
message=f"{metric_name} exceeded threshold: {value:.2f} > {threshold:.2f}",
metric_name=metric_name,
metric_value=value,
threshold=threshold
)
self.trigger_alert(alert)
def trigger_alert(self, alert: Alert):
"""Trigger alert to all handlers"""
self.alerts.append(alert)
for handler in self.handlers:
handler(alert)
def get_recent_alerts(self, n=10):
"""Get recent alerts"""
return self.alerts[-n:]
# Example handlers
def console_handler(alert: Alert):
icon = {"info": "βΉοΈ", "warning": "β οΈ", "critical": "π¨"}[alert.severity.value]
print(f"{icon} [{alert.severity.value.upper()}] {alert.message}")
def email_handler(alert: Alert):
# In practice, send actual email
print(f"π§ Email sent: {alert.message}")
# Setup alert manager
alert_mgr = AlertManager()
alert_mgr.add_handler(console_handler)
alert_mgr.add_handler(email_handler)
# Check metrics
alert_mgr.check_metric("latency_ms", 150, 100, AlertSeverity.WARNING)
alert_mgr.check_metric("error_rate", 0.15, 0.05, AlertSeverity.CRITICAL)
print(f"\nTotal alerts: {len(alert_mgr.get_recent_alerts())}")
Performance Dashboard (Mock)ΒΆ
A performance dashboard aggregates request-level metrics into summary statistics that give you an at-a-glance view of system health. The PerformanceDashboard class computes key indicators over a configurable time window: total requests, success rate, and latency percentiles (p95, p99). Latency percentiles are more informative than averages because a few slow requests can hide behind a healthy mean β the p99 tells you what the slowest 1% of users experience. The simulated data uses a gamma distribution for latencies, which closely approximates real-world API response time distributions with their characteristic right skew.
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
class PerformanceDashboard:
"""Track and display performance metrics"""
def __init__(self):
self.metrics = []
def log_request(self, latency_ms, success, model_version):
"""Log request metrics"""
self.metrics.append({
"timestamp": datetime.now(),
"latency_ms": latency_ms,
"success": success,
"model_version": model_version
})
def get_summary(self, last_n_minutes=60):
"""Get performance summary"""
if not self.metrics:
return "No data available"
df = pd.DataFrame(self.metrics)
cutoff = datetime.now() - timedelta(minutes=last_n_minutes)
recent_df = df[df['timestamp'] > cutoff]
if recent_df.empty:
return "No recent data"
summary = {
"total_requests": len(recent_df),
"success_rate": recent_df['success'].mean() * 100,
"avg_latency_ms": recent_df['latency_ms'].mean(),
"p95_latency_ms": recent_df['latency_ms'].quantile(0.95),
"p99_latency_ms": recent_df['latency_ms'].quantile(0.99),
"max_latency_ms": recent_df['latency_ms'].max()
}
return summary
def display_summary(self):
"""Display formatted summary"""
summary = self.get_summary()
if isinstance(summary, str):
print(summary)
return
print("\n=== Performance Summary (Last 60 min) ===")
print(f"Total Requests: {summary['total_requests']}")
print(f"Success Rate: {summary['success_rate']:.2f}%")
print(f"Avg Latency: {summary['avg_latency_ms']:.2f}ms")
print(f"P95 Latency: {summary['p95_latency_ms']:.2f}ms")
print(f"P99 Latency: {summary['p99_latency_ms']:.2f}ms")
print(f"Max Latency: {summary['max_latency_ms']:.2f}ms")
print("=" * 40)
# Example usage
dashboard = PerformanceDashboard()
# Simulate requests
for _ in range(1000):
latency = np.random.gamma(2, 15) # Realistic latency distribution
success = np.random.random() > 0.02 # 98% success rate
dashboard.log_request(latency, success, "v1.0")
dashboard.display_summary()
Best PracticesΒΆ
Monitor Everything
System metrics (CPU, memory, latency)
Model metrics (confidence, predictions)
Business metrics (conversions, revenue)
Set Meaningful Alerts
Not too sensitive (alert fatigue)
Not too relaxed (miss issues)
Actionable thresholds
Track Trends
Donβt just monitor current values
Look for degradation over time
Use moving averages
Automate Responses
Auto-scaling on high load
Automatic rollback on errors
Self-healing systems
Regular Reviews
Weekly performance reviews
Monthly model evaluations
Quarterly business impact analysis
Key TakeawaysΒΆ
β Monitor system, model, and business metrics β Detect data drift proactively β Set up meaningful alerts β Build dashboards for visibility β Automate responses when possible