Production-Grade Streaming SystemsΒΆ

OverviewΒΆ

Getting a streaming endpoint working in a notebook is step one. Getting it to handle 10,000 concurrent users reliably, with observability and graceful failure, is an entirely different challenge.

This notebook covers the engineering required to take streaming LLM endpoints to production.

Learning ObjectivesΒΆ

  • Understand production requirements for streaming systems

  • Implement connection limits and resource management

  • Build rate limiting per user/IP with token bucket algorithm

  • Handle backpressure from slow clients

  • Add timeout management at every layer

  • Implement circuit breakers for LLM API failures

  • Add retry logic with exponential backoff

  • Instrument Prometheus metrics (TTFT, tokens/sec, active connections)

  • Implement graceful shutdown and connection draining

  • Load test with Locust

  • Deploy with Docker Compose

PrerequisitesΒΆ

pip install fastapi uvicorn prometheus-client locust httpx python-dotenv

1. Production RequirementsΒΆ

What Makes Streaming Hard in ProductionΒΆ

Challenge

Description

Solution

Long-lived connections

SSE/WS hold sockets open for 10-60s

Connection limits + timeouts

Slow clients

Client reads slower than LLM writes

Backpressure / bounded queues

LLM API flakiness

OpenAI/Anthropic have ~99.9% uptime

Circuit breakers + retries

Memory leaks

Each connection allocates buffers

Explicit cleanup + weak refs

Load balancer issues

Default LBs close idle connections

Nginx SSE config + keepalive

No observability

Can’t see TTFT or error rates

Prometheus metrics

Thundering herd

LLM returns -> all clients wake

Staggered retries

Target SLOsΒΆ

TTFT (Time to First Token):  p50 < 300ms,  p95 < 1000ms,  p99 < 2000ms
Total generation time:        p50 < 5s,     p95 < 15s
Error rate:                   < 0.1%
Active connections:           < 500 per instance
Memory per connection:        < 2MB
# Install dependencies (uncomment if needed)
# !pip install fastapi uvicorn prometheus-client locust httpx python-dotenv

import asyncio
import json
import os
import time
import signal
import weakref
import logging
import traceback
from collections import defaultdict, deque
from dataclasses import dataclass, field
from enum import Enum
from typing import AsyncGenerator, Callable, Dict, List, Optional, Set

try:
    from fastapi import FastAPI, HTTPException, Request, Response
    from fastapi.responses import StreamingResponse
    from pydantic import BaseModel
    import uvicorn
    FASTAPI_AVAILABLE = True
except ImportError:
    FASTAPI_AVAILABLE = False
    print("FastAPI not installed: pip install fastapi uvicorn")

try:
    from prometheus_client import (
        Counter, Histogram, Gauge, Summary,
        generate_latest, CONTENT_TYPE_LATEST,
        CollectorRegistry
    )
    PROMETHEUS_AVAILABLE = True
except ImportError:
    PROMETHEUS_AVAILABLE = False
    print("prometheus-client not installed: pip install prometheus-client")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("prod_streaming")

print("Imports complete.")
print(f"FastAPI:    {'OK' if FASTAPI_AVAILABLE else 'MISSING'}")
print(f"Prometheus: {'OK' if PROMETHEUS_AVAILABLE else 'MISSING'}")

2. Load Balancing for SSE and WebSocketsΒΆ

Why Standard Load Balancers Fail for StreamingΒΆ

Standard HTTP load balancers (ALB, HAProxy defaults) are optimized for short-lived request-response cycles: they distribute requests round-robin, enforce 30-second timeouts, and buffer responses for compression. SSE and WebSocket connections violate all of these assumptions – they hold sockets open for minutes, produce data incrementally, and must not be buffered. Nginx requires explicit configuration: proxy_buffering off prevents nginx from accumulating tokens before forwarding, ip_hash ensures sticky sessions so reconnecting clients hit the same backend, and proxy_read_timeout must be extended to 300+ seconds.

Why sticky sessions matter: if a client disconnects and reconnects (common on mobile networks), a non-sticky load balancer may route them to a different backend that has no context about their previous conversation. With ip_hash or cookie-based affinity, the client returns to the same server. For WebSocket upgrades, the Connection: upgrade and Upgrade headers must be explicitly proxied – nginx will not forward them by default.

NGINX_SSE_CONFIG = """
# nginx.conf -- Production SSE/WebSocket streaming config

upstream streaming_backend {
    # Sticky sessions: same client -> same backend (important for SSE)
    # Use ip_hash OR cookie-based sticky sessions
    ip_hash;

    server app1:8000 max_fails=3 fail_timeout=30s;
    server app2:8000 max_fails=3 fail_timeout=30s;
    server app3:8000 max_fails=3 fail_timeout=30s;

    # Keep connections to upstream alive
    keepalive 32;
}

server {
    listen 80;
    server_name api.example.com;

    # SSE endpoint
    location /stream {
        proxy_pass http://streaming_backend;

        # Critical for SSE: disable buffering so tokens reach client immediately
        proxy_buffering off;
        proxy_cache off;

        # Long timeouts for streaming connections
        proxy_read_timeout 300s;   # 5 minutes max stream
        proxy_send_timeout 300s;
        proxy_connect_timeout 10s;

        # SSE headers
        proxy_set_header Connection '';
        proxy_http_version 1.1;
        chunked_transfer_encoding on;

        # Pass real client IP
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;

        # Prevent nginx from closing idle SSE connections
        keepalive_timeout 300s;
    }

    # WebSocket endpoint
    location /ws {
        proxy_pass http://streaming_backend;
        proxy_http_version 1.1;

        # Required for WebSocket upgrade
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection "upgrade";

        proxy_read_timeout 3600s;  # 1 hour for WS
        proxy_buffering off;
    }

    # Regular API endpoints (short timeout)
    location / {
        proxy_pass http://streaming_backend;
        proxy_read_timeout 30s;
        proxy_buffering on;
    }
}
"""

print("Nginx configuration for SSE/WebSocket:")
print(NGINX_SSE_CONFIG)

print("Key settings explained:")
settings = [
    ("proxy_buffering off",     "Tokens reach client immediately, not buffered"),
    ("ip_hash",                 "Same client always goes to same backend"),
    ("proxy_read_timeout 300s", "Allow streams up to 5 minutes"),
    ("keepalive 32",            "Reuse connections to upstream (saves TCP overhead)"),
    ("Upgrade + Connection",    "Required headers for WebSocket protocol upgrade"),
]
for setting, explanation in settings:
    print(f"  {setting:<30} -> {explanation}")

3. Connection Limits and Resource ManagementΒΆ

Preventing Resource ExhaustionΒΆ

Each streaming connection consumes a file descriptor (socket), a coroutine (or thread), and potentially megabytes of buffer memory. On a typical Linux server, the default ulimit is 1024 file descriptors; a single traffic spike can exhaust this and cause the server to reject all new connections – including health checks, which triggers cascading failures in container orchestrators.

The ConnectionManager class implements two-tier limits: a global maximum (e.g., 500 connections) and a per-IP maximum (e.g., 10 connections) to prevent a single client from monopolizing resources. It also enforces maximum connection age to evict zombie connections that stopped consuming data but never sent a TCP FIN. The semaphore-based approach ensures that limit checks and connection registration are atomic, preventing race conditions under concurrent load.

import asyncio
import time
import uuid
from dataclasses import dataclass, field
from typing import Dict, Optional, Set


@dataclass
class ConnectionInfo:
    id: str
    client_ip: str
    query: str
    started_at: float = field(default_factory=time.time)
    tokens_sent: int = 0
    bytes_sent: int = 0

    def age_seconds(self) -> float:
        return time.time() - self.started_at


