WebSocket Connections for Real-Time AI ChatΒΆ

Phase 20 - Notebook 2ΒΆ

What you will learn:

  • SSE vs WebSockets: when to use which

  • WebSocket protocol basics and lifecycle

  • FastAPI WebSocket endpoint implementation

  • Bidirectional communication patterns and message envelopes

  • Connection lifecycle management (connect, authenticate, message, disconnect)

  • Heartbeat/ping-pong for connection health

  • Handling multiple concurrent connections

  • Broadcasting to multiple clients (rooms/channels)

  • Authentication with WebSocket headers and first-message auth

  • Client-side WebSocket with Python websockets library

  • Complete demo chat application (server + client code)

  • Error handling and reconnection with exponential backoff

  • Connection pooling patterns

Prerequisites: Notebook 1 (SSE Streaming), Python async/await, FastAPI basics

# Install required packages
!pip install fastapi websockets uvicorn openai python-dotenv httpx -q

import os
import asyncio
import json
import time
import random
import uuid
import logging
from datetime import datetime
from typing import Dict, Set, Optional, Any
from dataclasses import dataclass, asdict
from dotenv import load_dotenv

import openai

load_dotenv()

# Configure logging for connection events
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S"
)
logger = logging.getLogger("ws_demo")

openai_key = os.getenv("OPENAI_API_KEY", "")
print("Environment setup:")
print(f"  OpenAI key: {'set' if openai_key else 'NOT SET'}")
print("\nAll imports successful.")

Part 1: SSE vs WebSocketsΒΆ

Side-by-Side ComparisonΒΆ

Feature

Server-Sent Events (SSE)

WebSockets

Direction

Server β†’ Client only

Full duplex (both ways)

Protocol

HTTP/1.1 or HTTP/2

Upgraded HTTP β†’ ws://

Browser auto-reconnect

Yes (built-in)

No (must implement)

Firewalls/Proxies

Usually passes through

May be blocked

Data types

Text only

Text and binary

Max connections

Limited by browser (6 per domain in HTTP/1.1)

No practical limit

Complexity

Simple

Moderate

Use cases

LLM token streaming, logs, notifications

Chat, games, collaboration, trading

Decision FlowchartΒΆ

Does the CLIENT need to send messages after the initial request?
         β”‚
    YES  β”‚  NO
         β”‚   └─► Use SSE (EventSource)
         β”‚
         β–Ό
Is low latency bidirectional messaging required?
         β”‚
    YES  β”‚  NO
         β”‚   └─► Consider polling or SSE with a separate POST endpoint
         β”‚
         β–Ό
Use WebSockets

Real-World ExamplesΒΆ

  • Use SSE: ChatGPT’s streaming response (server streams tokens, you only click send once)

  • Use WebSockets: Claude’s web app (bidirectional with stop button, multi-turn conversation state)

  • Use WebSockets: GitHub Copilot in VS Code (editor sends context continuously, server streams code)

Part 2: WebSocket Protocol BasicsΒΆ

The HTTP Upgrade HandshakeΒΆ

A WebSocket connection starts as a normal HTTP request with special headers:

Client β†’ Server:
  GET /ws HTTP/1.1
  Host: localhost:8001
  Upgrade: websocket                      ← Request upgrade
  Connection: Upgrade
  Sec-WebSocket-Key: dGhlIHNhbXBsZQ==     ← Random 16-byte base64
  Sec-WebSocket-Version: 13

Server β†’ Client:
  HTTP/1.1 101 Switching Protocols        ← Accepted
  Upgrade: websocket
  Connection: Upgrade
  Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=  ← Derived from key

After the 101 response, the TCP connection is kept open and both sides can send at any time.

Frame TypesΒΆ

Opcode

Frame Type

Description

0x1

Text

UTF-8 text payload

0x2

Binary

Raw bytes

0x8

Close

Initiate close handshake

0x9

Ping

Health check (server sends)

0xA

Pong

Health check response

Connection LifecycleΒΆ

CONNECTING ──► OPEN ──► CLOSING ──► CLOSED
     β”‚           β”‚          β”‚
     β”‚       messages    close(code)
     β”‚       ping/pong
  TCP connect
  HTTP upgrade

Close CodesΒΆ

Code

Meaning

1000

Normal closure

1001

Going away (server restart)

1008

Policy violation (auth failure)

1011

Server error

4000-4999

Application-defined

# Basic FastAPI WebSocket server - simplest possible echo server
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware

app_basic = FastAPI(title="Basic WebSocket Echo")
app_basic.add_middleware(
    CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
)


@app_basic.websocket("/ws")
async def websocket_echo(websocket: WebSocket):
    """
    Minimal echo server. Accepts a connection and echoes every message back.
    
    Lifecycle:
      1. Client connects β†’ server calls websocket.accept()
      2. Client sends text β†’ server echoes it back
      3. Client disconnects β†’ WebSocketDisconnect is raised
    """
    # STEP 1: Accept the connection (completes the HTTP upgrade handshake)
    await websocket.accept()
    client_host = websocket.client.host if websocket.client else "unknown"
    print(f"[CONNECTED] {client_host}")

    try:
        while True:
            # STEP 2: Wait for a message from the client
            # receive_text() blocks until a text frame arrives
            data = await websocket.receive_text()
            print(f"[RECEIVED] {data!r}")

            # STEP 3: Send a response back
            response = json.dumps({
                "type": "echo",
                "original": data,
                "timestamp": datetime.now().isoformat()
            })
            await websocket.send_text(response)

    except WebSocketDisconnect as e:
        # STEP 4: Handle clean or unclean disconnections
        print(f"[DISCONNECTED] {client_host} (code={e.code})")

    except Exception as e:
        print(f"[ERROR] {type(e).__name__}: {e}")
        await websocket.close(code=1011)  # Server error


@app_basic.get("/health")
async def health():
    return {"status": "ok"}


print("Basic echo WebSocket server defined.")
print()
print("To run: save to ws_echo.py, then:")
print("  uvicorn ws_echo:app_basic --reload --port 8001")
print()
print("Test with Python websockets client (next cell) or:")
print("  wscat -c ws://localhost:8001/ws")
# Python WebSocket client using the 'websockets' library
import websockets
import asyncio
import json


