Model Deployment with FastAPI: From Notebook to Production APIΒΆ
Training a model is 20% of the work. Deploying it reliably is the other 80%. This notebook covers FastAPI endpoint design, Pydantic validation, async inference, Docker packaging, and health checks β everything needed to ship a model others can call.
# !pip install fastapi uvicorn pydantic scikit-learn joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import joblib
import json
import time
import warnings
warnings.filterwarnings('ignore')
# Train a simple churn prediction model we'll deploy
np.random.seed(42)
n = 3000
X = pd.DataFrame({
'session_count': np.random.poisson(20, n),
'days_since_login': np.random.exponential(30, n),
'total_spent': np.random.lognormal(4, 1.5, n),
'support_tickets': np.random.poisson(1, n),
'plan_encoded': np.random.choice([0,1,2,3], n),
})
churn_prob = 1 / (1 + np.exp(0.5*np.log(X['session_count']+1) - 0.1*X['support_tickets']))
y = (churn_prob > 0.5).astype(int)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model_pipeline = Pipeline([
('scaler', StandardScaler()),
('model', GradientBoostingClassifier(n_estimators=100, random_state=42))
])
model_pipeline.fit(X_train, y_train)
print('Model trained:')
print(classification_report(y_test, model_pipeline.predict(X_test)))
# Save model
joblib.dump(model_pipeline, '/tmp/churn_model.pkl')
print('Model saved to /tmp/churn_model.pkl')
1. FastAPI Application StructureΒΆ
# The complete FastAPI application
# (Save this as app.py to run with: uvicorn app:app --reload)
APP_CODE = '''
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, validator
from typing import Optional, List
import numpy as np
import pandas as pd
import joblib
import time
import logging
# --- Setup ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Churn Prediction API",
description="Predict customer churn probability from behavioral features.",
version="1.0.0",
)
# --- Load model at startup (not per request!) ---
model = None
@app.on_event("startup")
async def load_model():
global model
model = joblib.load("/app/models/churn_model.pkl")
logger.info("Model loaded successfully")
# --- Request/Response schemas ---
class CustomerFeatures(BaseModel):
session_count: int = Field(..., ge=0, description="Total number of sessions")
days_since_login: float = Field(..., ge=0, description="Days since last login")
total_spent: float = Field(..., ge=0, description="Total amount spent ($)")
support_tickets: int = Field(default=0, ge=0, description="Number of support tickets")
plan_encoded: int = Field(..., ge=0, le=3, description="Plan: 0=free, 1=basic, 2=pro, 3=enterprise")
@validator("days_since_login")
def validate_recent_login(cls, v):
if v > 3650: # 10 years
raise ValueError("days_since_login seems unreasonably large")
return v
class PredictionResponse(BaseModel):
customer_id: Optional[str]
churn_probability: float
prediction: str # "churn" or "retain"
confidence: str # "high" / "medium" / "low"
model_version: str = "1.0.0"
latency_ms: float
class BatchRequest(BaseModel):
customers: List[CustomerFeatures]
ids: Optional[List[str]] = None
# --- Endpoints ---
@app.get("/health")
async def health_check():
"""Health check for load balancer / Kubernetes liveness probe."""
return {
"status": "healthy",
"model_loaded": model is not None,
"timestamp": time.time(),
}
@app.get("/metrics")
async def metrics():
"""Prometheus-compatible metrics endpoint."""
return {"requests_total": 0, "errors_total": 0} # In prod: use Prometheus client
@app.post("/predict", response_model=PredictionResponse)
async def predict(features: CustomerFeatures, customer_id: Optional[str] = None):
"""Single customer churn prediction."""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
start = time.time()
X = pd.DataFrame([features.dict()])
prob = model.predict_proba(X)[0, 1]
prediction = "churn" if prob > 0.5 else "retain"
# Confidence tiers
if prob > 0.8 or prob < 0.2:
confidence = "high"
elif prob > 0.65 or prob < 0.35:
confidence = "medium"
else:
confidence = "low"
latency_ms = (time.time() - start) * 1000
logger.info(f"predict: prob={prob:.3f}, latency={latency_ms:.1f}ms")
return PredictionResponse(
customer_id=customer_id,
churn_probability=round(prob, 4),
prediction=prediction,
confidence=confidence,
latency_ms=round(latency_ms, 2),
)
@app.post("/predict/batch")
async def predict_batch(request: BatchRequest):
"""Batch prediction for up to 1000 customers."""
if len(request.customers) > 1000:
raise HTTPException(status_code=400, detail="Max batch size is 1000")
start = time.time()
X = pd.DataFrame([c.dict() for c in request.customers])
probs = model.predict_proba(X)[:, 1]
results = [
{
"id": request.ids[i] if request.ids else str(i),
"churn_probability": round(float(p), 4),
"prediction": "churn" if p > 0.5 else "retain",
}
for i, p in enumerate(probs)
]
return {
"results": results,
"count": len(results),
"latency_ms": round((time.time() - start) * 1000, 2)
}
'''
with open('/tmp/app.py', 'w') as f:
f.write(APP_CODE)
print('app.py written to /tmp/app.py')
print()
print('To run:')
print(' cd /tmp && uvicorn app:app --host 0.0.0.0 --port 8000 --reload')
print()
print('Swagger UI auto-generated at: http://localhost:8000/docs')
2. Calling the API β Client PatternsΒΆ
# Simulate the API call pattern (without a running server)
import json
# What a client would send
single_request = {
'session_count': 5,
'days_since_login': 45.0,
'total_spent': 89.99,
'support_tickets': 3,
'plan_encoded': 0,
}
batch_request = {
'customers': [
{'session_count': 5, 'days_since_login': 45.0, 'total_spent': 89.99, 'support_tickets': 3, 'plan_encoded': 0},
{'session_count': 45, 'days_since_login': 2.0, 'total_spent': 499.00, 'support_tickets': 0, 'plan_encoded': 2},
{'session_count': 2, 'days_since_login': 120.0,'total_spent': 19.99, 'support_tickets': 5, 'plan_encoded': 0},
],
'ids': ['cust_001', 'cust_002', 'cust_003'],
}
# Local prediction (simulating what the API would return)
def local_predict(features: dict):
X = pd.DataFrame([features])
prob = model_pipeline.predict_proba(X)[0, 1]
prediction = 'churn' if prob > 0.5 else 'retain'
confidence = 'high' if (prob > 0.8 or prob < 0.2) else 'medium' if (prob > 0.65 or prob < 0.35) else 'low'
return {'churn_probability': round(prob, 4), 'prediction': prediction, 'confidence': confidence}
print('Single prediction:')
result = local_predict(single_request)
print(json.dumps(result, indent=2))
print('\nBatch predictions:')
for cust_id, customer in zip(batch_request['ids'], batch_request['customers']):
result = local_predict(customer)
print(f' {cust_id}: {result["prediction"]} ({result["churn_probability"]:.1%}) [{result["confidence"]}]')
# Python client code pattern
CLIENT_CODE = '''
import httpx # or requests
# Single prediction
response = httpx.post(
"http://localhost:8000/predict",
json=single_request
)
print(response.json())
# Batch prediction
response = httpx.post(
"http://localhost:8000/predict/batch",
json=batch_request
)
# Async client for high throughput
async def predict_async(features_list):
async with httpx.AsyncClient() as client:
tasks = [client.post(url, json=f) for f in features_list]
results = await asyncio.gather(*tasks)
return [r.json() for r in results]
'''
print('\nClient code pattern:')
print(CLIENT_CODE)
3. Docker β Containerizing the APIΒΆ
DOCKERFILE = '''
# Multi-stage build: keeps final image small
FROM python:3.11-slim AS builder
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir --user -r requirements.txt
# Production image
FROM python:3.11-slim
# Security: don't run as root
RUN useradd --create-home appuser
WORKDIR /home/appuser/app
COPY --from=builder /root/.local /home/appuser/.local
COPY app.py .
COPY models/ ./models/
USER appuser
ENV PATH=/home/appuser/.local/bin:$PATH
# Health check: Docker will restart container if this fails
HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \\
CMD curl -f http://localhost:8000/health || exit 1
EXPOSE 8000
# Use Gunicorn + Uvicorn workers for production
CMD ["gunicorn", "app:app", "-w", "4", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"]
'''
REQUIREMENTS = '''
fastapi==0.111.0
uvicorn[standard]==0.30.0
gunicorn==22.0.0
pydantic==2.7.0
scikit-learn==1.4.2
pandas==2.2.2
numpy==1.26.4
joblib==1.4.2
'''
print('=== Dockerfile ===')
print(DOCKERFILE)
print('=== requirements.txt ===')
print(REQUIREMENTS)
print('=== Build & run commands ===')
print('docker build -t churn-api:1.0.0 .')
print('docker run -p 8000:8000 --name churn-api churn-api:1.0.0')
print()
print('=== Docker Compose for local dev with auto-reload ===')
COMPOSE = '''
version: "3.9"
services:
api:
build: .
ports:
- "8000:8000"
volumes:
- ./app.py:/home/appuser/app/app.py # hot reload
environment:
- MODEL_PATH=/home/appuser/app/models/churn_model.pkl
command: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
'''
print(COMPOSE)
4. Performance Testing β Latency & ThroughputΒΆ
# Benchmark local model inference (without HTTP overhead)
import time
# Single prediction latency
n_warmup = 10
n_bench = 1000
sample = pd.DataFrame([{
'session_count': 20, 'days_since_login': 15.0,
'total_spent': 150.0, 'support_tickets': 1, 'plan_encoded': 1
}])
# Warmup
for _ in range(n_warmup):
_ = model_pipeline.predict_proba(sample)
# Single inference timing
start = time.perf_counter()
for _ in range(n_bench):
_ = model_pipeline.predict_proba(sample)
single_latency = (time.perf_counter() - start) / n_bench * 1000
# Batch inference timing
batch_sizes = [1, 10, 100, 1000]
latencies = {}
for bs in batch_sizes:
batch = pd.DataFrame([
{'session_count': np.random.poisson(20), 'days_since_login': np.random.exponential(30),
'total_spent': np.random.lognormal(4, 1), 'support_tickets': np.random.poisson(1),
'plan_encoded': np.random.randint(0, 4)}
for _ in range(bs)
])
times = []
for _ in range(100):
t0 = time.perf_counter()
_ = model_pipeline.predict_proba(batch)
times.append((time.perf_counter() - t0) * 1000)
latencies[bs] = np.mean(times)
print('Inference Performance Benchmarks:')
print(f'{"Batch Size":<12} {"Latency (ms)":<16} {"Throughput (req/s)"}')
print('-' * 45)
for bs, lat in latencies.items():
throughput = bs / (lat / 1000)
print(f'{bs:<12} {lat:<16.2f} {throughput:.0f}')
print()
print('Key insight: batch processing is much more efficient than N sequential single calls')
fig, ax = plt.subplots(figsize=(9, 5))
ax.bar([str(bs) for bs in batch_sizes], [latencies[bs] for bs in batch_sizes], color='steelblue', alpha=0.8)
ax.set_xlabel('Batch size')
ax.set_ylabel('Latency (ms)')
ax.set_title('Inference Latency by Batch Size (GBM model)')
plt.tight_layout()
plt.show()
Model Deployment Cheat SheetΒΆ
Component Best Practice
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Model loading Load once at startup, not per request
Input validation Pydantic models with type + range checks
Error handling Return 422 for validation, 503 for model errors
Latency target < 50ms p99 for real-time, < 500ms for batch jobs
Health check GET /health for liveness + readiness probes
Logging Log prediction + features + latency (not PII)
Versioning /v1/predict β allows backward-compat rollouts
Docker Non-root user, multi-stage build, health check
Production server Gunicorn + Uvicorn workers (not uvicorn alone)
Load testing Locust or k6 before production launch
Scaling Strategies:
Horizontal: Docker + Kubernetes β add more pods
Vertical: Larger instance for memory-heavy models
Caching: Redis for identical repeated inputs
Async: FastAPI async endpoints + thread pool for CPU-bound models
ExercisesΒΆ
Add request ID tracking (UUID) to every response for distributed tracing.
Implement an
/explainendpoint using SHAP values for a single prediction.Add Prometheus metrics (request count, latency histogram) using
prometheus-fastapi-instrumentator.Write a Locust load test script that simulates 100 concurrent users hitting
/predict.Add model version routing:
/v1/predictand/v2/predictserving different model versions simultaneously.