class ConnectionManager:
    """
    Manages the lifecycle of streaming connections.
    Enforces global and per-IP connection limits.
    """

    def __init__(
        self,
        max_global: int = 500,
        max_per_ip: int = 10,
        max_connection_age_s: float = 300.0
    ):
        self.max_global = max_global
        self.max_per_ip = max_per_ip
        self.max_age = max_connection_age_s

        self._connections: Dict[str, ConnectionInfo] = {}
        self._per_ip: Dict[str, Set[str]] = defaultdict(set)
        self._semaphore = asyncio.Semaphore(max_global)
        self._lock = asyncio.Lock()

    @property
    def active_count(self) -> int:
        return len(self._connections)

    def connections_for_ip(self, ip: str) -> int:
        return len(self._per_ip.get(ip, set()))

    async def acquire(self, client_ip: str, query: str) -> Optional[ConnectionInfo]:
        """
        Try to register a new connection.
        Returns ConnectionInfo on success, None if limits are exceeded.
        """
        async with self._lock:
            # Check per-IP limit
            if self.connections_for_ip(client_ip) >= self.max_per_ip:
                logger.warning(f"Per-IP limit reached for {client_ip}")
                return None

            # Check global limit (non-blocking)
            if self.active_count >= self.max_global:
                logger.warning(f"Global connection limit reached ({self.max_global})")
                return None

            conn_id = str(uuid.uuid4())[:8]
            info = ConnectionInfo(
                id=conn_id,
                client_ip=client_ip,
                query=query
            )
            self._connections[conn_id] = info
            self._per_ip[client_ip].add(conn_id)
            logger.info(f"Connection {conn_id} opened from {client_ip} (total: {self.active_count})")
            return info

    async def release(self, conn_id: str):
        """Release a connection and free its resources."""
        async with self._lock:
            info = self._connections.pop(conn_id, None)
            if info:
                self._per_ip[info.client_ip].discard(conn_id)
                if not self._per_ip[info.client_ip]:
                    del self._per_ip[info.client_ip]
                logger.info(
                    f"Connection {conn_id} closed. "
                    f"Age={info.age_seconds():.1f}s, tokens={info.tokens_sent} "
                    f"(remaining: {self.active_count})"
                )

    async def evict_stale(self) -> int:
        """Evict connections that have exceeded max age."""
        async with self._lock:
            stale = [
                conn_id for conn_id, info in self._connections.items()
                if info.age_seconds() > self.max_age
            ]
        for conn_id in stale:
            await self.release(conn_id)
        return len(stale)

    def get_stats(self) -> Dict:
        return {
            "active_connections": self.active_count,
            "max_global": self.max_global,
            "utilization_pct": round(self.active_count / self.max_global * 100, 1),
            "unique_ips": len(self._per_ip),
            "connections_by_ip": {ip: len(ids) for ip, ids in self._per_ip.items()}
        }


# --- Demo ---
async def demo_connection_manager():
    mgr = ConnectionManager(max_global=5, max_per_ip=2)

    # Acquire some connections
    conns = []
    for i in range(4):
        ip = f"10.0.0.{i // 2 + 1}"  # 2 connections per IP
        conn = await mgr.acquire(ip, f"query {i}")
        if conn:
            conns.append(conn)
            print(f"  Opened {conn.id} from {conn.client_ip}")
        else:
            print(f"  Rejected connection from {ip} (limit reached)")

    # Try to exceed per-IP limit
    print("\nTrying to exceed per-IP limit for 10.0.0.1:")
    result = await mgr.acquire("10.0.0.1", "extra query")
    print(f"  Result: {'rejected' if result is None else 'accepted (unexpected)'}")

    print(f"\nStats: {json.dumps(mgr.get_stats(), indent=2)}")

    # Release one and try again
    if conns:
        await mgr.release(conns[0].id)
        print(f"\nAfter releasing {conns[0].id}:")
        print(f"  Active: {mgr.active_count}")

await demo_connection_manager()

4. Rate Limiting per User/IPΒΆ

Token Bucket Algorithm for Streaming EndpointsΒΆ

The token bucket algorithm is the standard approach for rate limiting API endpoints. Each client IP gets a bucket with a fixed capacity (burst size) that refills at a constant rate (tokens per second). Each request consumes one or more tokens; if the bucket is empty, the request is rejected with HTTP 429. The bucket allows short bursts up to its capacity while enforcing a long-term average rate.

Formally, at time \(t\) the bucket holds \(\min(C, B(t_0) + r \cdot (t - t_0))\) tokens, where \(C\) is capacity, \(r\) is the refill rate, and \(B(t_0)\) was the count at the last check. For streaming endpoints, rate limiting is especially important because each connection is expensive (long-lived, memory-holding), so even a moderate burst of abusive connections can degrade service for legitimate users. The retry_after_seconds method tells clients exactly how long to wait, enabling well-behaved clients to back off gracefully.

class TokenBucket:
    """
    Thread-safe token bucket for rate limiting.
    capacity: max tokens (burst size)
    refill_rate: tokens added per second
    """

    def __init__(self, capacity: float, refill_rate: float):
        self.capacity = capacity
        self.refill_rate = refill_rate
        self.tokens = capacity
        self.last_refill = time.monotonic()

    def _refill(self):
        now = time.monotonic()
        elapsed = now - self.last_refill
        added = elapsed * self.refill_rate
        self.tokens = min(self.capacity, self.tokens + added)
        self.last_refill = now

    def consume(self, tokens: float = 1.0) -> bool:
        """Try to consume tokens. Returns True if allowed, False if rate limited."""
        self._refill()
        if self.tokens >= tokens:
            self.tokens -= tokens
            return True
        return False

    def tokens_remaining(self) -> float:
        self._refill()
        return round(self.tokens, 2)

    def retry_after_seconds(self, tokens_needed: float = 1.0) -> float:
        """How many seconds until enough tokens are available."""
        self._refill()
        deficit = tokens_needed - self.tokens
        if deficit <= 0:
            return 0.0
        return round(deficit / self.refill_rate, 2)


class RateLimiter:
    """
    Per-IP rate limiter using token buckets.
    Automatically cleans up buckets for inactive IPs.
    """

    def __init__(
        self,
        requests_per_minute: float = 20,
        burst_size: float = 5,
        cleanup_interval_s: float = 300
    ):
        self.refill_rate = requests_per_minute / 60.0  # tokens/sec
        self.burst_size = burst_size
        self.cleanup_interval = cleanup_interval_s

        self._buckets: Dict[str, TokenBucket] = {}
        self._last_seen: Dict[str, float] = {}
        self._lock = asyncio.Lock()

    def _get_or_create_bucket(self, key: str) -> TokenBucket:
        if key not in self._buckets:
            self._buckets[key] = TokenBucket(
                capacity=self.burst_size,
                refill_rate=self.refill_rate
            )
        self._last_seen[key] = time.monotonic()
        return self._buckets[key]

    async def check(self, client_ip: str, cost: float = 1.0) -> Dict:
        """
        Check rate limit for a client.
        Returns dict with allowed, remaining_tokens, retry_after.
        """
        async with self._lock:
            bucket = self._get_or_create_bucket(client_ip)
            allowed = bucket.consume(cost)
            return {
                "allowed": allowed,
                "remaining": bucket.tokens_remaining(),
                "limit": self.burst_size,
                "retry_after": 0.0 if allowed else bucket.retry_after_seconds(cost)
            }

    async def cleanup_stale(self) -> int:
        """Remove buckets for IPs not seen recently."""
        now = time.monotonic()
        async with self._lock:
            stale_keys = [
                k for k, last in self._last_seen.items()
                if now - last > self.cleanup_interval
            ]
            for k in stale_keys:
                del self._buckets[k]
                del self._last_seen[k]
        return len(stale_keys)


# --- Demo ---
async def demo_rate_limiter():
    # 10 req/min, burst of 3
    limiter = RateLimiter(requests_per_minute=10, burst_size=3)

    ip = "192.168.1.100"
    print(f"Rate limit: 10 req/min, burst=3")
    print(f"Sending 6 rapid requests from {ip}:")
    print("-" * 50)

    for i in range(6):
        result = await limiter.check(ip)
        status = "ALLOWED" if result["allowed"] else f"BLOCKED (retry in {result['retry_after']}s)"
        print(f"  Request {i+1}: {status} | tokens_remaining={result['remaining']}")

    # Simulate time passing -- tokens refill
    print("\nWaiting 6 seconds for tokens to refill...")
    await asyncio.sleep(6)

    result = await limiter.check(ip)
    print(f"After 6s: {result}")

await demo_rate_limiter()

5. Backpressure Handling (Slow Clients)ΒΆ

Managing Producer-Consumer Speed MismatchΒΆ

When an LLM produces tokens at 50-100 tokens/second but a client on a slow mobile connection can only consume at 10 tokens/second, the server must handle the resulting backpressure. There are three strategies: (1) unbounded buffering (accumulate all tokens in memory – dangerous, causes OOM), (2) dropping tokens (fast but corrupts the response), or (3) bounded queue with disconnect (safest for production).

The bounded asyncio.Queue pattern sets a maxsize (e.g., 100 tokens). The producer puts tokens with a timeout; if the queue is full for longer than the threshold, the server assumes the client is too slow and closes the connection cleanly. This prevents memory exhaustion while giving slow clients a fair chance to catch up. The queue depth also serves as a real-time health metric: a consistently full queue indicates the client is at capacity.

class BackpressureError(Exception):
    """Raised when client is too slow to consume the stream."""