async def simple_ws_client(
    uri: str = "ws://localhost:8001/ws",
    messages: list = None
):
    """
    Connect to a WebSocket server, send messages, and receive responses.
    """
    if messages is None:
        messages = ["Hello, WebSocket!", "How are you?", "Goodbye!"]

    print(f"Connecting to {uri}...")

    async with websockets.connect(
        uri,
        ping_interval=20,   # Send ping every 20s to keep connection alive
        ping_timeout=10     # Disconnect if no pong within 10s
    ) as ws:
        print(f"Connected! (state={ws.state.name})")

        for msg in messages:
            # Send message
            await ws.send(msg)
            print(f"  Sent: {msg!r}")

            # Wait for response
            raw = await ws.recv()
            data = json.loads(raw)
            print(f"  Received: type={data.get('type')!r}, ts={data.get('timestamp', '')[:19]}")

            await asyncio.sleep(0.2)

    print("Connection closed cleanly.")


# Usage (requires server running on port 8001)
print("WebSocket client defined.")
print()
print("Usage (requires echo server on port 8001):")
print("  await simple_ws_client('ws://localhost:8001/ws')")
print()
print("# Uncomment to run if server is active:")
# await simple_ws_client("ws://localhost:8001/ws")

Part 3: Bidirectional Communication PatternsΒΆ

Message Envelope PatternΒΆ

Instead of sending raw strings, wrap every message in a typed envelope:

{
  "type": "chat",
  "data": {"text": "Hello!", "user": "alice"},
  "message_id": "a1b2c3d4",
  "timestamp": "2025-01-01T12:00:00.000Z"
}

Benefits:

  • Extensibility: Add fields without breaking clients

  • Routing: Server/client can route by type without parsing data

  • Debugging: Every message is self-describing

  • Deduplication: Use message_id to detect duplicate delivery

Common Message TypesΒΆ

Type

Direction

Purpose

auth

C β†’ S

Send token for authentication

auth_success

S β†’ C

Confirm authentication

chat

C β†’ S

User sends a message

token

S β†’ C

AI streaming token

message_complete

S β†’ C

Full AI response done

system

S β†’ C

Server announcement

ping

S β†’ C

Heartbeat check

pong

C β†’ S

Heartbeat response

error

S β†’ C

Error notification

typing_start

S β†’ C

AI started generating

# Message envelope dataclass with serialization

@dataclass
class WSMessage:
    """
    Structured WebSocket message envelope.
    All WebSocket messages should use this format.
    """
    type: str           # Message category (auth, chat, token, ping, error, ...)
    data: Any           # Payload - can be str, dict, list, or None
    message_id: str = None
    timestamp: str = None

    def __post_init__(self):
        if self.message_id is None:
            self.message_id = uuid.uuid4().hex[:8]
        if self.timestamp is None:
            self.timestamp = datetime.utcnow().isoformat() + "Z"

    def to_json(self) -> str:
        return json.dumps(asdict(self))

    @classmethod
    def from_json(cls, raw: str) -> "WSMessage":
        try:
            d = json.loads(raw)
            return cls(
                type=d.get("type", "unknown"),
                data=d.get("data"),
                message_id=d.get("message_id"),
                timestamp=d.get("timestamp")
            )
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON in WebSocket message: {e}")

    @classmethod
    def chat(cls, text: str, user: str = "user") -> "WSMessage":
        return cls(type="chat", data={"text": text, "user": user})

    @classmethod
    def token(cls, text: str) -> "WSMessage":
        return cls(type="token", data={"text": text})

    @classmethod
    def system(cls, text: str) -> "WSMessage":
        return cls(type="system", data={"text": text})

    @classmethod
    def error(cls, message: str, code: int = 0) -> "WSMessage":
        return cls(type="error", data={"message": message, "code": code})

    @classmethod
    def ping(cls) -> "WSMessage":
        return cls(type="ping", data=None)

    @classmethod
    def pong(cls) -> "WSMessage":
        return cls(type="pong", data=None)


# Demonstrate envelope usage
print("WSMessage examples:\n")

msg1 = WSMessage.chat("Hello, AI!", user="alice")
print(f"Chat message:   {msg1.to_json()}")

msg2 = WSMessage.token("Hello")
print(f"Token message:  {msg2.to_json()}")

msg3 = WSMessage.error("Rate limit exceeded", code=429)
print(f"Error message:  {msg3.to_json()}")

# Round-trip: serialize then deserialize
raw = msg1.to_json()
parsed = WSMessage.from_json(raw)
print(f"\nRound-trip:")
print(f"  Original: type={msg1.type}, data={msg1.data}")
print(f"  Parsed:   type={parsed.type}, data={parsed.data}")
print(f"  IDs match: {msg1.message_id == parsed.message_id}")

Part 4: Connection Lifecycle ManagementΒΆ

Connection StatesΒΆ

CONNECTING
    β”‚
    β–Ό  (websocket.accept() called)
  OPEN  ◄──── All normal messaging happens here
    β”‚
    β–Ό  (websocket.close() or WebSocketDisconnect raised)
 CLOSING
    β”‚
    β–Ό  (TCP connection torn down)
  CLOSED

Server-Side Responsibilities Per StateΒΆ

State

Server Action

CONNECTING

Validate origin, check rate limits, call accept() or close()

OPEN

Process messages, send responses, run heartbeat

CLOSING

Save state, clean up resources, remove from connection registry

CLOSED

Remove references (prevent memory leaks)

Common MistakesΒΆ

  1. Not removing from registry on disconnect β†’ memory leak, stale connections

  2. Not catching WebSocketDisconnect β†’ unhandled exception logs noise

  3. Blocking the event loop β†’ all connections freeze (use await, not time.sleep)

  4. Not calling accept() β†’ connection never opens

# ConnectionManager: tracks all active connections
from fastapi import WebSocket, WebSocketDisconnect


