Production-Grade Streaming SystemsΒΆ
OverviewΒΆ
Getting a streaming endpoint working in a notebook is step one. Getting it to handle 10,000 concurrent users reliably, with observability and graceful failure, is an entirely different challenge.
This notebook covers the engineering required to take streaming LLM endpoints to production.
Learning ObjectivesΒΆ
Understand production requirements for streaming systems
Implement connection limits and resource management
Build rate limiting per user/IP with token bucket algorithm
Handle backpressure from slow clients
Add timeout management at every layer
Implement circuit breakers for LLM API failures
Add retry logic with exponential backoff
Instrument Prometheus metrics (TTFT, tokens/sec, active connections)
Implement graceful shutdown and connection draining
Load test with Locust
Deploy with Docker Compose
PrerequisitesΒΆ
pip install fastapi uvicorn prometheus-client locust httpx python-dotenv
1. Production RequirementsΒΆ
What Makes Streaming Hard in ProductionΒΆ
Challenge |
Description |
Solution |
|---|---|---|
Long-lived connections |
SSE/WS hold sockets open for 10-60s |
Connection limits + timeouts |
Slow clients |
Client reads slower than LLM writes |
Backpressure / bounded queues |
LLM API flakiness |
OpenAI/Anthropic have ~99.9% uptime |
Circuit breakers + retries |
Memory leaks |
Each connection allocates buffers |
Explicit cleanup + weak refs |
Load balancer issues |
Default LBs close idle connections |
Nginx SSE config + keepalive |
No observability |
Canβt see TTFT or error rates |
Prometheus metrics |
Thundering herd |
LLM returns -> all clients wake |
Staggered retries |
Target SLOsΒΆ
TTFT (Time to First Token): p50 < 300ms, p95 < 1000ms, p99 < 2000ms
Total generation time: p50 < 5s, p95 < 15s
Error rate: < 0.1%
Active connections: < 500 per instance
Memory per connection: < 2MB
# Install dependencies (uncomment if needed)
# !pip install fastapi uvicorn prometheus-client locust httpx python-dotenv
import asyncio
import json
import os
import time
import signal
import weakref
import logging
import traceback
from collections import defaultdict, deque
from dataclasses import dataclass, field
from enum import Enum
from typing import AsyncGenerator, Callable, Dict, List, Optional, Set
try:
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import uvicorn
FASTAPI_AVAILABLE = True
except ImportError:
FASTAPI_AVAILABLE = False
print("FastAPI not installed: pip install fastapi uvicorn")
try:
from prometheus_client import (
Counter, Histogram, Gauge, Summary,
generate_latest, CONTENT_TYPE_LATEST,
CollectorRegistry
)
PROMETHEUS_AVAILABLE = True
except ImportError:
PROMETHEUS_AVAILABLE = False
print("prometheus-client not installed: pip install prometheus-client")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("prod_streaming")
print("Imports complete.")
print(f"FastAPI: {'OK' if FASTAPI_AVAILABLE else 'MISSING'}")
print(f"Prometheus: {'OK' if PROMETHEUS_AVAILABLE else 'MISSING'}")
2. Load Balancing for SSE and WebSocketsΒΆ
Why Standard Load Balancers Fail for StreamingΒΆ
Standard HTTP load balancers (ALB, HAProxy defaults) are optimized for short-lived request-response cycles: they distribute requests round-robin, enforce 30-second timeouts, and buffer responses for compression. SSE and WebSocket connections violate all of these assumptions β they hold sockets open for minutes, produce data incrementally, and must not be buffered. Nginx requires explicit configuration: proxy_buffering off prevents nginx from accumulating tokens before forwarding, ip_hash ensures sticky sessions so reconnecting clients hit the same backend, and proxy_read_timeout must be extended to 300+ seconds.
Why sticky sessions matter: if a client disconnects and reconnects (common on mobile networks), a non-sticky load balancer may route them to a different backend that has no context about their previous conversation. With ip_hash or cookie-based affinity, the client returns to the same server. For WebSocket upgrades, the Connection: upgrade and Upgrade headers must be explicitly proxied β nginx will not forward them by default.
NGINX_SSE_CONFIG = """
# nginx.conf -- Production SSE/WebSocket streaming config
upstream streaming_backend {
# Sticky sessions: same client -> same backend (important for SSE)
# Use ip_hash OR cookie-based sticky sessions
ip_hash;
server app1:8000 max_fails=3 fail_timeout=30s;
server app2:8000 max_fails=3 fail_timeout=30s;
server app3:8000 max_fails=3 fail_timeout=30s;
# Keep connections to upstream alive
keepalive 32;
}
server {
listen 80;
server_name api.example.com;
# SSE endpoint
location /stream {
proxy_pass http://streaming_backend;
# Critical for SSE: disable buffering so tokens reach client immediately
proxy_buffering off;
proxy_cache off;
# Long timeouts for streaming connections
proxy_read_timeout 300s; # 5 minutes max stream
proxy_send_timeout 300s;
proxy_connect_timeout 10s;
# SSE headers
proxy_set_header Connection '';
proxy_http_version 1.1;
chunked_transfer_encoding on;
# Pass real client IP
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
# Prevent nginx from closing idle SSE connections
keepalive_timeout 300s;
}
# WebSocket endpoint
location /ws {
proxy_pass http://streaming_backend;
proxy_http_version 1.1;
# Required for WebSocket upgrade
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_read_timeout 3600s; # 1 hour for WS
proxy_buffering off;
}
# Regular API endpoints (short timeout)
location / {
proxy_pass http://streaming_backend;
proxy_read_timeout 30s;
proxy_buffering on;
}
}
"""
print("Nginx configuration for SSE/WebSocket:")
print(NGINX_SSE_CONFIG)
print("Key settings explained:")
settings = [
("proxy_buffering off", "Tokens reach client immediately, not buffered"),
("ip_hash", "Same client always goes to same backend"),
("proxy_read_timeout 300s", "Allow streams up to 5 minutes"),
("keepalive 32", "Reuse connections to upstream (saves TCP overhead)"),
("Upgrade + Connection", "Required headers for WebSocket protocol upgrade"),
]
for setting, explanation in settings:
print(f" {setting:<30} -> {explanation}")
3. Connection Limits and Resource ManagementΒΆ
Preventing Resource ExhaustionΒΆ
Each streaming connection consumes a file descriptor (socket), a coroutine (or thread), and potentially megabytes of buffer memory. On a typical Linux server, the default ulimit is 1024 file descriptors; a single traffic spike can exhaust this and cause the server to reject all new connections β including health checks, which triggers cascading failures in container orchestrators.
The ConnectionManager class implements two-tier limits: a global maximum (e.g., 500 connections) and a per-IP maximum (e.g., 10 connections) to prevent a single client from monopolizing resources. It also enforces maximum connection age to evict zombie connections that stopped consuming data but never sent a TCP FIN. The semaphore-based approach ensures that limit checks and connection registration are atomic, preventing race conditions under concurrent load.
import asyncio
import time
import uuid
from dataclasses import dataclass, field
from typing import Dict, Optional, Set
@dataclass
class ConnectionInfo:
id: str
client_ip: str
query: str
started_at: float = field(default_factory=time.time)
tokens_sent: int = 0
bytes_sent: int = 0
def age_seconds(self) -> float:
return time.time() - self.started_at
class ConnectionManager:
"""
Manages the lifecycle of streaming connections.
Enforces global and per-IP connection limits.
"""
def __init__(
self,
max_global: int = 500,
max_per_ip: int = 10,
max_connection_age_s: float = 300.0
):
self.max_global = max_global
self.max_per_ip = max_per_ip
self.max_age = max_connection_age_s
self._connections: Dict[str, ConnectionInfo] = {}
self._per_ip: Dict[str, Set[str]] = defaultdict(set)
self._semaphore = asyncio.Semaphore(max_global)
self._lock = asyncio.Lock()
@property
def active_count(self) -> int:
return len(self._connections)
def connections_for_ip(self, ip: str) -> int:
return len(self._per_ip.get(ip, set()))
async def acquire(self, client_ip: str, query: str) -> Optional[ConnectionInfo]:
"""
Try to register a new connection.
Returns ConnectionInfo on success, None if limits are exceeded.
"""
async with self._lock:
# Check per-IP limit
if self.connections_for_ip(client_ip) >= self.max_per_ip:
logger.warning(f"Per-IP limit reached for {client_ip}")
return None
# Check global limit (non-blocking)
if self.active_count >= self.max_global:
logger.warning(f"Global connection limit reached ({self.max_global})")
return None
conn_id = str(uuid.uuid4())[:8]
info = ConnectionInfo(
id=conn_id,
client_ip=client_ip,
query=query
)
self._connections[conn_id] = info
self._per_ip[client_ip].add(conn_id)
logger.info(f"Connection {conn_id} opened from {client_ip} (total: {self.active_count})")
return info
async def release(self, conn_id: str):
"""Release a connection and free its resources."""
async with self._lock:
info = self._connections.pop(conn_id, None)
if info:
self._per_ip[info.client_ip].discard(conn_id)
if not self._per_ip[info.client_ip]:
del self._per_ip[info.client_ip]
logger.info(
f"Connection {conn_id} closed. "
f"Age={info.age_seconds():.1f}s, tokens={info.tokens_sent} "
f"(remaining: {self.active_count})"
)
async def evict_stale(self) -> int:
"""Evict connections that have exceeded max age."""
async with self._lock:
stale = [
conn_id for conn_id, info in self._connections.items()
if info.age_seconds() > self.max_age
]
for conn_id in stale:
await self.release(conn_id)
return len(stale)
def get_stats(self) -> Dict:
return {
"active_connections": self.active_count,
"max_global": self.max_global,
"utilization_pct": round(self.active_count / self.max_global * 100, 1),
"unique_ips": len(self._per_ip),
"connections_by_ip": {ip: len(ids) for ip, ids in self._per_ip.items()}
}
# --- Demo ---
async def demo_connection_manager():
mgr = ConnectionManager(max_global=5, max_per_ip=2)
# Acquire some connections
conns = []
for i in range(4):
ip = f"10.0.0.{i // 2 + 1}" # 2 connections per IP
conn = await mgr.acquire(ip, f"query {i}")
if conn:
conns.append(conn)
print(f" Opened {conn.id} from {conn.client_ip}")
else:
print(f" Rejected connection from {ip} (limit reached)")
# Try to exceed per-IP limit
print("\nTrying to exceed per-IP limit for 10.0.0.1:")
result = await mgr.acquire("10.0.0.1", "extra query")
print(f" Result: {'rejected' if result is None else 'accepted (unexpected)'}")
print(f"\nStats: {json.dumps(mgr.get_stats(), indent=2)}")
# Release one and try again
if conns:
await mgr.release(conns[0].id)
print(f"\nAfter releasing {conns[0].id}:")
print(f" Active: {mgr.active_count}")
await demo_connection_manager()
4. Rate Limiting per User/IPΒΆ
Token Bucket Algorithm for Streaming EndpointsΒΆ
The token bucket algorithm is the standard approach for rate limiting API endpoints. Each client IP gets a bucket with a fixed capacity (burst size) that refills at a constant rate (tokens per second). Each request consumes one or more tokens; if the bucket is empty, the request is rejected with HTTP 429. The bucket allows short bursts up to its capacity while enforcing a long-term average rate.
Formally, at time \(t\) the bucket holds \(\min(C, B(t_0) + r \cdot (t - t_0))\) tokens, where \(C\) is capacity, \(r\) is the refill rate, and \(B(t_0)\) was the count at the last check. For streaming endpoints, rate limiting is especially important because each connection is expensive (long-lived, memory-holding), so even a moderate burst of abusive connections can degrade service for legitimate users. The retry_after_seconds method tells clients exactly how long to wait, enabling well-behaved clients to back off gracefully.
class TokenBucket:
"""
Thread-safe token bucket for rate limiting.
capacity: max tokens (burst size)
refill_rate: tokens added per second
"""
def __init__(self, capacity: float, refill_rate: float):
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = capacity
self.last_refill = time.monotonic()
def _refill(self):
now = time.monotonic()
elapsed = now - self.last_refill
added = elapsed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + added)
self.last_refill = now
def consume(self, tokens: float = 1.0) -> bool:
"""Try to consume tokens. Returns True if allowed, False if rate limited."""
self._refill()
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
def tokens_remaining(self) -> float:
self._refill()
return round(self.tokens, 2)
def retry_after_seconds(self, tokens_needed: float = 1.0) -> float:
"""How many seconds until enough tokens are available."""
self._refill()
deficit = tokens_needed - self.tokens
if deficit <= 0:
return 0.0
return round(deficit / self.refill_rate, 2)
class RateLimiter:
"""
Per-IP rate limiter using token buckets.
Automatically cleans up buckets for inactive IPs.
"""
def __init__(
self,
requests_per_minute: float = 20,
burst_size: float = 5,
cleanup_interval_s: float = 300
):
self.refill_rate = requests_per_minute / 60.0 # tokens/sec
self.burst_size = burst_size
self.cleanup_interval = cleanup_interval_s
self._buckets: Dict[str, TokenBucket] = {}
self._last_seen: Dict[str, float] = {}
self._lock = asyncio.Lock()
def _get_or_create_bucket(self, key: str) -> TokenBucket:
if key not in self._buckets:
self._buckets[key] = TokenBucket(
capacity=self.burst_size,
refill_rate=self.refill_rate
)
self._last_seen[key] = time.monotonic()
return self._buckets[key]
async def check(self, client_ip: str, cost: float = 1.0) -> Dict:
"""
Check rate limit for a client.
Returns dict with allowed, remaining_tokens, retry_after.
"""
async with self._lock:
bucket = self._get_or_create_bucket(client_ip)
allowed = bucket.consume(cost)
return {
"allowed": allowed,
"remaining": bucket.tokens_remaining(),
"limit": self.burst_size,
"retry_after": 0.0 if allowed else bucket.retry_after_seconds(cost)
}
async def cleanup_stale(self) -> int:
"""Remove buckets for IPs not seen recently."""
now = time.monotonic()
async with self._lock:
stale_keys = [
k for k, last in self._last_seen.items()
if now - last > self.cleanup_interval
]
for k in stale_keys:
del self._buckets[k]
del self._last_seen[k]
return len(stale_keys)
# --- Demo ---
async def demo_rate_limiter():
# 10 req/min, burst of 3
limiter = RateLimiter(requests_per_minute=10, burst_size=3)
ip = "192.168.1.100"
print(f"Rate limit: 10 req/min, burst=3")
print(f"Sending 6 rapid requests from {ip}:")
print("-" * 50)
for i in range(6):
result = await limiter.check(ip)
status = "ALLOWED" if result["allowed"] else f"BLOCKED (retry in {result['retry_after']}s)"
print(f" Request {i+1}: {status} | tokens_remaining={result['remaining']}")
# Simulate time passing -- tokens refill
print("\nWaiting 6 seconds for tokens to refill...")
await asyncio.sleep(6)
result = await limiter.check(ip)
print(f"After 6s: {result}")
await demo_rate_limiter()
5. Backpressure Handling (Slow Clients)ΒΆ
Managing Producer-Consumer Speed MismatchΒΆ
When an LLM produces tokens at 50-100 tokens/second but a client on a slow mobile connection can only consume at 10 tokens/second, the server must handle the resulting backpressure. There are three strategies: (1) unbounded buffering (accumulate all tokens in memory β dangerous, causes OOM), (2) dropping tokens (fast but corrupts the response), or (3) bounded queue with disconnect (safest for production).
The bounded asyncio.Queue pattern sets a maxsize (e.g., 100 tokens). The producer puts tokens with a timeout; if the queue is full for longer than the threshold, the server assumes the client is too slow and closes the connection cleanly. This prevents memory exhaustion while giving slow clients a fair chance to catch up. The queue depth also serves as a real-time health metric: a consistently full queue indicates the client is at capacity.
class BackpressureError(Exception):
"""Raised when client is too slow to consume the stream."""
class BackpressureStream:
"""
Wraps a source async generator with backpressure detection.
If the internal buffer fills up (client not reading fast enough),
raises BackpressureError to terminate the connection cleanly.
"""
def __init__(
self,
source: AsyncGenerator,
queue_maxsize: int = 50,
slow_client_timeout_s: float = 5.0
):
self.source = source
self.queue_maxsize = queue_maxsize
self.slow_client_timeout = slow_client_timeout_s
self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
self._producer_task: Optional[asyncio.Task] = None
self.tokens_dropped = 0
self.tokens_sent = 0
self.backpressure_events = 0
async def _producer(self):
"""Reads from source and fills the queue."""
try:
async for item in self.source:
try:
# Wait at most slow_client_timeout for space in queue
await asyncio.wait_for(
self._queue.put(item),
timeout=self.slow_client_timeout
)
except asyncio.TimeoutError:
self.backpressure_events += 1
logger.warning(
f"Backpressure: queue full after {self.slow_client_timeout}s. "
f"Terminating slow client. (event #{self.backpressure_events})"
)
await self._queue.put(BackpressureError("Client too slow"))
return
except Exception as e:
await self._queue.put(e)
finally:
await self._queue.put(None) # Sentinel: stream done
async def stream(self) -> AsyncGenerator:
"""Async generator that yields items with backpressure detection."""
self._producer_task = asyncio.create_task(self._producer())
try:
while True:
item = await self._queue.get()
if item is None:
break
if isinstance(item, BackpressureError):
raise item
if isinstance(item, Exception):
raise item
self.tokens_sent += 1
yield item
finally:
if self._producer_task and not self._producer_task.done():
self._producer_task.cancel()
def get_stats(self) -> Dict:
return {
"queue_size": self._queue.qsize(),
"queue_maxsize": self.queue_maxsize,
"tokens_sent": self.tokens_sent,
"tokens_dropped": self.tokens_dropped,
"backpressure_events": self.backpressure_events
}
# --- Demo ---
async def fast_producer() -> AsyncGenerator[str, None]:
"""Produces tokens very fast (simulates a fast LLM)."""
for i in range(20):
yield f"token_{i} "
await asyncio.sleep(0.01) # 100 tokens/sec
async def demo_backpressure():
print("Demo 1: Fast client (no backpressure)")
bp = BackpressureStream(fast_producer(), queue_maxsize=5, slow_client_timeout_s=1.0)
tokens = []
async for token in bp.stream():
tokens.append(token)
await asyncio.sleep(0.02) # Client reads at 50/sec -- can keep up
print(f" Received {len(tokens)} tokens. Stats: {bp.get_stats()}")
print("\nDemo 2: Slow client (backpressure triggered)")
bp2 = BackpressureStream(fast_producer(), queue_maxsize=3, slow_client_timeout_s=0.2)
tokens2 = []
try:
async for token in bp2.stream():
tokens2.append(token)
await asyncio.sleep(0.5) # Client reads at 2/sec -- too slow
except BackpressureError as e:
print(f" BackpressureError raised after {len(tokens2)} tokens: {e}")
print(f" Stats: {bp2.get_stats()}")
await demo_backpressure()
6. Timeout Management at Every LayerΒΆ
Defense Against Silent HangsΒΆ
Streaming systems have multiple timeout boundaries that must all be configured correctly: client connection timeout, nginx proxy_read_timeout, application-level stream timeout, and LLM API timeout. If any single layer has an infinite or missing timeout, a stalled connection hangs silently β consuming resources, holding a connection slot, and eventually triggering cascading failures when the connection pool fills up.
The correct approach is layered timeouts from outside in: nginx at 300s (outermost), application stream at 240s, LLM API call at 120s, individual chunk wait at 30s. Each inner timeout should be shorter than the outer one, so the application can send a clean error response rather than having nginx cut the connection abruptly. For TTFT specifically, a separate short timeout (e.g., 10s) catches cases where the LLM never starts generating.
@dataclass
class TimeoutConfig:
"""Timeout configuration for all layers of a streaming request."""
# How long to wait for the LLM API to start responding
llm_connect_s: float = 10.0
# How long to wait for the first token after generation starts
first_token_s: float = 30.0
# Maximum time between any two consecutive tokens
inter_token_s: float = 15.0
# Maximum total generation time
total_generation_s: float = 120.0
# Maximum total request time (includes retrieval + generation)
total_request_s: float = 180.0
class TimeoutManager:
"""
Enforces timeouts at every layer of a streaming request.
"""
def __init__(self, config: TimeoutConfig):
self.config = config
self.request_start = time.monotonic()
self.first_token_received = False
self.last_token_time = time.monotonic()
self.token_count = 0
def check_total_request_timeout(self):
elapsed = time.monotonic() - self.request_start
if elapsed > self.config.total_request_s:
raise asyncio.TimeoutError(
f"Total request timeout ({self.config.total_request_s}s) exceeded after {elapsed:.1f}s"
)
def check_first_token_timeout(self):
if not self.first_token_received:
elapsed = time.monotonic() - self.request_start
if elapsed > self.config.first_token_s:
raise asyncio.TimeoutError(
f"First token timeout ({self.config.first_token_s}s). No tokens after {elapsed:.1f}s"
)
def on_token(self):
self.first_token_received = True
now = time.monotonic()
gap = now - self.last_token_time
self.last_token_time = now
self.token_count += 1
if gap > self.config.inter_token_s:
raise asyncio.TimeoutError(
f"Inter-token timeout: {gap:.1f}s gap (max {self.config.inter_token_s}s)"
)
async def wrap_stream(
self,
source: AsyncGenerator
) -> AsyncGenerator:
"""Wrap a source stream with all timeout checks."""
try:
async with asyncio.timeout(self.config.total_request_s):
async for item in source:
self.on_token()
yield item
except asyncio.TimeoutError as e:
logger.error(f"Stream timeout: {e}")
yield {"type": "error", "code": "timeout", "message": str(e)}
# --- Demo ---
async def slow_llm_stream() -> AsyncGenerator[Dict, None]:
"""Simulates an LLM that pauses mid-generation."""
for i in range(5):
yield {"type": "token", "content": f"word{i} "}
await asyncio.sleep(0.1)
# Simulate a stall
await asyncio.sleep(2.0) # This will trigger inter-token timeout
yield {"type": "token", "content": "never_seen "}
async def demo_timeouts():
config = TimeoutConfig(
first_token_s=5.0,
inter_token_s=1.0, # 1s max between tokens
total_request_s=30.0
)
mgr = TimeoutManager(config)
print("Streaming with inter-token timeout of 1.0s:")
tokens_received = []
async for event in mgr.wrap_stream(slow_llm_stream()):
print(f" Event: {event}")
if event.get("type") == "token":
tokens_received.append(event["content"])
print(f"\nTokens received before timeout: {tokens_received}")
print(f"Token count: {mgr.token_count}")
await demo_timeouts()
7. Memory Management for Long-Lived ConnectionsΒΆ
Preventing Memory Leaks in Streaming ServicesΒΆ
Each streaming connection holds references to buffers, LLM client objects, callback handlers, and accumulated response text. In a typical 30-second stream generating 500 tokens, a single connection may hold 1-2MB of live objects. Without explicit cleanup, connections that have logically closed (client disconnected) can still hold references that prevent garbage collection β a classic memory leak pattern in async Python.
Weak references (weakref.ref) are the key defense: store connection metadata in a WeakValueDictionary so that entries are automatically removed when the connection object is garbage collected. Additionally, every async generator must have a finally block that explicitly clears buffers and closes client sessions. The combination of weak references, explicit cleanup, and periodic stale-connection eviction keeps memory usage bounded even under sustained high concurrency.
import gc
import sys
import weakref
from contextlib import asynccontextmanager
@dataclass
class StreamContext:
"""Per-connection context. Holds all resources for one streaming session."""
conn_id: str
client_ip: str
buffer: bytearray = field(default_factory=lambda: bytearray(4096))
token_history: List[str] = field(default_factory=list)
metadata: Dict = field(default_factory=dict)
def cleanup(self):
"""Explicitly release large resources."""
del self.buffer
self.buffer = bytearray(0)
self.token_history.clear()
self.metadata.clear()
def size_bytes(self) -> int:
return sys.getsizeof(self.buffer) + sum(sys.getsizeof(t) for t in self.token_history)
class MemoryAwareConnectionRegistry:
"""
Tracks active connections using WeakReferences.
When a StreamContext is garbage collected, it is automatically
removed from the registry -- no manual cleanup needed.
"""
def __init__(self, max_memory_mb: float = 500.0):
self.max_memory_bytes = max_memory_mb * 1024 * 1024
self._registry: Dict[str, weakref.ref] = {}
def register(self, ctx: StreamContext):
"""Register a context using a weak reference."""
def on_gc(ref):
self._registry.pop(ctx.conn_id, None)
logger.debug(f"Context {ctx.conn_id} garbage collected")
self._registry[ctx.conn_id] = weakref.ref(ctx, on_gc)
def active_contexts(self) -> List[StreamContext]:
"""Return all live context objects."""
live = []
for ref in list(self._registry.values()):
ctx = ref()
if ctx is not None:
live.append(ctx)
return live
def total_memory_bytes(self) -> int:
return sum(c.size_bytes() for c in self.active_contexts())
def check_memory_pressure(self) -> bool:
"""Return True if we are approaching memory limits."""
used = self.total_memory_bytes()
if used > self.max_memory_bytes * 0.9:
logger.warning(
f"Memory pressure: {used / 1024 / 1024:.1f}MB / "
f"{self.max_memory_bytes / 1024 / 1024:.1f}MB"
)
return True
return False
def force_cleanup_oldest(self, n: int = 5):
"""Force cleanup of the N oldest contexts under memory pressure."""
contexts = sorted(self.active_contexts(), key=lambda c: c.conn_id)
for ctx in contexts[:n]:
ctx.cleanup()
gc.collect()
logger.info(f"Force-cleaned {n} contexts. Remaining: {len(self.active_contexts())}")
@asynccontextmanager
async def managed_stream_context(
registry: MemoryAwareConnectionRegistry,
client_ip: str
):
"""Context manager that creates, registers, and cleans up a StreamContext."""
conn_id = str(uuid.uuid4())[:8]
ctx = StreamContext(conn_id=conn_id, client_ip=client_ip)
registry.register(ctx)
logger.info(f"Stream context {conn_id} created")
try:
yield ctx
finally:
ctx.cleanup()
logger.info(f"Stream context {conn_id} cleaned up")
# --- Demo ---
async def demo_memory_management():
registry = MemoryAwareConnectionRegistry(max_memory_mb=1.0)
# Create contexts and simulate streaming
contexts = []
for i in range(3):
async with managed_stream_context(registry, f"10.0.0.{i+1}") as ctx:
# Simulate accumulating tokens
for j in range(10):
ctx.token_history.append(f"token_{j}")
contexts.append(ctx)
print(f" Active context {ctx.conn_id}: {ctx.size_bytes()} bytes")
# After exiting context manager, cleanup() is called
gc.collect()
remaining = registry.active_contexts()
print(f"\nAfter all contexts exited: {len(remaining)} active")
print(f"Total memory: {registry.total_memory_bytes()} bytes")
await demo_memory_management()
8. Circuit Breaker for LLM API FailuresΒΆ
Fail Fast Instead of Fail SlowΒΆ
When the LLM API is experiencing an outage (returning 500 errors or timing out), the worst thing a streaming service can do is keep sending requests β each one consumes a connection slot for the full timeout duration before failing. The circuit breaker pattern (from Michael Nygardβs βRelease It!β) addresses this with three states: closed (normal operation), open (all requests immediately fail with a cached error), and half-open (a single test request probes whether the API has recovered).
The circuit opens after \(N\) consecutive failures (e.g., 5) and remains open for a configurable duration (e.g., 30s). During this time, incoming requests receive an instant 503 response instead of waiting 120 seconds for a timeout. This protects both the streaming service (connection pool stays available) and the LLM API (reduced load during recovery). The half-open state sends one probe request; if it succeeds, the circuit closes and normal traffic resumes.
class CircuitState(Enum):
CLOSED = "closed" # Normal -- requests pass through
OPEN = "open" # Failing -- reject requests immediately
HALF_OPEN = "half_open" # Testing recovery -- allow one probe request
class CircuitBreakerOpen(Exception):
"""Raised when the circuit breaker is open."""
class CircuitBreaker:
"""
Circuit breaker for LLM API calls.
State transitions:
CLOSED -> OPEN: failure_threshold failures in failure_window_s
OPEN -> HALF_OPEN: after reset_timeout_s
HALF_OPEN -> CLOSED: probe request succeeds
HALF_OPEN -> OPEN: probe request fails
"""
def __init__(
self,
name: str = "llm_api",
failure_threshold: int = 5,
failure_window_s: float = 60.0,
reset_timeout_s: float = 30.0,
success_threshold: int = 2 # Successes needed in HALF_OPEN to CLOSE
):
self.name = name
self.failure_threshold = failure_threshold
self.failure_window_s = failure_window_s
self.reset_timeout_s = reset_timeout_s
self.success_threshold = success_threshold
self._state = CircuitState.CLOSED
self._failure_times: deque = deque()
self._opened_at: Optional[float] = None
self._half_open_successes = 0
self._total_calls = 0
self._total_failures = 0
self._total_rejected = 0
self._lock = asyncio.Lock()
@property
def state(self) -> CircuitState:
return self._state
def _prune_failures(self):
"""Remove failures outside the sliding window."""
cutoff = time.monotonic() - self.failure_window_s
while self._failure_times and self._failure_times[0] < cutoff:
self._failure_times.popleft()
async def _check_state(self):
"""Check if we should transition OPEN -> HALF_OPEN."""
if self._state == CircuitState.OPEN:
if time.monotonic() - self._opened_at >= self.reset_timeout_s:
self._state = CircuitState.HALF_OPEN
self._half_open_successes = 0
logger.info(f"CircuitBreaker '{self.name}': OPEN -> HALF_OPEN")
async def call(self, func: Callable, *args, **kwargs):
"""
Execute a function through the circuit breaker.
Raises CircuitBreakerOpen if the circuit is OPEN.
"""
async with self._lock:
await self._check_state()
if self._state == CircuitState.OPEN:
self._total_rejected += 1
raise CircuitBreakerOpen(
f"Circuit '{self.name}' is OPEN. "
f"Retry after {self.reset_timeout_s}s."
)
self._total_calls += 1
try:
result = await func(*args, **kwargs)
async with self._lock:
if self._state == CircuitState.HALF_OPEN:
self._half_open_successes += 1
if self._half_open_successes >= self.success_threshold:
self._state = CircuitState.CLOSED
self._failure_times.clear()
logger.info(f"CircuitBreaker '{self.name}': HALF_OPEN -> CLOSED")
return result
except Exception as e:
self._total_failures += 1
async with self._lock:
now = time.monotonic()
self._failure_times.append(now)
self._prune_failures()
if self._state == CircuitState.HALF_OPEN:
self._state = CircuitState.OPEN
self._opened_at = now
logger.error(f"CircuitBreaker '{self.name}': HALF_OPEN -> OPEN (probe failed)")
elif len(self._failure_times) >= self.failure_threshold:
self._state = CircuitState.OPEN
self._opened_at = now
logger.error(
f"CircuitBreaker '{self.name}': CLOSED -> OPEN "
f"({len(self._failure_times)} failures in {self.failure_window_s}s)"
)
raise
def get_stats(self) -> Dict:
self._prune_failures()
return {
"name": self.name,
"state": self._state.value,
"recent_failures": len(self._failure_times),
"total_calls": self._total_calls,
"total_failures": self._total_failures,
"total_rejected": self._total_rejected,
"error_rate_pct": round(
self._total_failures / max(self._total_calls, 1) * 100, 1
)
}
# --- Demo ---
cb = CircuitBreaker(name="openai", failure_threshold=3, failure_window_s=10, reset_timeout_s=2)
call_count = 0
async def flaky_llm_call():
global call_count
call_count += 1
if call_count <= 4: # First 4 calls fail
raise Exception("OpenAI 503 Service Unavailable")
return "LLM response text"
async def demo_circuit_breaker():
for i in range(8):
try:
result = await cb.call(flaky_llm_call)
print(f" Call {i+1}: SUCCESS -> '{result}'")
except CircuitBreakerOpen as e:
print(f" Call {i+1}: CIRCUIT OPEN (fast fail)")
except Exception as e:
print(f" Call {i+1}: FAILED -> {e}")
print(f" State: {cb.state.value} | Stats: {cb.get_stats()}")
await asyncio.sleep(0.3)
# Wait for reset timeout
print("\nWaiting 2.5s for circuit to allow probe...")
await asyncio.sleep(2.5)
try:
result = await cb.call(flaky_llm_call)
print(f" Probe: SUCCESS -> '{result}' | State: {cb.state.value}")
except Exception as e:
print(f" Probe: {e} | State: {cb.state.value}")
await demo_circuit_breaker()
9. Retry Logic with Exponential BackoffΒΆ
Surviving Transient FailuresΒΆ
Transient failures β HTTP 429 (rate limited), 503 (service unavailable), and network timeouts β are routine in production LLM systems. The standard recovery pattern is exponential backoff with jitter: wait \(2^n + \text{random}(0, 1)\) seconds before retry \(n\), up to a maximum of 3-5 attempts. The jitter term prevents the thundering herd problem where thousands of clients that failed simultaneously all retry at exactly the same moment.
For streaming endpoints, retry logic has an important nuance: you can only retry before the first token is sent. Once token streaming has begun, retrying would produce a duplicate or corrupted response. The implementation should track whether any tokens have been yielded and fall back to an error event (rather than retry) if the stream has already started.
import random
class RetryConfig:
def __init__(
self,
max_retries: int = 3,
base_delay_s: float = 1.0,
max_delay_s: float = 30.0,
exponential_base: float = 2.0,
jitter: bool = True,
retryable_exceptions: tuple = (Exception,)
):
self.max_retries = max_retries
self.base_delay_s = base_delay_s
self.max_delay_s = max_delay_s
self.exponential_base = exponential_base
self.jitter = jitter
self.retryable_exceptions = retryable_exceptions
def delay_for_attempt(self, attempt: int) -> float:
"""Calculate delay for retry attempt N (0-indexed)."""
delay = self.base_delay_s * (self.exponential_base ** attempt)
delay = min(delay, self.max_delay_s)
if self.jitter:
# Full jitter: random in [0, delay]
delay = random.uniform(0, delay)
return delay
async def retry_with_backoff(
func: Callable,
config: RetryConfig,
*args,
**kwargs
):
"""
Execute func with exponential backoff retry.
Raises the last exception if all retries are exhausted.
"""
last_exception = None
for attempt in range(config.max_retries + 1):
try:
result = await func(*args, **kwargs)
if attempt > 0:
logger.info(f"Succeeded on attempt {attempt + 1}")
return result
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_retries:
logger.error(f"All {config.max_retries + 1} attempts failed. Last error: {e}")
break
delay = config.delay_for_attempt(attempt)
logger.warning(
f"Attempt {attempt + 1}/{config.max_retries + 1} failed: {e}. "
f"Retrying in {delay:.2f}s..."
)
await asyncio.sleep(delay)
raise last_exception
# --- Demo ---
attempt_counter = 0
async def unreliable_api_call():
global attempt_counter
attempt_counter += 1
if attempt_counter < 3:
raise ConnectionError(f"Network error (attempt {attempt_counter})")
return {"response": "success", "attempt": attempt_counter}
config = RetryConfig(
max_retries=4,
base_delay_s=0.1, # Short for demo
max_delay_s=2.0,
jitter=False # Deterministic for demo
)
print("Retry schedule (no jitter):")
for i in range(4):
delay = config.delay_for_attempt(i)
print(f" Attempt {i+1} fails -> wait {delay:.2f}s before attempt {i+2}")
print("\nRunning with unreliable API:")
result = await retry_with_backoff(unreliable_api_call, config)
print(f"Result: {result}")
10. Prometheus Metrics IntegrationΒΆ
Observability for Streaming SystemsΒΆ
Streaming services require specialized metrics beyond standard HTTP request counters. The four critical metrics are: (1) active_connections (Gauge) β how many streams are open right now, (2) ttft_seconds (Histogram) β time-to-first-token distribution for latency SLOs, (3) tokens_per_second (Summary) β generation throughput per stream, and (4) stream_errors_total (Counter) β error rate by type (timeout, disconnect, LLM failure).
Prometheus Histogram buckets should be tuned for streaming latencies: TTFT buckets at [0.1, 0.25, 0.5, 1.0, 2.0, 5.0] seconds, and total stream duration at [1, 5, 10, 30, 60, 120] seconds. These metrics feed directly into Grafana dashboards and PagerDuty alerts. A spike in active_connections without a corresponding spike in tokens_per_second indicates stalled streams β an early warning of LLM API issues before errors propagate to users.
# Use a custom registry to avoid conflicts in notebook re-runs
registry = CollectorRegistry() if PROMETHEUS_AVAILABLE else None
if PROMETHEUS_AVAILABLE:
# --- Connection metrics ---
ACTIVE_CONNECTIONS = Gauge(
'streaming_active_connections',
'Number of currently active streaming connections',
['endpoint'],
registry=registry
)
CONNECTIONS_TOTAL = Counter(
'streaming_connections_total',
'Total streaming connections opened',
['endpoint', 'status'], # status: success, rejected, error
registry=registry
)
# --- Latency metrics ---
TTFT_SECONDS = Histogram(
'streaming_ttft_seconds',
'Time to first token in seconds',
['endpoint', 'model'],
buckets=[0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0],
registry=registry
)
GENERATION_DURATION = Histogram(
'streaming_generation_duration_seconds',
'Total time from first to last token',
['endpoint', 'model'],
buckets=[1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0],
registry=registry
)
REQUEST_DURATION = Histogram(
'streaming_request_duration_seconds',
'Total request duration including retrieval',
['endpoint'],
buckets=[0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0],
registry=registry
)
# --- Throughput metrics ---
TOKENS_STREAMED_TOTAL = Counter(
'streaming_tokens_total',
'Total tokens streamed to clients',
['endpoint', 'model'],
registry=registry
)
TOKENS_PER_SECOND = Histogram(
'streaming_tokens_per_second',
'Token generation throughput',
['model'],
buckets=[5, 10, 20, 30, 50, 75, 100, 150, 200],
registry=registry
)
# --- Error metrics ---
ERRORS_TOTAL = Counter(
'streaming_errors_total',
'Total errors during streaming',
['endpoint', 'error_type'], # error_type: timeout, backpressure, llm_error, rate_limited
registry=registry
)
CIRCUIT_BREAKER_STATE = Gauge(
'streaming_circuit_breaker_state',
'Circuit breaker state (0=closed, 1=half_open, 2=open)',
['service'],
registry=registry
)
print("Prometheus metrics registered:")
for metric in [
'streaming_active_connections',
'streaming_connections_total',
'streaming_ttft_seconds',
'streaming_generation_duration_seconds',
'streaming_request_duration_seconds',
'streaming_tokens_total',
'streaming_tokens_per_second',
'streaming_errors_total',
'streaming_circuit_breaker_state'
]:
print(f" {metric}")
else:
print("prometheus-client not available. Install: pip install prometheus-client")
# Instrument a streaming handler with Prometheus metrics
@dataclass
class StreamStats:
start_time: float = field(default_factory=time.perf_counter)
first_token_time: Optional[float] = None
last_token_time: Optional[float] = None
token_count: int = 0
error: Optional[str] = None
def record_token(self):
now = time.perf_counter()
if self.first_token_time is None:
self.first_token_time = now
self.last_token_time = now
self.token_count += 1
@property
def ttft_seconds(self) -> Optional[float]:
if self.first_token_time:
return self.first_token_time - self.start_time
return None
@property
def generation_duration_seconds(self) -> Optional[float]:
if self.first_token_time and self.last_token_time:
return self.last_token_time - self.first_token_time
return None
@property
def total_duration_seconds(self) -> float:
return time.perf_counter() - self.start_time
@property
def tokens_per_second(self) -> Optional[float]:
dur = self.generation_duration_seconds
if dur and dur > 0:
return self.token_count / dur
return None
def record_stream_metrics(
stats: StreamStats,
endpoint: str = "/stream",
model: str = "gpt-4o-mini"
):
"""Flush StreamStats into Prometheus metrics."""
if not PROMETHEUS_AVAILABLE:
return
if stats.ttft_seconds is not None:
TTFT_SECONDS.labels(endpoint=endpoint, model=model).observe(stats.ttft_seconds)
if stats.generation_duration_seconds is not None:
GENERATION_DURATION.labels(endpoint=endpoint, model=model).observe(
stats.generation_duration_seconds
)
REQUEST_DURATION.labels(endpoint=endpoint).observe(stats.total_duration_seconds)
TOKENS_STREAMED_TOTAL.labels(endpoint=endpoint, model=model).inc(stats.token_count)
if stats.tokens_per_second is not None:
TOKENS_PER_SECOND.labels(model=model).observe(stats.tokens_per_second)
if stats.error:
ERRORS_TOTAL.labels(endpoint=endpoint, error_type=stats.error).inc()
# --- Simulate several streaming requests and check metrics ---
import numpy as np
async def simulate_stream_request(ttft_ms: float, tokens: int, tps: float):
"""Simulate a streaming request with given TTFT, token count, and tokens/sec."""
stats = StreamStats()
if PROMETHEUS_AVAILABLE:
ACTIVE_CONNECTIONS.labels(endpoint="/stream").inc()
try:
await asyncio.sleep(ttft_ms / 1000.0) # Simulate retrieval + TTFT
for _ in range(tokens):
stats.record_token()
await asyncio.sleep(1.0 / tps)
finally:
if PROMETHEUS_AVAILABLE:
ACTIVE_CONNECTIONS.labels(endpoint="/stream").dec()
record_stream_metrics(stats)
return stats
# Simulate 5 requests with varying characteristics
print("Simulating 5 streaming requests...")
requests = [
(150, 30, 25), # ttft_ms, tokens, tps
(300, 50, 20),
(200, 20, 30),
(800, 80, 15),
(100, 15, 35),
]
all_stats = await asyncio.gather(*[
simulate_stream_request(ttft, tokens, tps)
for ttft, tokens, tps in requests
])
print("\nPer-request stats:")
for i, s in enumerate(all_stats):
print(
f" Request {i+1}: "
f"TTFT={s.ttft_seconds*1000:.0f}ms, "
f"tokens={s.token_count}, "
f"tps={s.tokens_per_second:.1f}, "
f"total={s.total_duration_seconds*1000:.0f}ms"
)
# Summary percentiles
ttfts = [s.ttft_seconds * 1000 for s in all_stats if s.ttft_seconds]
tps_vals = [s.tokens_per_second for s in all_stats if s.tokens_per_second]
print(f"\nAggregate:")
print(f" TTFT p50={np.percentile(ttfts, 50):.0f}ms, p95={np.percentile(ttfts, 95):.0f}ms")
print(f" Tokens/sec avg={np.mean(tps_vals):.1f}")
if PROMETHEUS_AVAILABLE:
output = generate_latest(registry).decode("utf-8")
# Show just the TTFT metric
for line in output.split("\n"):
if "ttft" in line or "tokens_per" in line:
print(f" {line}")
11. Graceful Shutdown and Connection DrainingΒΆ
Zero-Downtime Deployments for StreamingΒΆ
When deploying a new version, the server must stop accepting new connections while letting existing streams finish gracefully. This is called connection draining. The process follows three steps: (1) receive SIGTERM from the orchestrator, (2) stop the listener (no new connections), (3) wait up to a drain timeout (e.g., 60 seconds) for active streams to complete, then force-close any remaining connections.
Without graceful shutdown, a rolling deployment in Kubernetes kills pods immediately, severing active streams mid-response. Users see truncated answers with no error message. The drain timeout must be configured at every layer: Kubernetes terminationGracePeriodSeconds, the application shutdown handler, and nginx proxy_read_timeout. The application should also send a final SSE event ({βtypeβ: βserver_shutdownβ}) to connected clients so they can retry on a healthy instance.
class GracefulShutdownManager:
"""
Manages graceful shutdown of a streaming server.
Shutdown sequence:
1. Receive SIGTERM
2. Stop accepting new connections (is_shutting_down = True)
3. Send 'draining' SSE event to all active connections
4. Wait for active connections to finish (up to drain_timeout_s)
5. Force-close remaining connections
6. Exit
"""
def __init__(self, drain_timeout_s: float = 30.0):
self.drain_timeout_s = drain_timeout_s
self.is_shutting_down = False
self._active_streams: Set[str] = set()
self._stream_done_event: Dict[str, asyncio.Event] = {}
self._shutdown_complete = asyncio.Event()
self._lock = asyncio.Lock()
async def register_stream(self, stream_id: str):
async with self._lock:
self._active_streams.add(stream_id)
self._stream_done_event[stream_id] = asyncio.Event()
async def unregister_stream(self, stream_id: str):
async with self._lock:
self._active_streams.discard(stream_id)
if stream_id in self._stream_done_event:
self._stream_done_event[stream_id].set()
del self._stream_done_event[stream_id]
if not self._active_streams and self.is_shutting_down:
self._shutdown_complete.set()
def accept_new_connections(self) -> bool:
"""Check if new connections should be accepted."""
return not self.is_shutting_down
async def initiate_shutdown(self):
"""Begin graceful shutdown."""
logger.info("Graceful shutdown initiated")
self.is_shutting_down = True
active_count = len(self._active_streams)
if active_count == 0:
logger.info("No active streams. Shutdown complete.")
self._shutdown_complete.set()
return
logger.info(f"Draining {active_count} active stream(s) (timeout={self.drain_timeout_s}s)...")
try:
await asyncio.wait_for(
self._shutdown_complete.wait(),
timeout=self.drain_timeout_s
)
logger.info("All streams drained. Clean shutdown.")
except asyncio.TimeoutError:
remaining = len(self._active_streams)
logger.warning(f"Drain timeout. Force-closing {remaining} stream(s).")
def get_stats(self) -> Dict:
return {
"active_streams": len(self._active_streams),
"is_shutting_down": self.is_shutting_down,
}
# FastAPI lifespan integration
SHUTDOWN_MANAGER = GracefulShutdownManager(drain_timeout_s=30.0)
LIFESPAN_CODE = '''
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Server starting up")
yield
# Shutdown (triggered by SIGTERM/SIGINT)
logger.info("Server shutting down")
await SHUTDOWN_MANAGER.initiate_shutdown()
app = FastAPI(lifespan=lifespan)
'''
# --- Demo graceful shutdown ---
async def demo_graceful_shutdown():
mgr = GracefulShutdownManager(drain_timeout_s=2.0)
# Register some active streams
for sid in ["stream-A", "stream-B", "stream-C"]:
await mgr.register_stream(sid)
print(f"Active streams: {mgr.get_stats()}")
print(f"Accepting connections: {mgr.accept_new_connections()}")
# Simulate streams finishing
async def finish_stream(sid, delay):
await asyncio.sleep(delay)
await mgr.unregister_stream(sid)
print(f" Stream {sid} finished")
# Start shutdown and simultaneously finish streams
print("\nInitiating graceful shutdown...")
print(f"Accepting new connections: {mgr.accept_new_connections()}")
shutdown_task = asyncio.create_task(mgr.initiate_shutdown())
finish_tasks = [
asyncio.create_task(finish_stream("stream-A", 0.3)),
asyncio.create_task(finish_stream("stream-B", 0.5)),
asyncio.create_task(finish_stream("stream-C", 0.8)),
]
await asyncio.gather(shutdown_task, *finish_tasks)
print(f"Shutdown complete. Active streams: {mgr.get_stats()}")
await demo_graceful_shutdown()
12. Complete Production FastAPI AppΒΆ
All Components IntegratedΒΆ
This section assembles every production component β connection management, rate limiting, backpressure handling, circuit breaker, retry logic, Prometheus metrics, and graceful shutdown β into a single FastAPI application. The app demonstrates how these components compose: each incoming request passes through rate limiting, then connection limit checks, then the streaming handler with backpressure and timeout management, with metrics recorded at every stage.
Key integration patterns: middleware handles cross-cutting concerns (rate limiting, metrics), the connection manager is a FastAPI dependency, and the circuit breaker wraps the LLM client. The /metrics endpoint exposes Prometheus-format data, /health returns detailed system status including circuit breaker state and connection pool utilization, and the streaming endpoint (/stream) ties everything together.
from contextlib import asynccontextmanager
# --- Singletons ---
conn_manager = ConnectionManager(max_global=500, max_per_ip=10)
rate_limiter = RateLimiter(requests_per_minute=20, burst_size=5)
circuit_breaker = CircuitBreaker(name="llm_api", failure_threshold=5, reset_timeout_s=30)
shutdown_manager = GracefulShutdownManager(drain_timeout_s=30)
timeout_config = TimeoutConfig(
first_token_s=30.0,
inter_token_s=15.0,
total_request_s=120.0
)
retry_config = RetryConfig(max_retries=2, base_delay_s=1.0, max_delay_s=10.0)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Production streaming server starting")
yield
logger.info("Initiating graceful shutdown")
await shutdown_manager.initiate_shutdown()
prod_app = FastAPI(title="Production Streaming API", version="1.0.0", lifespan=lifespan)
class StreamRequest(BaseModel):
prompt: str
model: str = "gpt-4o-mini"
max_tokens: int = 512
async def mock_llm_stream(prompt: str, model: str) -> AsyncGenerator[str, None]:
"""Mock LLM -- replace with real OpenAI/Anthropic call."""
words = f"This is a streaming response for: {prompt[:40]}".split()
for word in words:
yield word + " "
await asyncio.sleep(0.05)
async def production_stream_generator(
request: StreamRequest,
conn_info: ConnectionInfo
) -> AsyncGenerator[str, None]:
"""Full production streaming pipeline with all safety mechanisms."""
stats = StreamStats()
stream_id = conn_info.id
if PROMETHEUS_AVAILABLE:
ACTIVE_CONNECTIONS.labels(endpoint="/stream").inc()
await shutdown_manager.register_stream(stream_id)
try:
# Check circuit breaker
if circuit_breaker.state == CircuitState.OPEN:
yield f"data: {json.dumps({'type': 'error', 'code': 'circuit_open', 'message': 'LLM API unavailable, retry later'})}\n\n"
return
yield f"data: {json.dumps({'type': 'status', 'message': 'Generating...'})}\n\n"
# Timeout-wrapped source
timeout_mgr = TimeoutManager(timeout_config)
async def call_llm():
return mock_llm_stream(request.prompt, request.model)
source = await circuit_breaker.call(call_llm)
# Backpressure wrapper
bp_stream = BackpressureStream(source, queue_maxsize=100, slow_client_timeout_s=5.0)
async for token in bp_stream.stream():
stats.record_token()
conn_info.tokens_sent += 1
token_bytes = token.encode()
conn_info.bytes_sent += len(token_bytes)
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
yield f"data: {json.dumps({'type': 'done', 'tokens': stats.token_count})}\n\n"
except BackpressureError:
if PROMETHEUS_AVAILABLE:
ERRORS_TOTAL.labels(endpoint="/stream", error_type="backpressure").inc()
yield f"data: {json.dumps({'type': 'error', 'code': 'backpressure'})}\n\n"
except CircuitBreakerOpen as e:
if PROMETHEUS_AVAILABLE:
ERRORS_TOTAL.labels(endpoint="/stream", error_type="circuit_open").inc()
yield f"data: {json.dumps({'type': 'error', 'code': 'circuit_open', 'message': str(e)})}\n\n"
except asyncio.CancelledError:
logger.info(f"Stream {stream_id} cancelled (client disconnected)")
except Exception as e:
logger.error(f"Stream {stream_id} error: {e}")
if PROMETHEUS_AVAILABLE:
ERRORS_TOTAL.labels(endpoint="/stream", error_type="unhandled").inc()
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
finally:
await conn_manager.release(conn_info.id)
await shutdown_manager.unregister_stream(stream_id)
record_stream_metrics(stats, endpoint="/stream", model=request.model)
if PROMETHEUS_AVAILABLE:
ACTIVE_CONNECTIONS.labels(endpoint="/stream").dec()
@prod_app.post("/stream")
async def stream_endpoint(req: StreamRequest, http_req: Request):
client_ip = http_req.client.host if http_req.client else "unknown"
# Reject if shutting down
if not shutdown_manager.accept_new_connections():
raise HTTPException(status_code=503, detail="Server is shutting down")
# Rate limit check
rate_result = await rate_limiter.check(client_ip)
if not rate_result["allowed"]:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Retry after {rate_result['retry_after']}s",
headers={"Retry-After": str(rate_result["retry_after"])}
)
# Connection limit check
conn_info = await conn_manager.acquire(client_ip, req.prompt[:100])
if conn_info is None:
raise HTTPException(status_code=503, detail="Too many connections")
if PROMETHEUS_AVAILABLE:
CONNECTIONS_TOTAL.labels(endpoint="/stream", status="success").inc()
return StreamingResponse(
production_stream_generator(req, conn_info),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
)
@prod_app.get("/metrics")
async def metrics_endpoint():
if not PROMETHEUS_AVAILABLE:
return {"error": "prometheus-client not installed"}
return Response(
content=generate_latest(registry),
media_type=CONTENT_TYPE_LATEST
)
@prod_app.get("/health")
async def health_check():
return {
"status": "degraded" if shutdown_manager.is_shutting_down else "healthy",
"active_connections": conn_manager.active_count,
"circuit_breaker": circuit_breaker.state.value,
"shutting_down": shutdown_manager.is_shutting_down
}
print("Production FastAPI app configured.")
print("\nRoutes:")
for route in prod_app.routes:
if hasattr(route, 'methods'):
print(f" {list(route.methods)} {route.path}")
print("\nTo run:")
print(" uvicorn notebook_prod_app:prod_app --workers 4 --port 8001")
13. Load Testing with LocustΒΆ
Validating Production ReadinessΒΆ
Locust is a Python-based load testing framework that can simulate thousands of concurrent users. For SSE streaming endpoints, standard HTTP benchmarks (wrk, ab) do not work because they expect a complete response β they cannot measure TTFT or handle the event stream protocol. Locust lets you write custom Python code that opens an SSE connection, reads events one by one, and measures TTFT and total generation time.
The load test should validate three SLOs: (1) TTFT p95 under 1 second at target concurrency, (2) error rate below 0.1% under sustained load, (3) no memory leaks over extended runs (monitor RSS via /metrics). Start with 10 users and ramp to your target (e.g., 500) over 60 seconds. Watch for the inflection point where TTFT degrades sharply β this is your practical concurrency limit per instance.
# Save to locustfile.py and run with: locust -f locustfile.py --host http://localhost:8001
LOCUSTFILE = '''
import time
import json
from locust import HttpUser, task, between, events
from locust.exception import RescheduleTask
class StreamingUser(HttpUser):
"""
Simulates a user making streaming LLM requests.
Measures TTFT separately from total response time.
"""
wait_time = between(1, 3) # Wait 1-3s between requests
@task(3)
def stream_short_query(self):
"""Short query -- should have low TTFT."""
self._do_stream_request(
prompt="What is RAG?",
name="/stream [short]"
)
@task(1)
def stream_long_query(self):
"""Long query -- higher TTFT due to retrieval."""
self._do_stream_request(
prompt="Explain in detail the differences between RAG and fine-tuning for LLMs, "
"including use cases, costs, and tradeoffs.",
name="/stream [long]"
)
def _do_stream_request(self, prompt: str, name: str):
start = time.perf_counter()
first_token_time = None
token_count = 0
total_bytes = 0
with self.client.post(
"/stream",
json={"prompt": prompt},
stream=True,
name=name,
catch_response=True
) as response:
if response.status_code == 429:
response.failure(f"Rate limited: {response.text}")
return
if response.status_code == 503:
response.failure(f"Service unavailable: {response.text}")
raise RescheduleTask()
if response.status_code != 200:
response.failure(f"HTTP {response.status_code}")
return
for line in response.iter_lines():
if not line:
continue
if isinstance(line, bytes):
line = line.decode("utf-8")
if not line.startswith("data: "):
continue
total_bytes += len(line)
try:
data = json.loads(line[6:])
except json.JSONDecodeError:
continue
if data.get("type") == "token":
if first_token_time is None:
first_token_time = time.perf_counter()
ttft_ms = (first_token_time - start) * 1000
# Report TTFT as a custom metric
events.request.fire(
request_type="TTFT",
name=name,
response_time=ttft_ms,
response_length=0,
exception=None,
context={}
)
token_count += 1
elif data.get("type") == "error":
response.failure(f"Stream error: {data}")
return
total_ms = (time.perf_counter() - start) * 1000
if token_count == 0:
response.failure("No tokens received")
else:
response.success()
# Run with:
# locust -f locustfile.py --host http://localhost:8001 --users 50 --spawn-rate 5
# Or headless:
# locust -f locustfile.py --host http://localhost:8001 \\
# --users 50 --spawn-rate 5 --run-time 60s --headless
'''
# Save the locustfile
with open("/tmp/locustfile_streaming.py", "w") as f:
f.write(LOCUSTFILE)
print("Locust load test file saved to /tmp/locustfile_streaming.py")
print()
print("Run load test with:")
print(" locust -f /tmp/locustfile_streaming.py \\")
print(" --host http://localhost:8001 \\")
print(" --users 50 --spawn-rate 5 --run-time 60s --headless")
print()
print("Key metrics to watch:")
metrics_to_watch = [
("TTFT p50", "< 300ms target"),
("TTFT p95", "< 1000ms target"),
("Request failures", "< 0.1% target"),
("Active users", "Ramp to 50 over 10s"),
("RPS", "Track peak sustainable RPS"),
]
for metric, target in metrics_to_watch:
print(f" {metric:<25} {target}")
14. Docker Compose Production SetupΒΆ
Infrastructure as Code for Streaming ServicesΒΆ
A production streaming deployment requires multiple coordinated services: the FastAPI application (the streaming server itself), Nginx (reverse proxy with SSE/WebSocket configuration), Prometheus (metrics collection and alerting), and Grafana (dashboards and visualization). Docker Compose orchestrates these services with proper networking, volume mounts for configuration files, and health checks.
Critical configuration details: the Nginx container must mount the custom SSE configuration (proxy_buffering off, extended timeouts). Prometheus scrapes the /metrics endpoint at 15-second intervals. Grafana provisions pre-built dashboards for streaming metrics (active connections, TTFT heatmap, error rate). The application container should set resource limits (memory, CPU) to prevent a single instance from consuming all host resources during traffic spikes.
DOCKER_COMPOSE = '''
# docker-compose.yml -- Production streaming setup
version: "3.9"
services:
# ---- Application (3 replicas) ----
app:
build:
context: .
dockerfile: Dockerfile
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- MAX_CONNECTIONS=200
- MAX_CONNECTIONS_PER_IP=5
- RATE_LIMIT_RPM=20
deploy:
replicas: 3
resources:
limits:
memory: 1G
cpus: "1.0"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 10s
timeout: 5s
retries: 3
start_period: 15s
restart: unless-stopped
networks:
- streaming_net
# ---- Nginx reverse proxy ----
nginx:
image: nginx:1.25-alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf:ro
- ./certs:/etc/nginx/certs:ro
depends_on:
- app
restart: unless-stopped
networks:
- streaming_net
# ---- Prometheus ----
prometheus:
image: prom/prometheus:v2.48.0
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml:ro
- prometheus_data:/prometheus
command:
- "--config.file=/etc/prometheus/prometheus.yml"
- "--storage.tsdb.retention.time=7d"
restart: unless-stopped
networks:
- streaming_net
# ---- Grafana ----
grafana:
image: grafana/grafana:10.2.0
ports:
- "3000:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
- GF_USERS_ALLOW_SIGN_UP=false
volumes:
- grafana_data:/var/lib/grafana
- ./grafana/dashboards:/var/lib/grafana/dashboards:ro
depends_on:
- prometheus
restart: unless-stopped
networks:
- streaming_net
networks:
streaming_net:
driver: bridge
volumes:
prometheus_data:
grafana_data:
'''
DOCKERFILE = '''
# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application
COPY app/ ./app/
# Non-root user for security
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser
# Expose port
EXPOSE 8000
# Graceful shutdown: uvicorn handles SIGTERM by stopping new connections
# and waiting for in-flight requests before exiting
CMD ["uvicorn", "app.main:app",
"--host", "0.0.0.0",
"--port", "8000",
"--workers", "1",
"--timeout-graceful-shutdown", "30",
"--access-log",
"--log-level", "info"]
'''
PROMETHEUS_CONFIG = '''
# prometheus.yml
global:
scrape_interval: 15s
evaluation_interval: 15s
scrape_configs:
- job_name: streaming_api
static_configs:
- targets:
- app:8000 # Scraped via Docker service DNS
metrics_path: /metrics
scrape_interval: 10s
alerting:
alertmanagers:
- static_configs:
- targets: []
rule_files:
- "alerts.yml"
'''
ALERTS_CONFIG = '''
# alerts.yml -- Prometheus alerting rules
groups:
- name: streaming_alerts
rules:
- alert: HighTTFT
expr: histogram_quantile(0.95, rate(streaming_ttft_seconds_bucket[5m])) > 2.0
for: 2m
labels:
severity: warning
annotations:
summary: "P95 TTFT > 2s"
- alert: TooManyConnections
expr: streaming_active_connections > 400
for: 1m
labels:
severity: critical
annotations:
summary: "Active connections > 400 (limit 500)"
- alert: HighErrorRate
expr: rate(streaming_errors_total[5m]) > 0.01
for: 2m
labels:
severity: warning
annotations:
summary: "Error rate > 1%"
- alert: CircuitBreakerOpen
expr: streaming_circuit_breaker_state == 2
for: 0m
labels:
severity: critical
annotations:
summary: "Circuit breaker OPEN -- LLM API down"
'''
print("=== docker-compose.yml ===")
print(DOCKER_COMPOSE)
print("=== Dockerfile ===")
print(DOCKERFILE)
print("=== prometheus.yml ===")
print(PROMETHEUS_CONFIG)
print("=== alerts.yml ===")
print(ALERTS_CONFIG)
# Useful Grafana dashboard queries (PromQL)
GRAFANA_QUERIES = {
"Active Connections": (
"streaming_active_connections",
"Current active SSE connections"
),
"TTFT p50 (ms)": (
"histogram_quantile(0.5, rate(streaming_ttft_seconds_bucket[5m])) * 1000",
"Median time to first token in ms"
),
"TTFT p95 (ms)": (
"histogram_quantile(0.95, rate(streaming_ttft_seconds_bucket[5m])) * 1000",
"95th percentile TTFT in ms"
),
"TTFT p99 (ms)": (
"histogram_quantile(0.99, rate(streaming_ttft_seconds_bucket[5m])) * 1000",
"99th percentile TTFT in ms"
),
"Tokens per second (avg)": (
"histogram_quantile(0.5, rate(streaming_tokens_per_second_bucket[5m]))",
"Median token throughput"
),
"Error rate (%)": (
"100 * rate(streaming_errors_total[5m]) / rate(streaming_connections_total[5m])",
"Percentage of connections with errors"
),
"Connection rate (per min)": (
"rate(streaming_connections_total[1m]) * 60",
"New connections per minute"
),
"Circuit breaker state": (
"streaming_circuit_breaker_state",
"0=closed, 1=half_open, 2=open"
),
}
print("Grafana dashboard PromQL queries:")
print("=" * 70)
for panel_name, (query, description) in GRAFANA_QUERIES.items():
print(f"\nPanel: {panel_name}")
print(f" Query: {query}")
print(f" Note: {description}")
SummaryΒΆ
Production Streaming ChecklistΒΆ
InfrastructureΒΆ
Nginx configured with
proxy_buffering offand longproxy_read_timeoutSticky sessions (ip_hash) for SSE
Docker health checks on
/healthendpointGraceful shutdown with 30s drain timeout
ReliabilityΒΆ
Connection limits: global (500) and per-IP (10)
Rate limiting: token bucket, 20 req/min with burst of 5
Backpressure: bounded queue, disconnect slow clients after 5s
Timeouts: first-token (30s), inter-token (15s), total (120s)
Circuit breaker: open after 5 failures in 60s, reset after 30s
Retry with exponential backoff + full jitter
ObservabilityΒΆ
Prometheus metrics: active connections, TTFT histogram, tokens/sec, error rates
Grafana dashboards for all key metrics
Alerts: TTFT p95 > 2s, connections > 400, error rate > 1%, circuit open
Structured logging with connection IDs
Load TestingΒΆ
Locust test measuring TTFT as a custom metric
Soak test: 50 users x 60 minutes
Spike test: ramp to 200 users in 30s
Performance TargetsΒΆ
Metric |
Target |
|---|---|
TTFT p50 |
< 300ms |
TTFT p95 |
< 1000ms |
TTFT p99 |
< 2000ms |
Error rate |
< 0.1% |
Max connections per instance |
500 |
Memory per connection |
< 2MB |
Graceful shutdown |
< 30s drain |
Key TakeawaysΒΆ
Fail fast β circuit breakers and rate limits protect the LLM API from overload.
Bound everything β queues, connections, timeouts. Unbounded resources cause outages.
Measure TTFT β it is the most important user-facing metric for streaming.
Drain before shutdown β never kill active streams; let them complete or time out gracefully.
Test under load β Locust reveals backpressure and memory issues that unit tests miss.