class BackpressureStream:
    """
    Wraps a source async generator with backpressure detection.

    If the internal buffer fills up (client not reading fast enough),
    raises BackpressureError to terminate the connection cleanly.
    """

    def __init__(
        self,
        source: AsyncGenerator,
        queue_maxsize: int = 50,
        slow_client_timeout_s: float = 5.0
    ):
        self.source = source
        self.queue_maxsize = queue_maxsize
        self.slow_client_timeout = slow_client_timeout_s
        self._queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
        self._producer_task: Optional[asyncio.Task] = None
        self.tokens_dropped = 0
        self.tokens_sent = 0
        self.backpressure_events = 0

    async def _producer(self):
        """Reads from source and fills the queue."""
        try:
            async for item in self.source:
                try:
                    # Wait at most slow_client_timeout for space in queue
                    await asyncio.wait_for(
                        self._queue.put(item),
                        timeout=self.slow_client_timeout
                    )
                except asyncio.TimeoutError:
                    self.backpressure_events += 1
                    logger.warning(
                        f"Backpressure: queue full after {self.slow_client_timeout}s. "
                        f"Terminating slow client. (event #{self.backpressure_events})"
                    )
                    await self._queue.put(BackpressureError("Client too slow"))
                    return
        except Exception as e:
            await self._queue.put(e)
        finally:
            await self._queue.put(None)  # Sentinel: stream done

    async def stream(self) -> AsyncGenerator:
        """Async generator that yields items with backpressure detection."""
        self._producer_task = asyncio.create_task(self._producer())
        try:
            while True:
                item = await self._queue.get()
                if item is None:
                    break
                if isinstance(item, BackpressureError):
                    raise item
                if isinstance(item, Exception):
                    raise item
                self.tokens_sent += 1
                yield item
        finally:
            if self._producer_task and not self._producer_task.done():
                self._producer_task.cancel()

    def get_stats(self) -> Dict:
        return {
            "queue_size": self._queue.qsize(),
            "queue_maxsize": self.queue_maxsize,
            "tokens_sent": self.tokens_sent,
            "tokens_dropped": self.tokens_dropped,
            "backpressure_events": self.backpressure_events
        }


# --- Demo ---
async def fast_producer() -> AsyncGenerator[str, None]:
    """Produces tokens very fast (simulates a fast LLM)."""
    for i in range(20):
        yield f"token_{i} "
        await asyncio.sleep(0.01)  # 100 tokens/sec

async def demo_backpressure():
    print("Demo 1: Fast client (no backpressure)")
    bp = BackpressureStream(fast_producer(), queue_maxsize=5, slow_client_timeout_s=1.0)
    tokens = []
    async for token in bp.stream():
        tokens.append(token)
        await asyncio.sleep(0.02)  # Client reads at 50/sec -- can keep up
    print(f"  Received {len(tokens)} tokens. Stats: {bp.get_stats()}")

    print("\nDemo 2: Slow client (backpressure triggered)")
    bp2 = BackpressureStream(fast_producer(), queue_maxsize=3, slow_client_timeout_s=0.2)
    tokens2 = []
    try:
        async for token in bp2.stream():
            tokens2.append(token)
            await asyncio.sleep(0.5)  # Client reads at 2/sec -- too slow
    except BackpressureError as e:
        print(f"  BackpressureError raised after {len(tokens2)} tokens: {e}")
    print(f"  Stats: {bp2.get_stats()}")

await demo_backpressure()

6. Timeout Management at Every LayerΒΆ

Defense Against Silent HangsΒΆ

Streaming systems have multiple timeout boundaries that must all be configured correctly: client connection timeout, nginx proxy_read_timeout, application-level stream timeout, and LLM API timeout. If any single layer has an infinite or missing timeout, a stalled connection hangs silently – consuming resources, holding a connection slot, and eventually triggering cascading failures when the connection pool fills up.

The correct approach is layered timeouts from outside in: nginx at 300s (outermost), application stream at 240s, LLM API call at 120s, individual chunk wait at 30s. Each inner timeout should be shorter than the outer one, so the application can send a clean error response rather than having nginx cut the connection abruptly. For TTFT specifically, a separate short timeout (e.g., 10s) catches cases where the LLM never starts generating.

@dataclass
class TimeoutConfig:
    """Timeout configuration for all layers of a streaming request."""
    # How long to wait for the LLM API to start responding
    llm_connect_s: float = 10.0
    # How long to wait for the first token after generation starts
    first_token_s: float = 30.0
    # Maximum time between any two consecutive tokens
    inter_token_s: float = 15.0
    # Maximum total generation time
    total_generation_s: float = 120.0
    # Maximum total request time (includes retrieval + generation)
    total_request_s: float = 180.0


class TimeoutManager:
    """
    Enforces timeouts at every layer of a streaming request.
    """

    def __init__(self, config: TimeoutConfig):
        self.config = config
        self.request_start = time.monotonic()
        self.first_token_received = False
        self.last_token_time = time.monotonic()
        self.token_count = 0

    def check_total_request_timeout(self):
        elapsed = time.monotonic() - self.request_start
        if elapsed > self.config.total_request_s:
            raise asyncio.TimeoutError(
                f"Total request timeout ({self.config.total_request_s}s) exceeded after {elapsed:.1f}s"
            )

    def check_first_token_timeout(self):
        if not self.first_token_received:
            elapsed = time.monotonic() - self.request_start
            if elapsed > self.config.first_token_s:
                raise asyncio.TimeoutError(
                    f"First token timeout ({self.config.first_token_s}s). No tokens after {elapsed:.1f}s"
                )

    def on_token(self):
        self.first_token_received = True
        now = time.monotonic()
        gap = now - self.last_token_time
        self.last_token_time = now
        self.token_count += 1
        if gap > self.config.inter_token_s:
            raise asyncio.TimeoutError(
                f"Inter-token timeout: {gap:.1f}s gap (max {self.config.inter_token_s}s)"
            )

    async def wrap_stream(
        self,
        source: AsyncGenerator
    ) -> AsyncGenerator:
        """Wrap a source stream with all timeout checks."""
        try:
            async with asyncio.timeout(self.config.total_request_s):
                async for item in source:
                    self.on_token()
                    yield item
        except asyncio.TimeoutError as e:
            logger.error(f"Stream timeout: {e}")
            yield {"type": "error", "code": "timeout", "message": str(e)}


# --- Demo ---
async def slow_llm_stream() -> AsyncGenerator[Dict, None]:
    """Simulates an LLM that pauses mid-generation."""
    for i in range(5):
        yield {"type": "token", "content": f"word{i} "}
        await asyncio.sleep(0.1)
    # Simulate a stall
    await asyncio.sleep(2.0)  # This will trigger inter-token timeout
    yield {"type": "token", "content": "never_seen "}

async def demo_timeouts():
    config = TimeoutConfig(
        first_token_s=5.0,
        inter_token_s=1.0,  # 1s max between tokens
        total_request_s=30.0
    )
    mgr = TimeoutManager(config)
    print("Streaming with inter-token timeout of 1.0s:")

    tokens_received = []
    async for event in mgr.wrap_stream(slow_llm_stream()):
        print(f"  Event: {event}")
        if event.get("type") == "token":
            tokens_received.append(event["content"])

    print(f"\nTokens received before timeout: {tokens_received}")
    print(f"Token count: {mgr.token_count}")

await demo_timeouts()

7. Memory Management for Long-Lived ConnectionsΒΆ

Preventing Memory Leaks in Streaming ServicesΒΆ

Each streaming connection holds references to buffers, LLM client objects, callback handlers, and accumulated response text. In a typical 30-second stream generating 500 tokens, a single connection may hold 1-2MB of live objects. Without explicit cleanup, connections that have logically closed (client disconnected) can still hold references that prevent garbage collection – a classic memory leak pattern in async Python.

Weak references (weakref.ref) are the key defense: store connection metadata in a WeakValueDictionary so that entries are automatically removed when the connection object is garbage collected. Additionally, every async generator must have a finally block that explicitly clears buffers and closes client sessions. The combination of weak references, explicit cleanup, and periodic stale-connection eviction keeps memory usage bounded even under sustained high concurrency.

import gc
import sys
import weakref
from contextlib import asynccontextmanager


@dataclass
class StreamContext:
    """Per-connection context. Holds all resources for one streaming session."""
    conn_id: str
    client_ip: str
    buffer: bytearray = field(default_factory=lambda: bytearray(4096))
    token_history: List[str] = field(default_factory=list)
    metadata: Dict = field(default_factory=dict)

    def cleanup(self):
        """Explicitly release large resources."""
        del self.buffer
        self.buffer = bytearray(0)
        self.token_history.clear()
        self.metadata.clear()

    def size_bytes(self) -> int:
        return sys.getsizeof(self.buffer) + sum(sys.getsizeof(t) for t in self.token_history)