class ConnectionManager:
    """
    Central registry for all active WebSocket connections.
    Thread-safe for use within a single asyncio event loop.
    """

    def __init__(self):
        # client_id -> WebSocket
        self._connections: Dict[str, WebSocket] = {}
        # client_id -> metadata dict
        self._metadata: Dict[str, dict] = {}
        self._total_connections = 0  # Monotonically increasing counter

    async def connect(self, websocket: WebSocket, client_id: str, **metadata):
        """Accept and register a new connection."""
        await websocket.accept()
        self._connections[client_id] = websocket
        self._metadata[client_id] = {
            "connected_at": datetime.utcnow().isoformat(),
            "message_count": 0,
            **metadata
        }
        self._total_connections += 1
        logger.info(f"[CONNECT] {client_id} | active={self.active_count}")

    def disconnect(self, client_id: str):
        """Remove a connection from the registry (does NOT close the socket)."""
        self._connections.pop(client_id, None)
        self._metadata.pop(client_id, None)
        logger.info(f"[DISCONNECT] {client_id} | active={self.active_count}")

    async def send(self, client_id: str, message: WSMessage) -> bool:
        """
        Send a WSMessage to a specific client.
        Returns True on success, False if client is gone.
        """
        ws = self._connections.get(client_id)
        if not ws:
            return False
        try:
            await ws.send_text(message.to_json())
            if client_id in self._metadata:
                self._metadata[client_id]["message_count"] += 1
            return True
        except Exception as e:
            logger.warning(f"[SEND FAILED] {client_id}: {e}")
            self.disconnect(client_id)
            return False

    async def broadcast(
        self,
        message: WSMessage,
        exclude: Set[str] = None
    ) -> int:
        """
        Send a message to all connected clients.
        Returns the number of clients successfully reached.
        """
        exclude = exclude or set()
        client_ids = list(self._connections.keys())
        successes = 0

        for cid in client_ids:
            if cid not in exclude:
                if await self.send(cid, message):
                    successes += 1

        return successes

    def is_connected(self, client_id: str) -> bool:
        return client_id in self._connections

    @property
    def active_count(self) -> int:
        return len(self._connections)

    def get_stats(self) -> dict:
        return {
            "active_connections": self.active_count,
            "total_ever_connected": self._total_connections,
            "clients": [
                {"id": cid, **meta}
                for cid, meta in self._metadata.items()
            ]
        }


# Demo
manager = ConnectionManager()
print("ConnectionManager defined.")
print(f"Initial stats: {manager.get_stats()}")

Part 5: Heartbeat / Ping-PongΒΆ

Why Heartbeats Are NecessaryΒΆ

TCP connections can become silently dead due to:

  • NAT timeouts: Many routers drop idle connections after 60-300 seconds

  • Mobile network switching: WiFi β†’ 4G β†’ WiFi transitions

  • Proxy/load balancer timeouts: Nginx default is 60s idle timeout

  • Crashed clients: No close frame sent, server doesn’t know

Without heartbeats, the server holds resources for dead connections indefinitely.

Two Levels of Ping/PongΒΆ

  1. WebSocket protocol ping/pong (opcode 0x9/0xA): Handled transparently by the library. The websockets library sends these automatically with ping_interval.

  2. Application-level ping/pong (JSON messages): Your own heartbeat using the message envelope. Useful for detecting logic-level hangs, not just TCP-level drops.

Part 6: Multiple Concurrent ConnectionsΒΆ

asyncio Concurrency ModelΒΆ

FastAPI’s WebSocket handling uses asyncio. Each connection runs in the same event loop via cooperative multitasking:

Event Loop
  β”‚
  β”œβ”€ Coroutine: handle_client(alice)  ← awaiting receive_text()
  β”œβ”€ Coroutine: handle_client(bob)    ← awaiting receive_text()
  β”œβ”€ Coroutine: handle_client(carol)  ← streaming tokens to carol
  └─ Coroutine: heartbeat_task       ← sleeping for 30s

When any coroutine hits an await, Python switches to another coroutine. This allows hundreds of concurrent connections with a single process.

Scaling Beyond One ProcessΒΆ

For thousands of concurrent connections:

  1. Run multiple Uvicorn workers: uvicorn app:app --workers 4

  2. Use Redis Pub/Sub to broadcast across workers

  3. Use a sticky session load balancer (connections must reach the same worker)

# Full multi-client AI chat FastAPI application
from fastapi import FastAPI, WebSocket, WebSocketDisconnect

app_chat = FastAPI(title="Multi-Client AI Chat")
app_chat.add_middleware(
    CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
)

# Global connection registry
chat_manager = ConnectionManager()


@app_chat.websocket("/ws/chat/{client_id}")
async def chat_endpoint(websocket: WebSocket, client_id: str):
    """
    Multi-client AI chat WebSocket endpoint.
    Each client_id gets its own persistent connection.
    AI responses are streamed token by token over the WebSocket.
    """
    # Register the connection
    await chat_manager.connect(websocket, client_id, username=client_id)

    # Welcome the new client
    await chat_manager.send(
        client_id,
        WSMessage.system(f"Welcome, {client_id}! {chat_manager.active_count} user(s) online.")
    )

    # Start heartbeat in background
    hb_task = asyncio.create_task(
        run_heartbeat(websocket, client_id, interval=30.0)
    )

    # OpenAI async client
    oai = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))

    try:
        while True:
            # Wait for a message from this client
            raw = await websocket.receive_text()

            # Parse using our envelope
            try:
                msg = WSMessage.from_json(raw)
            except ValueError as e:
                await chat_manager.send(
                    client_id, WSMessage.error(f"Invalid message format: {e}")
                )
                continue

            # Handle pong (heartbeat response)
            if msg.type == "pong":
                continue

            # Handle chat message
            if msg.type == "chat":
                user_text = msg.data.get("text", "") if isinstance(msg.data, dict) else str(msg.data)

                if not user_text.strip():
                    continue

                # Signal that AI is generating
                await chat_manager.send(
                    client_id,
                    WSMessage(type="typing_start", data={"model": "gpt-4o-mini"})
                )

                # Stream AI response token by token
                full_response_parts = []
                token_count = 0
                stream_start = time.time()

                try:
                    async with oai.chat.completions.stream(
                        model="gpt-4o-mini",
                        messages=[{"role": "user", "content": user_text}],
                        max_tokens=300
                    ) as stream:
                        async for chunk in stream:
                            token = chunk.choices[0].delta.content
                            if token:
                                full_response_parts.append(token)
                                token_count += 1
                                # Send each token as a separate WebSocket message
                                sent = await chat_manager.send(
                                    client_id,
                                    WSMessage.token(token)
                                )
                                if not sent:
                                    break  # Client disconnected mid-stream

                except openai.OpenAIError as e:
                    await chat_manager.send(
                        client_id,
                        WSMessage.error(f"AI error: {str(e)[:100]}")
                    )
                    continue

                # Signal completion with full text and stats
                elapsed = time.time() - stream_start
                await chat_manager.send(
                    client_id,
                    WSMessage(
                        type="message_complete",
                        data={
                            "full_text": "".join(full_response_parts),
                            "token_count": token_count,
                            "elapsed_s": round(elapsed, 3),
                            "tps": round(token_count / elapsed, 1) if elapsed > 0 else 0
                        }
                    )
                )

            # Handle unknown message types gracefully
            elif msg.type not in ("ping", "pong", "auth"):
                await chat_manager.send(
                    client_id,
                    WSMessage.error(f"Unknown message type: {msg.type!r}")
                )

    except WebSocketDisconnect as e:
        logger.info(f"[DISCONNECT] {client_id} (code={e.code})")

    except Exception as e:
        logger.error(f"[ERROR] {client_id}: {type(e).__name__}: {e}")

    finally:
        hb_task.cancel()
        chat_manager.disconnect(client_id)


@app_chat.get("/stats")
async def get_stats():
    """REST endpoint to check connection stats."""
    return chat_manager.get_stats()


print("Multi-client AI chat app defined.")
print()
print("To run: save this to ws_chat_app.py then:")
print("  uvicorn ws_chat_app:app_chat --reload --port 8001")
print()
print("Endpoints:")
print("  WS  ws://localhost:8001/ws/chat/{client_id}")
print("  GET http://localhost:8001/stats")

Part 7: Broadcasting to Multiple ClientsΒΆ

Room/Channel PatternΒΆ

Group clients into named rooms. Messages sent to a room are broadcast to all members:

Room "general":  [alice, bob, carol]
Room "python":   [alice, dave]
Room "js":       [bob, evan]

Message to "general" β†’ alice, bob, carol receive it
Message to "python"  β†’ alice, dave receive it

Fan-Out ConsiderationsΒΆ

  • Sequential fan-out: Send to each client one at a time (safe, slower)

  • Concurrent fan-out: Use asyncio.gather (faster, but errors can interfere)

  • Dead connection cleanup: Remove failed sends from the registry

# Room-based broadcast manager

class RoomManager:
    """
    Extends ConnectionManager with room/channel support.
    Clients can join multiple rooms and receive broadcasts to those rooms.
    """

    def __init__(self):
        self._connections: Dict[str, WebSocket] = {}
        self._client_rooms: Dict[str, Set[str]] = {}  # client_id -> set of room_ids
        self._room_members: Dict[str, Set[str]] = {}  # room_id -> set of client_ids

    async def connect(self, websocket: WebSocket, client_id: str):
        """Accept and register a connection (not in any room yet)."""
        await websocket.accept()
        self._connections[client_id] = websocket
        self._client_rooms[client_id] = set()
        logger.info(f"[ROOM] {client_id} connected")

    def disconnect(self, client_id: str):
        """Remove client from all rooms and the connection registry."""
        for room_id in list(self._client_rooms.get(client_id, [])):
            self._leave_room_internal(client_id, room_id)
        self._connections.pop(client_id, None)
        self._client_rooms.pop(client_id, None)
        logger.info(f"[ROOM] {client_id} disconnected")

    async def join_room(self, client_id: str, room_id: str):
        """Add a client to a room and notify other room members."""
        if room_id not in self._room_members:
            self._room_members[room_id] = set()

        self._room_members[room_id].add(client_id)
        self._client_rooms[client_id].add(room_id)

        # Notify existing members
        await self.broadcast_to_room(
            room_id,
            WSMessage.system(f"{client_id} joined #{room_id}"),
            exclude={client_id}
        )

        # Tell the new member how many people are in the room
        member_count = len(self._room_members[room_id])
        ws = self._connections.get(client_id)
        if ws:
            msg = WSMessage.system(f"You joined #{room_id} ({member_count} member(s))")
            await ws.send_text(msg.to_json())

        logger.info(f"[ROOM] {client_id} joined #{room_id} (members: {member_count})")

    async def leave_room(self, client_id: str, room_id: str):
        """Remove a client from a room."""
        self._leave_room_internal(client_id, room_id)
        await self.broadcast_to_room(
            room_id,
            WSMessage.system(f"{client_id} left #{room_id}"),
        )

    def _leave_room_internal(self, client_id: str, room_id: str):
        if room_id in self._room_members:
            self._room_members[room_id].discard(client_id)
            if not self._room_members[room_id]:  # Clean up empty rooms
                del self._room_members[room_id]
        if client_id in self._client_rooms:
            self._client_rooms[client_id].discard(room_id)

    async def broadcast_to_room(
        self,
        room_id: str,
        message: WSMessage,
        exclude: Set[str] = None
    ) -> int:
        """
        Send a message to all members of a room.
        Returns the number of clients reached.
        """
        exclude = exclude or set()
        members = list(self._room_members.get(room_id, set()))
        dead = []
        sent_count = 0

        for cid in members:
            if cid in exclude:
                continue
            ws = self._connections.get(cid)
            if not ws:
                dead.append(cid)
                continue
            try:
                await ws.send_text(message.to_json())
                sent_count += 1
            except Exception:
                dead.append(cid)

        # Clean up dead connections
        for cid in dead:
            self.disconnect(cid)

        return sent_count

    def room_info(self, room_id: str) -> dict:
        members = list(self._room_members.get(room_id, set()))
        return {
            "room_id": room_id,
            "member_count": len(members),
            "members": members
        }

    def list_rooms(self) -> list:
        return [
            self.room_info(rid)
            for rid in self._room_members
        ]


room_manager = RoomManager()
print("RoomManager defined.")
print()
print("Usage in FastAPI:")
print("  @app.websocket('/ws/{room_id}/{client_id}')")
print("  async def room_endpoint(ws, room_id, client_id):")
print("      await room_manager.connect(ws, client_id)")
print("      await room_manager.join_room(client_id, room_id)")
print("      ...")

Part 8: Authentication with WebSocketsΒΆ

The ChallengeΒΆ