class MemoryAwareConnectionRegistry:
    """
    Tracks active connections using WeakReferences.
    When a StreamContext is garbage collected, it is automatically
    removed from the registry -- no manual cleanup needed.
    """

    def __init__(self, max_memory_mb: float = 500.0):
        self.max_memory_bytes = max_memory_mb * 1024 * 1024
        self._registry: Dict[str, weakref.ref] = {}

    def register(self, ctx: StreamContext):
        """Register a context using a weak reference."""
        def on_gc(ref):
            self._registry.pop(ctx.conn_id, None)
            logger.debug(f"Context {ctx.conn_id} garbage collected")

        self._registry[ctx.conn_id] = weakref.ref(ctx, on_gc)

    def active_contexts(self) -> List[StreamContext]:
        """Return all live context objects."""
        live = []
        for ref in list(self._registry.values()):
            ctx = ref()
            if ctx is not None:
                live.append(ctx)
        return live

    def total_memory_bytes(self) -> int:
        return sum(c.size_bytes() for c in self.active_contexts())

    def check_memory_pressure(self) -> bool:
        """Return True if we are approaching memory limits."""
        used = self.total_memory_bytes()
        if used > self.max_memory_bytes * 0.9:
            logger.warning(
                f"Memory pressure: {used / 1024 / 1024:.1f}MB / "
                f"{self.max_memory_bytes / 1024 / 1024:.1f}MB"
            )
            return True
        return False

    def force_cleanup_oldest(self, n: int = 5):
        """Force cleanup of the N oldest contexts under memory pressure."""
        contexts = sorted(self.active_contexts(), key=lambda c: c.conn_id)
        for ctx in contexts[:n]:
            ctx.cleanup()
        gc.collect()
        logger.info(f"Force-cleaned {n} contexts. Remaining: {len(self.active_contexts())}")


@asynccontextmanager
async def managed_stream_context(
    registry: MemoryAwareConnectionRegistry,
    client_ip: str
):
    """Context manager that creates, registers, and cleans up a StreamContext."""
    conn_id = str(uuid.uuid4())[:8]
    ctx = StreamContext(conn_id=conn_id, client_ip=client_ip)
    registry.register(ctx)
    logger.info(f"Stream context {conn_id} created")
    try:
        yield ctx
    finally:
        ctx.cleanup()
        logger.info(f"Stream context {conn_id} cleaned up")


# --- Demo ---
async def demo_memory_management():
    registry = MemoryAwareConnectionRegistry(max_memory_mb=1.0)

    # Create contexts and simulate streaming
    contexts = []
    for i in range(3):
        async with managed_stream_context(registry, f"10.0.0.{i+1}") as ctx:
            # Simulate accumulating tokens
            for j in range(10):
                ctx.token_history.append(f"token_{j}")
            contexts.append(ctx)
            print(f"  Active context {ctx.conn_id}: {ctx.size_bytes()} bytes")
        # After exiting context manager, cleanup() is called

    gc.collect()
    remaining = registry.active_contexts()
    print(f"\nAfter all contexts exited: {len(remaining)} active")
    print(f"Total memory: {registry.total_memory_bytes()} bytes")

await demo_memory_management()

8. Circuit Breaker for LLM API FailuresΒΆ

Fail Fast Instead of Fail SlowΒΆ

When the LLM API is experiencing an outage (returning 500 errors or timing out), the worst thing a streaming service can do is keep sending requests – each one consumes a connection slot for the full timeout duration before failing. The circuit breaker pattern (from Michael Nygard’s β€œRelease It!”) addresses this with three states: closed (normal operation), open (all requests immediately fail with a cached error), and half-open (a single test request probes whether the API has recovered).

The circuit opens after \(N\) consecutive failures (e.g., 5) and remains open for a configurable duration (e.g., 30s). During this time, incoming requests receive an instant 503 response instead of waiting 120 seconds for a timeout. This protects both the streaming service (connection pool stays available) and the LLM API (reduced load during recovery). The half-open state sends one probe request; if it succeeds, the circuit closes and normal traffic resumes.

class CircuitState(Enum):
    CLOSED = "closed"       # Normal -- requests pass through
    OPEN = "open"           # Failing -- reject requests immediately
    HALF_OPEN = "half_open" # Testing recovery -- allow one probe request


class CircuitBreakerOpen(Exception):
    """Raised when the circuit breaker is open."""


class CircuitBreaker:
    """
    Circuit breaker for LLM API calls.

    State transitions:
      CLOSED  -> OPEN:      failure_threshold failures in failure_window_s
      OPEN    -> HALF_OPEN: after reset_timeout_s
      HALF_OPEN -> CLOSED:  probe request succeeds
      HALF_OPEN -> OPEN:    probe request fails
    """

    def __init__(
        self,
        name: str = "llm_api",
        failure_threshold: int = 5,
        failure_window_s: float = 60.0,
        reset_timeout_s: float = 30.0,
        success_threshold: int = 2  # Successes needed in HALF_OPEN to CLOSE
    ):
        self.name = name
        self.failure_threshold = failure_threshold
        self.failure_window_s = failure_window_s
        self.reset_timeout_s = reset_timeout_s
        self.success_threshold = success_threshold

        self._state = CircuitState.CLOSED
        self._failure_times: deque = deque()
        self._opened_at: Optional[float] = None
        self._half_open_successes = 0
        self._total_calls = 0
        self._total_failures = 0
        self._total_rejected = 0
        self._lock = asyncio.Lock()

    @property
    def state(self) -> CircuitState:
        return self._state

    def _prune_failures(self):
        """Remove failures outside the sliding window."""
        cutoff = time.monotonic() - self.failure_window_s
        while self._failure_times and self._failure_times[0] < cutoff:
            self._failure_times.popleft()

    async def _check_state(self):
        """Check if we should transition OPEN -> HALF_OPEN."""
        if self._state == CircuitState.OPEN:
            if time.monotonic() - self._opened_at >= self.reset_timeout_s:
                self._state = CircuitState.HALF_OPEN
                self._half_open_successes = 0
                logger.info(f"CircuitBreaker '{self.name}': OPEN -> HALF_OPEN")

    async def call(self, func: Callable, *args, **kwargs):
        """
        Execute a function through the circuit breaker.
        Raises CircuitBreakerOpen if the circuit is OPEN.
        """
        async with self._lock:
            await self._check_state()
            if self._state == CircuitState.OPEN:
                self._total_rejected += 1
                raise CircuitBreakerOpen(
                    f"Circuit '{self.name}' is OPEN. "
                    f"Retry after {self.reset_timeout_s}s."
                )

        self._total_calls += 1
        try:
            result = await func(*args, **kwargs)
            async with self._lock:
                if self._state == CircuitState.HALF_OPEN:
                    self._half_open_successes += 1
                    if self._half_open_successes >= self.success_threshold:
                        self._state = CircuitState.CLOSED
                        self._failure_times.clear()
                        logger.info(f"CircuitBreaker '{self.name}': HALF_OPEN -> CLOSED")
            return result

        except Exception as e:
            self._total_failures += 1
            async with self._lock:
                now = time.monotonic()
                self._failure_times.append(now)
                self._prune_failures()

                if self._state == CircuitState.HALF_OPEN:
                    self._state = CircuitState.OPEN
                    self._opened_at = now
                    logger.error(f"CircuitBreaker '{self.name}': HALF_OPEN -> OPEN (probe failed)")
                elif len(self._failure_times) >= self.failure_threshold:
                    self._state = CircuitState.OPEN
                    self._opened_at = now
                    logger.error(
                        f"CircuitBreaker '{self.name}': CLOSED -> OPEN "
                        f"({len(self._failure_times)} failures in {self.failure_window_s}s)"
                    )
            raise

    def get_stats(self) -> Dict:
        self._prune_failures()
        return {
            "name": self.name,
            "state": self._state.value,
            "recent_failures": len(self._failure_times),
            "total_calls": self._total_calls,
            "total_failures": self._total_failures,
            "total_rejected": self._total_rejected,
            "error_rate_pct": round(
                self._total_failures / max(self._total_calls, 1) * 100, 1
            )
        }


# --- Demo ---
cb = CircuitBreaker(name="openai", failure_threshold=3, failure_window_s=10, reset_timeout_s=2)

call_count = 0

async def flaky_llm_call():
    global call_count
    call_count += 1
    if call_count <= 4:  # First 4 calls fail
        raise Exception("OpenAI 503 Service Unavailable")
    return "LLM response text"

async def demo_circuit_breaker():
    for i in range(8):
        try:
            result = await cb.call(flaky_llm_call)
            print(f"  Call {i+1}: SUCCESS -> '{result}'")
        except CircuitBreakerOpen as e:
            print(f"  Call {i+1}: CIRCUIT OPEN (fast fail)")
        except Exception as e:
            print(f"  Call {i+1}: FAILED -> {e}")

        print(f"    State: {cb.state.value} | Stats: {cb.get_stats()}")
        await asyncio.sleep(0.3)

    # Wait for reset timeout
    print("\nWaiting 2.5s for circuit to allow probe...")
    await asyncio.sleep(2.5)
    try:
        result = await cb.call(flaky_llm_call)
        print(f"  Probe: SUCCESS -> '{result}' | State: {cb.state.value}")
    except Exception as e:
        print(f"  Probe: {e} | State: {cb.state.value}")

await demo_circuit_breaker()

9. Retry Logic with Exponential BackoffΒΆ

Surviving Transient FailuresΒΆ