HTTP headers are sent during the WebSocket upgrade handshake but the browser’s EventSource and WebSocket APIs do not support custom headers. This creates a challenge for authentication.

Authentication OptionsΒΆ

Method

Security

Complexity

Browser Support

Token in query param

Low (logged in URLs)

Simple

Full

Cookie-based

Good (HttpOnly cookie)

Medium

Full

First-message auth

Good

Medium

Full

Header (non-browser)

Good

Simple

Python/node only

Part 9: Error Handling and Reconnection with Exponential BackoffΒΆ

Why Reconnection Logic Is EssentialΒΆ

Unlike SSE (which auto-reconnects in browsers), WebSockets require you to implement reconnection logic yourself. Connections drop due to:

  • Network blips (mobile switching WiFi to 4G)

  • Server restarts (deploys, crashes)

  • Proxy/load balancer timeouts

  • Client-side network changes

Exponential Backoff FormulaΒΆ

delay = min(base * 2^attempt + jitter, max_delay)

With base=1s, jitter=U(0, 1):
  Attempt 1:  1s  + jitter
  Attempt 2:  2s  + jitter
  Attempt 3:  4s  + jitter
  Attempt 4:  8s  + jitter
  Attempt 5:  16s + jitter
  Attempt 6:  32s + jitter
  Attempt 7+: 60s (capped)

Jitter prevents thundering herd: if 1000 clients all disconnect simultaneously and retry at exactly the same time, they can overwhelm the server. Random jitter spreads out retries.

Circuit BreakerΒΆ

After N consecutive failures, stop retrying and surface an error to the user. This prevents endlessly burning retries against a server that is permanently down.

# Resilient WebSocket client with exponential backoff and circuit breaker
import websockets
import asyncio
import random
from typing import Callable, Awaitable


class CircuitBreaker:
    """
    Simple circuit breaker: opens after max_failures, resets after reset_timeout.
    States: CLOSED (normal) β†’ OPEN (blocking) β†’ HALF_OPEN (testing) β†’ CLOSED
    """

    def __init__(self, max_failures: int = 5, reset_timeout: float = 60.0):
        self.max_failures = max_failures
        self.reset_timeout = reset_timeout
        self.failure_count = 0
        self.last_failure_time: Optional[float] = None
        self.state = "CLOSED"  # CLOSED, OPEN, HALF_OPEN

    def record_success(self):
        self.failure_count = 0
        self.state = "CLOSED"

    def record_failure(self):
        self.failure_count += 1
        self.last_failure_time = time.time()
        if self.failure_count >= self.max_failures:
            self.state = "OPEN"
            logger.warning(f"[CIRCUIT] OPEN after {self.failure_count} failures")

    def is_open(self) -> bool:
        if self.state == "CLOSED":
            return False
        if self.state == "OPEN":
            # Check if reset timeout has passed
            if self.last_failure_time and (time.time() - self.last_failure_time) > self.reset_timeout:
                self.state = "HALF_OPEN"
                logger.info("[CIRCUIT] HALF_OPEN - will try one request")
                return False
            return True
        return False  # HALF_OPEN: allow one through


async def resilient_ws_client(
    uri: str,
    message_handler: Callable,
    max_retries: int = 6,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    jitter: float = 1.0
):
    """
    WebSocket client with automatic reconnection using exponential backoff.

    Args:
        uri: WebSocket server URI
        message_handler: async function(ws, message_str) -> bool
            Return True to continue, False to stop gracefully.
        max_retries: Give up after this many consecutive failures
        base_delay: Initial backoff delay in seconds
        max_delay: Maximum delay between retries
        jitter: Maximum random jitter to add (prevents thundering herd)
    """
    circuit = CircuitBreaker(max_failures=max_retries)
    attempt = 0

    while attempt <= max_retries:
        if circuit.is_open():
            logger.error(f"[CIRCUIT] Open - not attempting connection")
            break

        try:
            logger.info(f"[CONNECT] Attempt {attempt + 1}/{max_retries + 1} β†’ {uri}")

            async with websockets.connect(
                uri,
                ping_interval=30,
                ping_timeout=10,
                open_timeout=5.0
            ) as ws:
                logger.info(f"[CONNECTED] Successfully connected")
                circuit.record_success()
                attempt = 0  # Reset attempt counter on success

                # Process incoming messages
                async for raw_message in ws:
                    should_continue = await message_handler(ws, raw_message)
                    if not should_continue:
                        logger.info("[STOP] Handler requested graceful shutdown")
                        return  # Clean exit

        except websockets.ConnectionClosedOK:
            logger.info("[CLOSED] Connection closed normally (code 1000)")
            return  # Don't retry on clean close

        except websockets.ConnectionClosedError as e:
            logger.warning(f"[CLOSED ERROR] code={e.code}, reason={e.reason!r}")
            circuit.record_failure()

            # Don't retry on auth errors
            if e.code == 1008:
                logger.error("[AUTH] Policy violation - not retrying")
                break

        except (OSError, websockets.InvalidURI) as e:
            logger.warning(f"[CONNECTION ERROR] {type(e).__name__}: {e}")
            circuit.record_failure()

        except Exception as e:
            logger.error(f"[UNEXPECTED] {type(e).__name__}: {e}")
            circuit.record_failure()

        # Calculate backoff delay
        attempt += 1
        if attempt > max_retries:
            logger.error(f"[FAILED] Exhausted {max_retries} retries. Giving up.")
            break

        delay = min(base_delay * (2 ** (attempt - 1)), max_delay)
        delay += random.uniform(0, jitter)  # Add jitter
        logger.info(f"[RETRY] Waiting {delay:.2f}s before attempt {attempt + 1}...")
        await asyncio.sleep(delay)


# Visualize the backoff schedule
print("Exponential backoff schedule (base=1s, max=60s, jitter=1s):")
print(f"{'Attempt':>8} {'Min delay':>12} {'Max delay':>12}")
print("-" * 35)
base, max_d, jitter = 1.0, 60.0, 1.0
for i in range(1, 8):
    delay = min(base * (2 ** (i - 1)), max_d)
    print(f"{i:>8} {delay:>11.1f}s {delay + jitter:>11.1f}s")

Part 10: Connection PoolingΒΆ

Why Pool WebSocket Connections?ΒΆ

In server-to-server scenarios (e.g., your API server calling a WebSocket-based AI service), creating a new WebSocket connection for each request adds overhead:

  • TCP handshake: ~20-100ms

  • TLS handshake: ~50-200ms

  • WebSocket upgrade: ~10-30ms

Total per-connection overhead: 80-330ms β€” which defeats the purpose of streaming.

Pool DesignΒΆ

Request 1 ──────► acquire() ──► WS conn A (from pool) ──► send/recv ──► release() ──► back to pool
Request 2 ──────► acquire() ──► WS conn B (from pool) ──► send/recv ──► release() ──► back to pool
Request 3 ──────► acquire() ──► WS conn C (new)       ──► send/recv ──► release() ──► back to pool
Request 4 ──────► acquire() ──► WAIT (pool full)...
# WebSocket connection pool for server-to-server use
import asyncio
import websockets
from contextlib import asynccontextmanager


class WebSocketPool:
    """
    Async pool of WebSocket connections for reuse across requests.
    Useful when your server needs to make many WebSocket calls to another service.

    Usage:
        pool = WebSocketPool("wss://api.example.com/ws", pool_size=5)
        await pool.initialize()

        async with pool.connection() as ws:
            await ws.send(json.dumps({"type": "query", "data": "..."}))  
            response = await ws.recv()
    """

    def __init__(self, uri: str, pool_size: int = 5, connect_timeout: float = 5.0):
        self.uri = uri
        self.pool_size = pool_size
        self.connect_timeout = connect_timeout
        self._pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
        self._total_created = 0
        self._initialized = False

    async def initialize(self):
        """Pre-create all connections in the pool."""
        logger.info(f"[POOL] Initializing {self.pool_size} connections to {self.uri}")
        for _ in range(self.pool_size):
            try:
                ws = await asyncio.wait_for(
                    websockets.connect(self.uri, ping_interval=30),
                    timeout=self.connect_timeout
                )
                await self._pool.put(ws)
                self._total_created += 1
            except Exception as e:
                logger.warning(f"[POOL] Failed to pre-create connection: {e}")
        self._initialized = True
        logger.info(f"[POOL] Ready with {self._pool.qsize()}/{self.pool_size} connections")

    async def _create_connection(self) -> websockets.WebSocketClientProtocol:
        """Create a new connection."""
        ws = await asyncio.wait_for(
            websockets.connect(self.uri, ping_interval=30),
            timeout=self.connect_timeout
        )
        self._total_created += 1
        return ws

    async def acquire(self) -> websockets.WebSocketClientProtocol:
        """
        Get a connection from the pool.
        Creates a new one if pool is empty and below pool_size.
        Blocks if pool is at capacity.
        """
        try:
            ws = self._pool.get_nowait()
            # Check if connection is still alive
            if ws.closed:
                logger.debug("[POOL] Dead connection found, creating new one")
                ws = await self._create_connection()
            return ws
        except asyncio.QueueEmpty:
            if self._total_created < self.pool_size:
                return await self._create_connection()
            # Wait for a connection to be released
            logger.debug("[POOL] All connections in use, waiting...")
            ws = await self._pool.get()
            if ws.closed:
                ws = await self._create_connection()
            return ws

    async def release(self, ws: websockets.WebSocketClientProtocol):
        """Return a connection to the pool."""
        if not ws.closed:
            try:
                self._pool.put_nowait(ws)
            except asyncio.QueueFull:
                # Pool is full (shouldn't happen with proper usage)
                await ws.close()
        else:
            # Don't return dead connections to the pool
            self._total_created -= 1

    @asynccontextmanager
    async def connection(self):
        """Context manager for pool connections (auto-acquire and release)."""
        ws = await self.acquire()
        try:
            yield ws
        except Exception:
            # On error, close the connection rather than returning to pool
            await ws.close()
            self._total_created -= 1
            raise
        else:
            await self.release(ws)

    async def close_all(self):
        """Close all connections in the pool."""
        while not self._pool.empty():
            ws = self._pool.get_nowait()
            await ws.close()
        self._total_created = 0
        logger.info("[POOL] All connections closed")

    def stats(self) -> dict:
        return {
            "pool_size": self.pool_size,
            "available": self._pool.qsize(),
            "in_use": self._total_created - self._pool.qsize(),
            "total_created": self._total_created
        }


print("WebSocketPool defined.")
print()
print("Usage example:")
print("""
pool = WebSocketPool("wss://api.example.com/ws", pool_size=5)
await pool.initialize()

# In request handler:
async with pool.connection() as ws:
    await ws.send(json.dumps({"query": "Hello"}))
    response = await ws.recv()
    print(response)
# Connection is automatically returned to the pool

# Cleanup on shutdown:
await pool.close_all()
""")

Part 11: Complete Demo Chat ApplicationΒΆ

ArchitectureΒΆ

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         Chat Application                            β”‚
β”‚                                                                     β”‚
β”‚  Python Terminal Clients                 FastAPI Server             β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”  β”‚
β”‚  β”‚  alice_client.py β”‚ ←── WebSocket ──► β”‚  /ws/{room}/{user}    β”‚  β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                   β”‚                       β”‚  β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                   β”‚  RoomManager          β”‚  β”‚
β”‚  β”‚  bob_client.py   β”‚ ←── WebSocket ──► β”‚  (broadcast)          β”‚  β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                   β”‚                       β”‚  β”‚
β”‚                                         β”‚  OpenAI AsyncClient   β”‚  β”‚
β”‚                                         β”‚  (streaming tokens)   β”‚  β”‚
β”‚                                         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Message flow:
  1. alice sends {type: chat, data: {text: "Hello AI"}}
  2. Server broadcasts the message to all room members
  3. Server calls OpenAI API with stream=True
  4. Each token β†’ {type: token, data: {text: "H"}}
  5. Broadcast tokens to all room members in real-time
  6. {type: message_complete} signals end of generation
# Save complete chat server to a runnable file