Transient failures – HTTP 429 (rate limited), 503 (service unavailable), and network timeouts – are routine in production LLM systems. The standard recovery pattern is exponential backoff with jitter: wait \(2^n + \text{random}(0, 1)\) seconds before retry \(n\), up to a maximum of 3-5 attempts. The jitter term prevents the thundering herd problem where thousands of clients that failed simultaneously all retry at exactly the same moment.

For streaming endpoints, retry logic has an important nuance: you can only retry before the first token is sent. Once token streaming has begun, retrying would produce a duplicate or corrupted response. The implementation should track whether any tokens have been yielded and fall back to an error event (rather than retry) if the stream has already started.

import random

class RetryConfig:
    def __init__(
        self,
        max_retries: int = 3,
        base_delay_s: float = 1.0,
        max_delay_s: float = 30.0,
        exponential_base: float = 2.0,
        jitter: bool = True,
        retryable_exceptions: tuple = (Exception,)
    ):
        self.max_retries = max_retries
        self.base_delay_s = base_delay_s
        self.max_delay_s = max_delay_s
        self.exponential_base = exponential_base
        self.jitter = jitter
        self.retryable_exceptions = retryable_exceptions

    def delay_for_attempt(self, attempt: int) -> float:
        """Calculate delay for retry attempt N (0-indexed)."""
        delay = self.base_delay_s * (self.exponential_base ** attempt)
        delay = min(delay, self.max_delay_s)
        if self.jitter:
            # Full jitter: random in [0, delay]
            delay = random.uniform(0, delay)
        return delay


async def retry_with_backoff(
    func: Callable,
    config: RetryConfig,
    *args,
    **kwargs
):
    """
    Execute func with exponential backoff retry.
    Raises the last exception if all retries are exhausted.
    """
    last_exception = None

    for attempt in range(config.max_retries + 1):
        try:
            result = await func(*args, **kwargs)
            if attempt > 0:
                logger.info(f"Succeeded on attempt {attempt + 1}")
            return result

        except config.retryable_exceptions as e:
            last_exception = e

            if attempt == config.max_retries:
                logger.error(f"All {config.max_retries + 1} attempts failed. Last error: {e}")
                break

            delay = config.delay_for_attempt(attempt)
            logger.warning(
                f"Attempt {attempt + 1}/{config.max_retries + 1} failed: {e}. "
                f"Retrying in {delay:.2f}s..."
            )
            await asyncio.sleep(delay)

    raise last_exception


# --- Demo ---
attempt_counter = 0

async def unreliable_api_call():
    global attempt_counter
    attempt_counter += 1
    if attempt_counter < 3:
        raise ConnectionError(f"Network error (attempt {attempt_counter})")
    return {"response": "success", "attempt": attempt_counter}

config = RetryConfig(
    max_retries=4,
    base_delay_s=0.1,  # Short for demo
    max_delay_s=2.0,
    jitter=False  # Deterministic for demo
)

print("Retry schedule (no jitter):")
for i in range(4):
    delay = config.delay_for_attempt(i)
    print(f"  Attempt {i+1} fails -> wait {delay:.2f}s before attempt {i+2}")

print("\nRunning with unreliable API:")
result = await retry_with_backoff(unreliable_api_call, config)
print(f"Result: {result}")

10. Prometheus Metrics IntegrationΒΆ

Observability for Streaming SystemsΒΆ

Streaming services require specialized metrics beyond standard HTTP request counters. The four critical metrics are: (1) active_connections (Gauge) – how many streams are open right now, (2) ttft_seconds (Histogram) – time-to-first-token distribution for latency SLOs, (3) tokens_per_second (Summary) – generation throughput per stream, and (4) stream_errors_total (Counter) – error rate by type (timeout, disconnect, LLM failure).

Prometheus Histogram buckets should be tuned for streaming latencies: TTFT buckets at [0.1, 0.25, 0.5, 1.0, 2.0, 5.0] seconds, and total stream duration at [1, 5, 10, 30, 60, 120] seconds. These metrics feed directly into Grafana dashboards and PagerDuty alerts. A spike in active_connections without a corresponding spike in tokens_per_second indicates stalled streams – an early warning of LLM API issues before errors propagate to users.

# Use a custom registry to avoid conflicts in notebook re-runs
registry = CollectorRegistry() if PROMETHEUS_AVAILABLE else None

if PROMETHEUS_AVAILABLE:
    # --- Connection metrics ---
    ACTIVE_CONNECTIONS = Gauge(
        'streaming_active_connections',
        'Number of currently active streaming connections',
        ['endpoint'],
        registry=registry
    )

    CONNECTIONS_TOTAL = Counter(
        'streaming_connections_total',
        'Total streaming connections opened',
        ['endpoint', 'status'],  # status: success, rejected, error
        registry=registry
    )

    # --- Latency metrics ---
    TTFT_SECONDS = Histogram(
        'streaming_ttft_seconds',
        'Time to first token in seconds',
        ['endpoint', 'model'],
        buckets=[0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0],
        registry=registry
    )

    GENERATION_DURATION = Histogram(
        'streaming_generation_duration_seconds',
        'Total time from first to last token',
        ['endpoint', 'model'],
        buckets=[1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0],
        registry=registry
    )

    REQUEST_DURATION = Histogram(
        'streaming_request_duration_seconds',
        'Total request duration including retrieval',
        ['endpoint'],
        buckets=[0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0],
        registry=registry
    )

    # --- Throughput metrics ---
    TOKENS_STREAMED_TOTAL = Counter(
        'streaming_tokens_total',
        'Total tokens streamed to clients',
        ['endpoint', 'model'],
        registry=registry
    )

    TOKENS_PER_SECOND = Histogram(
        'streaming_tokens_per_second',
        'Token generation throughput',
        ['model'],
        buckets=[5, 10, 20, 30, 50, 75, 100, 150, 200],
        registry=registry
    )

    # --- Error metrics ---
    ERRORS_TOTAL = Counter(
        'streaming_errors_total',
        'Total errors during streaming',
        ['endpoint', 'error_type'],  # error_type: timeout, backpressure, llm_error, rate_limited
        registry=registry
    )

    CIRCUIT_BREAKER_STATE = Gauge(
        'streaming_circuit_breaker_state',
        'Circuit breaker state (0=closed, 1=half_open, 2=open)',
        ['service'],
        registry=registry
    )

    print("Prometheus metrics registered:")
    for metric in [
        'streaming_active_connections',
        'streaming_connections_total',
        'streaming_ttft_seconds',
        'streaming_generation_duration_seconds',
        'streaming_request_duration_seconds',
        'streaming_tokens_total',
        'streaming_tokens_per_second',
        'streaming_errors_total',
        'streaming_circuit_breaker_state'
    ]:
        print(f"  {metric}")
else:
    print("prometheus-client not available. Install: pip install prometheus-client")
# Instrument a streaming handler with Prometheus metrics

@dataclass
class StreamStats:
    start_time: float = field(default_factory=time.perf_counter)
    first_token_time: Optional[float] = None
    last_token_time: Optional[float] = None
    token_count: int = 0
    error: Optional[str] = None

    def record_token(self):
        now = time.perf_counter()
        if self.first_token_time is None:
            self.first_token_time = now
        self.last_token_time = now
        self.token_count += 1

    @property
    def ttft_seconds(self) -> Optional[float]:
        if self.first_token_time:
            return self.first_token_time - self.start_time
        return None

    @property
    def generation_duration_seconds(self) -> Optional[float]:
        if self.first_token_time and self.last_token_time:
            return self.last_token_time - self.first_token_time
        return None

    @property
    def total_duration_seconds(self) -> float:
        return time.perf_counter() - self.start_time

    @property
    def tokens_per_second(self) -> Optional[float]:
        dur = self.generation_duration_seconds
        if dur and dur > 0:
            return self.token_count / dur
        return None


def record_stream_metrics(
    stats: StreamStats,
    endpoint: str = "/stream",
    model: str = "gpt-4o-mini"
):
    """Flush StreamStats into Prometheus metrics."""
    if not PROMETHEUS_AVAILABLE:
        return

    if stats.ttft_seconds is not None:
        TTFT_SECONDS.labels(endpoint=endpoint, model=model).observe(stats.ttft_seconds)

    if stats.generation_duration_seconds is not None:
        GENERATION_DURATION.labels(endpoint=endpoint, model=model).observe(
            stats.generation_duration_seconds
        )

    REQUEST_DURATION.labels(endpoint=endpoint).observe(stats.total_duration_seconds)
    TOKENS_STREAMED_TOTAL.labels(endpoint=endpoint, model=model).inc(stats.token_count)

    if stats.tokens_per_second is not None:
        TOKENS_PER_SECOND.labels(model=model).observe(stats.tokens_per_second)

    if stats.error:
        ERRORS_TOTAL.labels(endpoint=endpoint, error_type=stats.error).inc()


# --- Simulate several streaming requests and check metrics ---
import numpy as np