server_code = '''
"""Complete WebSocket-based AI chat server."""
import os, json, asyncio, uuid, time, logging
from datetime import datetime
from typing import Dict, Set, Optional, Any
from dataclasses import dataclass, asdict
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import openai
from dotenv import load_dotenv

load_dotenv()
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("chat_server")

app = FastAPI(title="AI Chat Server")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
oai = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))

@dataclass
class WSMessage:
    type: str
    data: Any
    message_id: str = None
    timestamp: str = None
    def __post_init__(self):
        if not self.message_id: self.message_id = uuid.uuid4().hex[:8]
        if not self.timestamp: self.timestamp = datetime.utcnow().isoformat() + "Z"
    def to_json(self): return json.dumps(asdict(self))
    @classmethod
    def from_json(cls, raw):
        d = json.loads(raw)
        return cls(type=d["type"], data=d.get("data"), message_id=d.get("message_id"), timestamp=d.get("timestamp"))

class RoomManager:
    def __init__(self):
        self.connections: Dict[str, WebSocket] = {}
        self.rooms: Dict[str, Set[str]] = {}

    async def connect(self, ws: WebSocket, client_id: str, room_id: str):
        await ws.accept()
        self.connections[client_id] = ws
        self.rooms.setdefault(room_id, set()).add(client_id)
        await self.broadcast(room_id, WSMessage("system", f"{client_id} joined #{room_id}"), exclude={client_id})
        logger.info(f"[JOIN] {client_id} β†’ #{room_id} (members: {len(self.rooms[room_id])})")

    def disconnect(self, client_id: str, room_id: str):
        self.connections.pop(client_id, None)
        if room_id in self.rooms:
            self.rooms[room_id].discard(client_id)

    async def broadcast(self, room_id: str, msg: WSMessage, exclude: Set[str] = None):
        exclude = exclude or set()
        dead = []
        for cid in list(self.rooms.get(room_id, [])):
            if cid in exclude: continue
            ws = self.connections.get(cid)
            if ws:
                try: await ws.send_text(msg.to_json())
                except: dead.append(cid)
        for cid in dead: self.disconnect(cid, room_id)

rooms = RoomManager()

@app.websocket("/ws/{room_id}/{client_id}")
async def chat_endpoint(ws: WebSocket, room_id: str, client_id: str):
    await rooms.connect(ws, client_id, room_id)
    try:
        while True:
            raw = await ws.receive_text()
            msg = WSMessage.from_json(raw)
            if msg.type == "pong": continue
            if msg.type != "chat": continue
            text = msg.data.get("text", "") if isinstance(msg.data, dict) else str(msg.data)
            if not text.strip(): continue

            # Echo user message to room
            await rooms.broadcast(room_id, WSMessage("chat", {"from": client_id, "text": text}))

            # Stream AI response
            await rooms.broadcast(room_id, WSMessage("typing_start", {"user": "AI"}))
            parts, count, t0 = [], 0, time.time()
            async with oai.chat.completions.stream(
                model="gpt-4o-mini", messages=[{"role":"user","content":text}], max_tokens=300
            ) as stream:
                async for chunk in stream:
                    tok = chunk.choices[0].delta.content
                    if tok:
                        parts.append(tok); count += 1
                        await rooms.broadcast(room_id, WSMessage("token", {"text": tok}))

            elapsed = time.time() - t0
            await rooms.broadcast(room_id, WSMessage("message_complete", {
                "full_text": "".join(parts),
                "tokens": count,
                "tps": round(count/elapsed, 1) if elapsed else 0
            }))

    except WebSocketDisconnect:
        rooms.disconnect(client_id, room_id)
        await rooms.broadcast(room_id, WSMessage("system", f"{client_id} left #{room_id}"))

@app.get("/rooms")
def list_rooms():
    return [{"room": r, "members": list(m)} for r, m in rooms.rooms.items()]

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)
'''

server_path = "/tmp/ws_chat_server.py"
with open(server_path, "w") as f:
    f.write(server_code)

print(f"Server written to {server_path}")
print("Start: python /tmp/ws_chat_server.py")
# Save complete chat client to a runnable file

client_code = '''
"""Interactive WebSocket chat client."""
import asyncio, json, sys, uuid
from datetime import datetime
try:
    import websockets
except ImportError:
    print("Install: pip install websockets")
    sys.exit(1)

SERVER_URL = "ws://localhost:8001"

async def chat_client(room_id: str, username: str):
    uri = f"{SERVER_URL}/ws/{room_id}/{username}"
    print(f"Connecting to #{room_id} as {username}...")
    print("Type messages and press Enter. Type /quit to exit.\\n")

    async with websockets.connect(uri, ping_interval=30, ping_timeout=10) as ws:
        print(f"Connected!")

        async def receive_loop():
            """Background task: continuously receive and display messages."""
            ai_streaming = False
            async for raw in ws:
                try:
                    data = json.loads(raw)
                except json.JSONDecodeError:
                    continue

                t = data.get("type")

                if t == "system":
                    print(f"\\n[System] {data.get(\'data\')}")
                elif t == "chat":
                    d = data.get("data", {})
                    if d.get("from") != username:
                        print(f"\\n[{d.get(\'from\')}]: {d.get(\'text\')}")
                elif t == "typing_start":
                    print("\\nAI: ", end="", flush=True)
                    ai_streaming = True
                elif t == "token":
                    print(data["data"]["text"], end="", flush=True)
                elif t == "message_complete":
                    d = data.get("data", {})
                    print(f"\\n  [{d.get(\'tokens\')} tokens @ {d.get(\'tps\')} TPS]")
                    ai_streaming = False
                elif t == "error":
                    print(f"\\n[ERROR] {data.get(\'data\')}")
                elif t == "ping":
                    await ws.send(json.dumps({"type": "pong", "data": None}))

        recv_task = asyncio.create_task(receive_loop())

        loop = asyncio.get_event_loop()
        try:
            while True:
                line = await loop.run_in_executor(None, input, f"[{username}]: ")
                if line.strip().lower() in ("/quit", "/exit", "quit", "exit"):
                    break
                if not line.strip():
                    continue
                payload = json.dumps({"type": "chat", "data": {"text": line}})
                await ws.send(payload)
        except (KeyboardInterrupt, EOFError):
            pass
        finally:
            recv_task.cancel()
            print("\\nDisconnecting...")

if __name__ == "__main__":
    room = sys.argv[1] if len(sys.argv) > 1 else "general"
    user = sys.argv[2] if len(sys.argv) > 2 else f"user_{uuid.uuid4().hex[:4]}"
    asyncio.run(chat_client(room, user))
'''