async def simulate_stream_request(ttft_ms: float, tokens: int, tps: float):
    """Simulate a streaming request with given TTFT, token count, and tokens/sec."""
    stats = StreamStats()

    if PROMETHEUS_AVAILABLE:
        ACTIVE_CONNECTIONS.labels(endpoint="/stream").inc()

    try:
        await asyncio.sleep(ttft_ms / 1000.0)  # Simulate retrieval + TTFT
        for _ in range(tokens):
            stats.record_token()
            await asyncio.sleep(1.0 / tps)
    finally:
        if PROMETHEUS_AVAILABLE:
            ACTIVE_CONNECTIONS.labels(endpoint="/stream").dec()
        record_stream_metrics(stats)

    return stats

# Simulate 5 requests with varying characteristics
print("Simulating 5 streaming requests...")
requests = [
    (150, 30, 25),  # ttft_ms, tokens, tps
    (300, 50, 20),
    (200, 20, 30),
    (800, 80, 15),
    (100, 15, 35),
]

all_stats = await asyncio.gather(*[
    simulate_stream_request(ttft, tokens, tps)
    for ttft, tokens, tps in requests
])

print("\nPer-request stats:")
for i, s in enumerate(all_stats):
    print(
        f"  Request {i+1}: "
        f"TTFT={s.ttft_seconds*1000:.0f}ms, "
        f"tokens={s.token_count}, "
        f"tps={s.tokens_per_second:.1f}, "
        f"total={s.total_duration_seconds*1000:.0f}ms"
    )

# Summary percentiles
ttfts = [s.ttft_seconds * 1000 for s in all_stats if s.ttft_seconds]
tps_vals = [s.tokens_per_second for s in all_stats if s.tokens_per_second]
print(f"\nAggregate:")
print(f"  TTFT p50={np.percentile(ttfts, 50):.0f}ms, p95={np.percentile(ttfts, 95):.0f}ms")
print(f"  Tokens/sec avg={np.mean(tps_vals):.1f}")

if PROMETHEUS_AVAILABLE:
    output = generate_latest(registry).decode("utf-8")
    # Show just the TTFT metric
    for line in output.split("\n"):
        if "ttft" in line or "tokens_per" in line:
            print(f"  {line}")

11. Graceful Shutdown and Connection DrainingΒΆ

Zero-Downtime Deployments for StreamingΒΆ

When deploying a new version, the server must stop accepting new connections while letting existing streams finish gracefully. This is called connection draining. The process follows three steps: (1) receive SIGTERM from the orchestrator, (2) stop the listener (no new connections), (3) wait up to a drain timeout (e.g., 60 seconds) for active streams to complete, then force-close any remaining connections.

Without graceful shutdown, a rolling deployment in Kubernetes kills pods immediately, severing active streams mid-response. Users see truncated answers with no error message. The drain timeout must be configured at every layer: Kubernetes terminationGracePeriodSeconds, the application shutdown handler, and nginx proxy_read_timeout. The application should also send a final SSE event ({β€œtype”: β€œserver_shutdown”}) to connected clients so they can retry on a healthy instance.

class GracefulShutdownManager:
    """
    Manages graceful shutdown of a streaming server.

    Shutdown sequence:
    1. Receive SIGTERM
    2. Stop accepting new connections (is_shutting_down = True)
    3. Send 'draining' SSE event to all active connections
    4. Wait for active connections to finish (up to drain_timeout_s)
    5. Force-close remaining connections
    6. Exit
    """

    def __init__(self, drain_timeout_s: float = 30.0):
        self.drain_timeout_s = drain_timeout_s
        self.is_shutting_down = False
        self._active_streams: Set[str] = set()
        self._stream_done_event: Dict[str, asyncio.Event] = {}
        self._shutdown_complete = asyncio.Event()
        self._lock = asyncio.Lock()

    async def register_stream(self, stream_id: str):
        async with self._lock:
            self._active_streams.add(stream_id)
            self._stream_done_event[stream_id] = asyncio.Event()

    async def unregister_stream(self, stream_id: str):
        async with self._lock:
            self._active_streams.discard(stream_id)
            if stream_id in self._stream_done_event:
                self._stream_done_event[stream_id].set()
                del self._stream_done_event[stream_id]
        if not self._active_streams and self.is_shutting_down:
            self._shutdown_complete.set()

    def accept_new_connections(self) -> bool:
        """Check if new connections should be accepted."""
        return not self.is_shutting_down

    async def initiate_shutdown(self):
        """Begin graceful shutdown."""
        logger.info("Graceful shutdown initiated")
        self.is_shutting_down = True

        active_count = len(self._active_streams)
        if active_count == 0:
            logger.info("No active streams. Shutdown complete.")
            self._shutdown_complete.set()
            return

        logger.info(f"Draining {active_count} active stream(s) (timeout={self.drain_timeout_s}s)...")

        try:
            await asyncio.wait_for(
                self._shutdown_complete.wait(),
                timeout=self.drain_timeout_s
            )
            logger.info("All streams drained. Clean shutdown.")
        except asyncio.TimeoutError:
            remaining = len(self._active_streams)
            logger.warning(f"Drain timeout. Force-closing {remaining} stream(s).")

    def get_stats(self) -> Dict:
        return {
            "active_streams": len(self._active_streams),
            "is_shutting_down": self.is_shutting_down,
        }


# FastAPI lifespan integration
SHUTDOWN_MANAGER = GracefulShutdownManager(drain_timeout_s=30.0)

LIFESPAN_CODE = '''
from contextlib import asynccontextmanager

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    logger.info("Server starting up")
    yield
    # Shutdown (triggered by SIGTERM/SIGINT)
    logger.info("Server shutting down")
    await SHUTDOWN_MANAGER.initiate_shutdown()

app = FastAPI(lifespan=lifespan)
'''

# --- Demo graceful shutdown ---
async def demo_graceful_shutdown():
    mgr = GracefulShutdownManager(drain_timeout_s=2.0)

    # Register some active streams
    for sid in ["stream-A", "stream-B", "stream-C"]:
        await mgr.register_stream(sid)

    print(f"Active streams: {mgr.get_stats()}")
    print(f"Accepting connections: {mgr.accept_new_connections()}")

    # Simulate streams finishing
    async def finish_stream(sid, delay):
        await asyncio.sleep(delay)
        await mgr.unregister_stream(sid)
        print(f"  Stream {sid} finished")

    # Start shutdown and simultaneously finish streams
    print("\nInitiating graceful shutdown...")
    print(f"Accepting new connections: {mgr.accept_new_connections()}")

    shutdown_task = asyncio.create_task(mgr.initiate_shutdown())
    finish_tasks = [
        asyncio.create_task(finish_stream("stream-A", 0.3)),
        asyncio.create_task(finish_stream("stream-B", 0.5)),
        asyncio.create_task(finish_stream("stream-C", 0.8)),
    ]

    await asyncio.gather(shutdown_task, *finish_tasks)
    print(f"Shutdown complete. Active streams: {mgr.get_stats()}")

await demo_graceful_shutdown()

12. Complete Production FastAPI AppΒΆ

All Components IntegratedΒΆ

This section assembles every production component – connection management, rate limiting, backpressure handling, circuit breaker, retry logic, Prometheus metrics, and graceful shutdown – into a single FastAPI application. The app demonstrates how these components compose: each incoming request passes through rate limiting, then connection limit checks, then the streaming handler with backpressure and timeout management, with metrics recorded at every stage.

Key integration patterns: middleware handles cross-cutting concerns (rate limiting, metrics), the connection manager is a FastAPI dependency, and the circuit breaker wraps the LLM client. The /metrics endpoint exposes Prometheus-format data, /health returns detailed system status including circuit breaker state and connection pool utilization, and the streaming endpoint (/stream) ties everything together.

from contextlib import asynccontextmanager

# --- Singletons ---
conn_manager = ConnectionManager(max_global=500, max_per_ip=10)
rate_limiter = RateLimiter(requests_per_minute=20, burst_size=5)
circuit_breaker = CircuitBreaker(name="llm_api", failure_threshold=5, reset_timeout_s=30)
shutdown_manager = GracefulShutdownManager(drain_timeout_s=30)
timeout_config = TimeoutConfig(
    first_token_s=30.0,
    inter_token_s=15.0,
    total_request_s=120.0
)
retry_config = RetryConfig(max_retries=2, base_delay_s=1.0, max_delay_s=10.0)


@asynccontextmanager
async def lifespan(app: FastAPI):
    logger.info("Production streaming server starting")
    yield
    logger.info("Initiating graceful shutdown")
    await shutdown_manager.initiate_shutdown()


prod_app = FastAPI(title="Production Streaming API", version="1.0.0", lifespan=lifespan)


class StreamRequest(BaseModel):
    prompt: str
    model: str = "gpt-4o-mini"
    max_tokens: int = 512


async def mock_llm_stream(prompt: str, model: str) -> AsyncGenerator[str, None]:
    """Mock LLM -- replace with real OpenAI/Anthropic call."""
    words = f"This is a streaming response for: {prompt[:40]}".split()
    for word in words:
        yield word + " "
        await asyncio.sleep(0.05)


async def production_stream_generator(
    request: StreamRequest,
    conn_info: ConnectionInfo
) -> AsyncGenerator[str, None]:
    """Full production streaming pipeline with all safety mechanisms."""
    stats = StreamStats()
    stream_id = conn_info.id

    if PROMETHEUS_AVAILABLE:
        ACTIVE_CONNECTIONS.labels(endpoint="/stream").inc()

    await shutdown_manager.register_stream(stream_id)

    try:
        # Check circuit breaker
        if circuit_breaker.state == CircuitState.OPEN:
            yield f"data: {json.dumps({'type': 'error', 'code': 'circuit_open', 'message': 'LLM API unavailable, retry later'})}\n\n"
            return

        yield f"data: {json.dumps({'type': 'status', 'message': 'Generating...'})}\n\n"

        # Timeout-wrapped source
        timeout_mgr = TimeoutManager(timeout_config)

        async def call_llm():
            return mock_llm_stream(request.prompt, request.model)

        source = await circuit_breaker.call(call_llm)

        # Backpressure wrapper
        bp_stream = BackpressureStream(source, queue_maxsize=100, slow_client_timeout_s=5.0)

        async for token in bp_stream.stream():
            stats.record_token()
            conn_info.tokens_sent += 1
            token_bytes = token.encode()
            conn_info.bytes_sent += len(token_bytes)
            yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"

        yield f"data: {json.dumps({'type': 'done', 'tokens': stats.token_count})}\n\n"

    except BackpressureError:
        if PROMETHEUS_AVAILABLE:
            ERRORS_TOTAL.labels(endpoint="/stream", error_type="backpressure").inc()
        yield f"data: {json.dumps({'type': 'error', 'code': 'backpressure'})}\n\n"

    except CircuitBreakerOpen as e:
        if PROMETHEUS_AVAILABLE:
            ERRORS_TOTAL.labels(endpoint="/stream", error_type="circuit_open").inc()
        yield f"data: {json.dumps({'type': 'error', 'code': 'circuit_open', 'message': str(e)})}\n\n"

    except asyncio.CancelledError:
        logger.info(f"Stream {stream_id} cancelled (client disconnected)")

    except Exception as e:
        logger.error(f"Stream {stream_id} error: {e}")
        if PROMETHEUS_AVAILABLE:
            ERRORS_TOTAL.labels(endpoint="/stream", error_type="unhandled").inc()
        yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"

    finally:
        await conn_manager.release(conn_info.id)
        await shutdown_manager.unregister_stream(stream_id)
        record_stream_metrics(stats, endpoint="/stream", model=request.model)
        if PROMETHEUS_AVAILABLE:
            ACTIVE_CONNECTIONS.labels(endpoint="/stream").dec()


@prod_app.post("/stream")
async def stream_endpoint(req: StreamRequest, http_req: Request):
    client_ip = http_req.client.host if http_req.client else "unknown"

    # Reject if shutting down
    if not shutdown_manager.accept_new_connections():
        raise HTTPException(status_code=503, detail="Server is shutting down")

    # Rate limit check
    rate_result = await rate_limiter.check(client_ip)
    if not rate_result["allowed"]:
        raise HTTPException(
            status_code=429,
            detail=f"Rate limit exceeded. Retry after {rate_result['retry_after']}s",
            headers={"Retry-After": str(rate_result["retry_after"])}
        )

    # Connection limit check
    conn_info = await conn_manager.acquire(client_ip, req.prompt[:100])
    if conn_info is None:
        raise HTTPException(status_code=503, detail="Too many connections")

    if PROMETHEUS_AVAILABLE:
        CONNECTIONS_TOTAL.labels(endpoint="/stream", status="success").inc()

    return StreamingResponse(
        production_stream_generator(req, conn_info),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        }
    )


@prod_app.get("/metrics")
async def metrics_endpoint():
    if not PROMETHEUS_AVAILABLE:
        return {"error": "prometheus-client not installed"}
    return Response(
        content=generate_latest(registry),
        media_type=CONTENT_TYPE_LATEST
    )


@prod_app.get("/health")
async def health_check():
    return {
        "status": "degraded" if shutdown_manager.is_shutting_down else "healthy",
        "active_connections": conn_manager.active_count,
        "circuit_breaker": circuit_breaker.state.value,
        "shutting_down": shutdown_manager.is_shutting_down
    }


print("Production FastAPI app configured.")
print("\nRoutes:")
for route in prod_app.routes:
    if hasattr(route, 'methods'):
        print(f"  {list(route.methods)} {route.path}")

print("\nTo run:")
print("  uvicorn notebook_prod_app:prod_app --workers 4 --port 8001")

13. Load Testing with LocustΒΆ

Validating Production ReadinessΒΆ

Locust is a Python-based load testing framework that can simulate thousands of concurrent users. For SSE streaming endpoints, standard HTTP benchmarks (wrk, ab) do not work because they expect a complete response – they cannot measure TTFT or handle the event stream protocol. Locust lets you write custom Python code that opens an SSE connection, reads events one by one, and measures TTFT and total generation time.

The load test should validate three SLOs: (1) TTFT p95 under 1 second at target concurrency, (2) error rate below 0.1% under sustained load, (3) no memory leaks over extended runs (monitor RSS via /metrics). Start with 10 users and ramp to your target (e.g., 500) over 60 seconds. Watch for the inflection point where TTFT degrades sharply – this is your practical concurrency limit per instance.

# Save to locustfile.py and run with: locust -f locustfile.py --host http://localhost:8001

LOCUSTFILE = '''
import time
import json
from locust import HttpUser, task, between, events
from locust.exception import RescheduleTask


class StreamingUser(HttpUser):
    """
    Simulates a user making streaming LLM requests.
    Measures TTFT separately from total response time.
    """
    wait_time = between(1, 3)  # Wait 1-3s between requests

    @task(3)
    def stream_short_query(self):
        """Short query -- should have low TTFT."""
        self._do_stream_request(
            prompt="What is RAG?",
            name="/stream [short]"
        )

    @task(1)
    def stream_long_query(self):
        """Long query -- higher TTFT due to retrieval."""
        self._do_stream_request(
            prompt="Explain in detail the differences between RAG and fine-tuning for LLMs, "
                   "including use cases, costs, and tradeoffs.",
            name="/stream [long]"
        )

    def _do_stream_request(self, prompt: str, name: str):
        start = time.perf_counter()
        first_token_time = None
        token_count = 0
        total_bytes = 0

        with self.client.post(
            "/stream",
            json={"prompt": prompt},
            stream=True,
            name=name,
            catch_response=True
        ) as response:
            if response.status_code == 429:
                response.failure(f"Rate limited: {response.text}")
                return
            if response.status_code == 503:
                response.failure(f"Service unavailable: {response.text}")
                raise RescheduleTask()
            if response.status_code != 200:
                response.failure(f"HTTP {response.status_code}")
                return

            for line in response.iter_lines():
                if not line:
                    continue
                if isinstance(line, bytes):
                    line = line.decode("utf-8")
                if not line.startswith("data: "):
                    continue

                total_bytes += len(line)
                try:
                    data = json.loads(line[6:])
                except json.JSONDecodeError:
                    continue

                if data.get("type") == "token":
                    if first_token_time is None:
                        first_token_time = time.perf_counter()
                        ttft_ms = (first_token_time - start) * 1000
                        # Report TTFT as a custom metric
                        events.request.fire(
                            request_type="TTFT",
                            name=name,
                            response_time=ttft_ms,
                            response_length=0,
                            exception=None,
                            context={}
                        )
                    token_count += 1

                elif data.get("type") == "error":
                    response.failure(f"Stream error: {data}")
                    return

            total_ms = (time.perf_counter() - start) * 1000
            if token_count == 0:
                response.failure("No tokens received")
            else:
                response.success()


# Run with:
#   locust -f locustfile.py --host http://localhost:8001 --users 50 --spawn-rate 5
# Or headless:
#   locust -f locustfile.py --host http://localhost:8001 \\
#     --users 50 --spawn-rate 5 --run-time 60s --headless
'''

# Save the locustfile
with open("/tmp/locustfile_streaming.py", "w") as f:
    f.write(LOCUSTFILE)

print("Locust load test file saved to /tmp/locustfile_streaming.py")
print()
print("Run load test with:")
print("  locust -f /tmp/locustfile_streaming.py \\")
print("    --host http://localhost:8001 \\")
print("    --users 50 --spawn-rate 5 --run-time 60s --headless")
print()
print("Key metrics to watch:")
metrics_to_watch = [
    ("TTFT p50",          "< 300ms target"),
    ("TTFT p95",          "< 1000ms target"),
    ("Request failures",  "< 0.1% target"),
    ("Active users",      "Ramp to 50 over 10s"),
    ("RPS",               "Track peak sustainable RPS"),
]
for metric, target in metrics_to_watch:
    print(f"  {metric:<25} {target}")