client_path = "/tmp/ws_chat_client.py"
with open(client_path, "w") as f:
    f.write(client_code)

print(f"Client written to {client_path}")
print()
print("To run a multi-user demo:")
print()
print("  Terminal 1 (server):")
print("    python /tmp/ws_chat_server.py")
print()
print("  Terminal 2 (alice):")
print("    python /tmp/ws_chat_client.py general alice")
print()
print("  Terminal 3 (bob):")
print("    python /tmp/ws_chat_client.py general bob")
print()
print("Both alice and bob will see the AI response streamed in real-time!")
# Quick self-test: verify all classes and functions are importable and usable

print("Self-test: verifying all components...\n")

# 1. WSMessage round-trip
msg = WSMessage.chat("Hello!", user="test")
parsed = WSMessage.from_json(msg.to_json())
assert parsed.type == "chat"
assert parsed.data["text"] == "Hello!"
print("[OK] WSMessage serialization/deserialization")

# 2. ConnectionManager instantiation
cm = ConnectionManager()
assert cm.active_count == 0
stats = cm.get_stats()
assert stats["active_connections"] == 0
print("[OK] ConnectionManager instantiation")

# 3. RoomManager instantiation
rm = RoomManager()
assert rm.list_rooms() == []
print("[OK] RoomManager instantiation")

# 4. CircuitBreaker logic
cb = CircuitBreaker(max_failures=3, reset_timeout=60.0)
assert cb.state == "CLOSED"
assert not cb.is_open()
cb.record_failure()
cb.record_failure()
cb.record_failure()
assert cb.state == "OPEN"
assert cb.is_open()
cb.record_success()
assert cb.state == "CLOSED"
print("[OK] CircuitBreaker state transitions")

# 5. WebSocketPool instantiation
pool = WebSocketPool("ws://localhost:9999", pool_size=3)
s = pool.stats()
assert s["pool_size"] == 3
assert s["available"] == 0
print("[OK] WebSocketPool instantiation")

# 6. Verify server and client files were created
import os
assert os.path.exists("/tmp/ws_chat_server.py")
assert os.path.exists("/tmp/ws_chat_client.py")
print("[OK] Server and client scripts written to /tmp/")

# 7. Verify VALID_TOKENS dictionary
assert len(VALID_TOKENS) == 3
assert verify_token("token_alice_secret_123") is not None
assert verify_token("bad_token") is None
print("[OK] Authentication token verification")

# 8. Backoff schedule sanity check
delays = [min(1.0 * (2 ** i), 60.0) for i in range(7)]
assert delays[0] == 1.0
assert delays[5] == 32.0
assert delays[6] == 60.0
print("[OK] Exponential backoff schedule")

print("\nAll self-tests passed!")

SummaryΒΆ

What We CoveredΒΆ

  1. SSE vs WebSockets: Use SSE for server-to-client streaming (LLM tokens, logs). Use WebSockets for bidirectional communication (chat, collaboration, games).

  2. WebSocket protocol: HTTP upgrade handshake, frame types (text/binary/ping/pong/close), connection states (CONNECTING/OPEN/CLOSING/CLOSED), close codes.

  3. FastAPI WebSocket: @app.websocket(), await websocket.accept(), receive_text() / send_text() / send_json(), WebSocketDisconnect.

  4. Message envelope pattern: Always use typed JSON envelopes. Include type, data, message_id, timestamp in every message.

  5. ConnectionManager: Central registry for all active connections. Clean up on disconnect to prevent memory leaks.

  6. Heartbeat: Application-level ping/pong to detect dead connections. Run as a background asyncio.Task, cancel on disconnect.

  7. Multiple concurrent connections: asyncio handles hundreds of connections cooperatively. Never block the event loop with synchronous calls.

  8. RoomManager: Group clients into named rooms for targeted broadcasts. Clean up empty rooms and dead connections on every broadcast.

  9. Authentication: First-message auth is the safest browser-compatible approach. Validate within AUTH_TIMEOUT_SECONDS or close with code 1008.

  10. Exponential backoff: delay = min(base * 2^attempt + jitter, max_delay). Jitter prevents thundering herd. Circuit breaker prevents endless retries.

  11. Connection pooling: Reuse connections in server-to-server scenarios. Use context manager pattern (async with pool.connection() as ws:) for safety.

Decision GuideΒΆ

What do you need?
β”œβ”€β”€ Stream AI tokens to browser?
β”‚     └── Use SSE (Notebook 1) - simpler, auto-reconnect
β”‚
β”œβ”€β”€ Interactive AI chat with stop/modify?
β”‚     └── Use WebSockets - bidirectional control
β”‚
β”œβ”€β”€ Multi-user real-time chat?
β”‚     └── WebSockets + RoomManager + broadcast
β”‚
β”œβ”€β”€ Server calling another WebSocket API many times?
β”‚     └── WebSocketPool - reuse connections
β”‚
└── Mobile/unreliable network client?
      └── WebSockets + resilient_ws_client with backoff

Security ChecklistΒΆ

  • Authenticate before allowing any application messages

  • Rate limit messages per connection (not just per IP)

  • Validate and sanitize all message data fields

  • Set max_size on websockets.connect() to prevent memory exhaustion

  • Use wss:// (TLS) in production, never ws://

  • Close with appropriate codes on auth failure (1008)

Quick ReferenceΒΆ

# Server
@app.websocket("/ws/{client_id}")
async def endpoint(ws: WebSocket, client_id: str):
    await ws.accept()
    try:
        while True:
            raw = await ws.receive_text()
            msg = WSMessage.from_json(raw)
            await ws.send_text(WSMessage.system("Got it").to_json())
    except WebSocketDisconnect:
        pass  # Clean up here

# Client
async with websockets.connect("ws://localhost:8001/ws/alice") as ws:
    await ws.send(WSMessage.chat("Hello AI").to_json())
    response = WSMessage.from_json(await ws.recv())
    print(response.data)

# Authenticated endpoint pattern
user = await authenticate_websocket(websocket)  # First-message auth
if not user:
    return  # Already closed with 1008

Files created during this notebook:

  • /tmp/ws_chat_server.py β€” Full room-based AI chat server

  • /tmp/ws_chat_client.py β€” Interactive terminal client

Next: Notebook 3 covers chunked generation and progressive rendering patterns.