14. Docker Compose Production SetupΒΆ

Infrastructure as Code for Streaming ServicesΒΆ

A production streaming deployment requires multiple coordinated services: the FastAPI application (the streaming server itself), Nginx (reverse proxy with SSE/WebSocket configuration), Prometheus (metrics collection and alerting), and Grafana (dashboards and visualization). Docker Compose orchestrates these services with proper networking, volume mounts for configuration files, and health checks.

Critical configuration details: the Nginx container must mount the custom SSE configuration (proxy_buffering off, extended timeouts). Prometheus scrapes the /metrics endpoint at 15-second intervals. Grafana provisions pre-built dashboards for streaming metrics (active connections, TTFT heatmap, error rate). The application container should set resource limits (memory, CPU) to prevent a single instance from consuming all host resources during traffic spikes.

DOCKER_COMPOSE = '''
# docker-compose.yml -- Production streaming setup

version: "3.9"

services:

  # ---- Application (3 replicas) ----
  app:
    build:
      context: .
      dockerfile: Dockerfile
    environment:
      - OPENAI_API_KEY=${OPENAI_API_KEY}
      - MAX_CONNECTIONS=200
      - MAX_CONNECTIONS_PER_IP=5
      - RATE_LIMIT_RPM=20
    deploy:
      replicas: 3
      resources:
        limits:
          memory: 1G
          cpus: "1.0"
    healthcheck:
      test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
      interval: 10s
      timeout: 5s
      retries: 3
      start_period: 15s
    restart: unless-stopped
    networks:
      - streaming_net

  # ---- Nginx reverse proxy ----
  nginx:
    image: nginx:1.25-alpine
    ports:
      - "80:80"
      - "443:443"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf:ro
      - ./certs:/etc/nginx/certs:ro
    depends_on:
      - app
    restart: unless-stopped
    networks:
      - streaming_net

  # ---- Prometheus ----
  prometheus:
    image: prom/prometheus:v2.48.0
    ports:
      - "9090:9090"
    volumes:
      - ./prometheus.yml:/etc/prometheus/prometheus.yml:ro
      - prometheus_data:/prometheus
    command:
      - "--config.file=/etc/prometheus/prometheus.yml"
      - "--storage.tsdb.retention.time=7d"
    restart: unless-stopped
    networks:
      - streaming_net

  # ---- Grafana ----
  grafana:
    image: grafana/grafana:10.2.0
    ports:
      - "3000:3000"
    environment:
      - GF_SECURITY_ADMIN_PASSWORD=admin
      - GF_USERS_ALLOW_SIGN_UP=false
    volumes:
      - grafana_data:/var/lib/grafana
      - ./grafana/dashboards:/var/lib/grafana/dashboards:ro
    depends_on:
      - prometheus
    restart: unless-stopped
    networks:
      - streaming_net

networks:
  streaming_net:
    driver: bridge

volumes:
  prometheus_data:
  grafana_data:
'''

DOCKERFILE = '''
# Dockerfile
FROM python:3.11-slim

WORKDIR /app

# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application
COPY app/ ./app/

# Non-root user for security
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser

# Expose port
EXPOSE 8000

# Graceful shutdown: uvicorn handles SIGTERM by stopping new connections
# and waiting for in-flight requests before exiting
CMD ["uvicorn", "app.main:app",
     "--host", "0.0.0.0",
     "--port", "8000",
     "--workers", "1",
     "--timeout-graceful-shutdown", "30",
     "--access-log",
     "--log-level", "info"]
'''

PROMETHEUS_CONFIG = '''
# prometheus.yml
global:
  scrape_interval: 15s
  evaluation_interval: 15s

scrape_configs:
  - job_name: streaming_api
    static_configs:
      - targets:
          - app:8000  # Scraped via Docker service DNS
    metrics_path: /metrics
    scrape_interval: 10s

alerting:
  alertmanagers:
    - static_configs:
        - targets: []

rule_files:
  - "alerts.yml"
'''

ALERTS_CONFIG = '''
# alerts.yml -- Prometheus alerting rules
groups:
  - name: streaming_alerts
    rules:

      - alert: HighTTFT
        expr: histogram_quantile(0.95, rate(streaming_ttft_seconds_bucket[5m])) > 2.0
        for: 2m
        labels:
          severity: warning
        annotations:
          summary: "P95 TTFT > 2s"

      - alert: TooManyConnections
        expr: streaming_active_connections > 400
        for: 1m
        labels:
          severity: critical
        annotations:
          summary: "Active connections > 400 (limit 500)"

      - alert: HighErrorRate
        expr: rate(streaming_errors_total[5m]) > 0.01
        for: 2m
        labels:
          severity: warning
        annotations:
          summary: "Error rate > 1%"

      - alert: CircuitBreakerOpen
        expr: streaming_circuit_breaker_state == 2
        for: 0m
        labels:
          severity: critical
        annotations:
          summary: "Circuit breaker OPEN -- LLM API down"
'''

print("=== docker-compose.yml ===")
print(DOCKER_COMPOSE)
print("=== Dockerfile ===")
print(DOCKERFILE)
print("=== prometheus.yml ===")
print(PROMETHEUS_CONFIG)
print("=== alerts.yml ===")
print(ALERTS_CONFIG)
# Useful Grafana dashboard queries (PromQL)

GRAFANA_QUERIES = {
    "Active Connections": (
        "streaming_active_connections",
        "Current active SSE connections"
    ),
    "TTFT p50 (ms)": (
        "histogram_quantile(0.5, rate(streaming_ttft_seconds_bucket[5m])) * 1000",
        "Median time to first token in ms"
    ),
    "TTFT p95 (ms)": (
        "histogram_quantile(0.95, rate(streaming_ttft_seconds_bucket[5m])) * 1000",
        "95th percentile TTFT in ms"
    ),
    "TTFT p99 (ms)": (
        "histogram_quantile(0.99, rate(streaming_ttft_seconds_bucket[5m])) * 1000",
        "99th percentile TTFT in ms"
    ),
    "Tokens per second (avg)": (
        "histogram_quantile(0.5, rate(streaming_tokens_per_second_bucket[5m]))",
        "Median token throughput"
    ),
    "Error rate (%)": (
        "100 * rate(streaming_errors_total[5m]) / rate(streaming_connections_total[5m])",
        "Percentage of connections with errors"
    ),
    "Connection rate (per min)": (
        "rate(streaming_connections_total[1m]) * 60",
        "New connections per minute"
    ),
    "Circuit breaker state": (
        "streaming_circuit_breaker_state",
        "0=closed, 1=half_open, 2=open"
    ),
}

print("Grafana dashboard PromQL queries:")
print("=" * 70)
for panel_name, (query, description) in GRAFANA_QUERIES.items():
    print(f"\nPanel: {panel_name}")
    print(f"  Query: {query}")
    print(f"  Note:  {description}")

SummaryΒΆ

Production Streaming ChecklistΒΆ

InfrastructureΒΆ

  • Nginx configured with proxy_buffering off and long proxy_read_timeout

  • Sticky sessions (ip_hash) for SSE

  • Docker health checks on /health endpoint

  • Graceful shutdown with 30s drain timeout

ReliabilityΒΆ

  • Connection limits: global (500) and per-IP (10)

  • Rate limiting: token bucket, 20 req/min with burst of 5

  • Backpressure: bounded queue, disconnect slow clients after 5s

  • Timeouts: first-token (30s), inter-token (15s), total (120s)

  • Circuit breaker: open after 5 failures in 60s, reset after 30s

  • Retry with exponential backoff + full jitter

ObservabilityΒΆ

  • Prometheus metrics: active connections, TTFT histogram, tokens/sec, error rates

  • Grafana dashboards for all key metrics

  • Alerts: TTFT p95 > 2s, connections > 400, error rate > 1%, circuit open

  • Structured logging with connection IDs

Load TestingΒΆ

  • Locust test measuring TTFT as a custom metric

  • Soak test: 50 users x 60 minutes

  • Spike test: ramp to 200 users in 30s

Performance TargetsΒΆ

Metric

Target

TTFT p50

< 300ms

TTFT p95

< 1000ms

TTFT p99

< 2000ms

Error rate

< 0.1%

Max connections per instance

500

Memory per connection

< 2MB

Graceful shutdown

< 30s drain

Key TakeawaysΒΆ

  1. Fail fast – circuit breakers and rate limits protect the LLM API from overload.

  2. Bound everything – queues, connections, timeouts. Unbounded resources cause outages.

  3. Measure TTFT – it is the most important user-facing metric for streaming.

  4. Drain before shutdown – never kill active streams; let them complete or time out gracefully.

  5. Test under load – Locust reveals backpressure and memory issues that unit tests miss.