WebSocket Connections for Real-Time AI ChatΒΆ
Phase 20 - Notebook 2ΒΆ
What you will learn:
SSE vs WebSockets: when to use which
WebSocket protocol basics and lifecycle
FastAPI WebSocket endpoint implementation
Bidirectional communication patterns and message envelopes
Connection lifecycle management (connect, authenticate, message, disconnect)
Heartbeat/ping-pong for connection health
Handling multiple concurrent connections
Broadcasting to multiple clients (rooms/channels)
Authentication with WebSocket headers and first-message auth
Client-side WebSocket with Python
websocketslibraryComplete demo chat application (server + client code)
Error handling and reconnection with exponential backoff
Connection pooling patterns
Prerequisites: Notebook 1 (SSE Streaming), Python async/await, FastAPI basics
# Install required packages
!pip install fastapi websockets uvicorn openai python-dotenv httpx -q
import os
import asyncio
import json
import time
import random
import uuid
import logging
from datetime import datetime
from typing import Dict, Set, Optional, Any
from dataclasses import dataclass, asdict
from dotenv import load_dotenv
import openai
load_dotenv()
# Configure logging for connection events
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S"
)
logger = logging.getLogger("ws_demo")
openai_key = os.getenv("OPENAI_API_KEY", "")
print("Environment setup:")
print(f" OpenAI key: {'set' if openai_key else 'NOT SET'}")
print("\nAll imports successful.")
Part 1: SSE vs WebSocketsΒΆ
Side-by-Side ComparisonΒΆ
Feature |
Server-Sent Events (SSE) |
WebSockets |
|---|---|---|
Direction |
Server β Client only |
Full duplex (both ways) |
Protocol |
HTTP/1.1 or HTTP/2 |
Upgraded HTTP β ws:// |
Browser auto-reconnect |
Yes (built-in) |
No (must implement) |
Firewalls/Proxies |
Usually passes through |
May be blocked |
Data types |
Text only |
Text and binary |
Max connections |
Limited by browser (6 per domain in HTTP/1.1) |
No practical limit |
Complexity |
Simple |
Moderate |
Use cases |
LLM token streaming, logs, notifications |
Chat, games, collaboration, trading |
Decision FlowchartΒΆ
Does the CLIENT need to send messages after the initial request?
β
YES β NO
β βββΊ Use SSE (EventSource)
β
βΌ
Is low latency bidirectional messaging required?
β
YES β NO
β βββΊ Consider polling or SSE with a separate POST endpoint
β
βΌ
Use WebSockets
Real-World ExamplesΒΆ
Use SSE: ChatGPTβs streaming response (server streams tokens, you only click send once)
Use WebSockets: Claudeβs web app (bidirectional with stop button, multi-turn conversation state)
Use WebSockets: GitHub Copilot in VS Code (editor sends context continuously, server streams code)
Part 2: WebSocket Protocol BasicsΒΆ
The HTTP Upgrade HandshakeΒΆ
A WebSocket connection starts as a normal HTTP request with special headers:
Client β Server:
GET /ws HTTP/1.1
Host: localhost:8001
Upgrade: websocket β Request upgrade
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZQ== β Random 16-byte base64
Sec-WebSocket-Version: 13
Server β Client:
HTTP/1.1 101 Switching Protocols β Accepted
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= β Derived from key
After the 101 response, the TCP connection is kept open and both sides can send at any time.
Frame TypesΒΆ
Opcode |
Frame Type |
Description |
|---|---|---|
|
Text |
UTF-8 text payload |
|
Binary |
Raw bytes |
|
Close |
Initiate close handshake |
|
Ping |
Health check (server sends) |
|
Pong |
Health check response |
Connection LifecycleΒΆ
CONNECTING βββΊ OPEN βββΊ CLOSING βββΊ CLOSED
β β β
β messages close(code)
β ping/pong
TCP connect
HTTP upgrade
Close CodesΒΆ
Code |
Meaning |
|---|---|
1000 |
Normal closure |
1001 |
Going away (server restart) |
1008 |
Policy violation (auth failure) |
1011 |
Server error |
4000-4999 |
Application-defined |
# Basic FastAPI WebSocket server - simplest possible echo server
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
app_basic = FastAPI(title="Basic WebSocket Echo")
app_basic.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
)
@app_basic.websocket("/ws")
async def websocket_echo(websocket: WebSocket):
"""
Minimal echo server. Accepts a connection and echoes every message back.
Lifecycle:
1. Client connects β server calls websocket.accept()
2. Client sends text β server echoes it back
3. Client disconnects β WebSocketDisconnect is raised
"""
# STEP 1: Accept the connection (completes the HTTP upgrade handshake)
await websocket.accept()
client_host = websocket.client.host if websocket.client else "unknown"
print(f"[CONNECTED] {client_host}")
try:
while True:
# STEP 2: Wait for a message from the client
# receive_text() blocks until a text frame arrives
data = await websocket.receive_text()
print(f"[RECEIVED] {data!r}")
# STEP 3: Send a response back
response = json.dumps({
"type": "echo",
"original": data,
"timestamp": datetime.now().isoformat()
})
await websocket.send_text(response)
except WebSocketDisconnect as e:
# STEP 4: Handle clean or unclean disconnections
print(f"[DISCONNECTED] {client_host} (code={e.code})")
except Exception as e:
print(f"[ERROR] {type(e).__name__}: {e}")
await websocket.close(code=1011) # Server error
@app_basic.get("/health")
async def health():
return {"status": "ok"}
print("Basic echo WebSocket server defined.")
print()
print("To run: save to ws_echo.py, then:")
print(" uvicorn ws_echo:app_basic --reload --port 8001")
print()
print("Test with Python websockets client (next cell) or:")
print(" wscat -c ws://localhost:8001/ws")
# Python WebSocket client using the 'websockets' library
import websockets
import asyncio
import json
async def simple_ws_client(
uri: str = "ws://localhost:8001/ws",
messages: list = None
):
"""
Connect to a WebSocket server, send messages, and receive responses.
"""
if messages is None:
messages = ["Hello, WebSocket!", "How are you?", "Goodbye!"]
print(f"Connecting to {uri}...")
async with websockets.connect(
uri,
ping_interval=20, # Send ping every 20s to keep connection alive
ping_timeout=10 # Disconnect if no pong within 10s
) as ws:
print(f"Connected! (state={ws.state.name})")
for msg in messages:
# Send message
await ws.send(msg)
print(f" Sent: {msg!r}")
# Wait for response
raw = await ws.recv()
data = json.loads(raw)
print(f" Received: type={data.get('type')!r}, ts={data.get('timestamp', '')[:19]}")
await asyncio.sleep(0.2)
print("Connection closed cleanly.")
# Usage (requires server running on port 8001)
print("WebSocket client defined.")
print()
print("Usage (requires echo server on port 8001):")
print(" await simple_ws_client('ws://localhost:8001/ws')")
print()
print("# Uncomment to run if server is active:")
# await simple_ws_client("ws://localhost:8001/ws")
Part 3: Bidirectional Communication PatternsΒΆ
Message Envelope PatternΒΆ
Instead of sending raw strings, wrap every message in a typed envelope:
{
"type": "chat",
"data": {"text": "Hello!", "user": "alice"},
"message_id": "a1b2c3d4",
"timestamp": "2025-01-01T12:00:00.000Z"
}
Benefits:
Extensibility: Add fields without breaking clients
Routing: Server/client can route by
typewithout parsingdataDebugging: Every message is self-describing
Deduplication: Use
message_idto detect duplicate delivery
Common Message TypesΒΆ
Type |
Direction |
Purpose |
|---|---|---|
|
C β S |
Send token for authentication |
|
S β C |
Confirm authentication |
|
C β S |
User sends a message |
|
S β C |
AI streaming token |
|
S β C |
Full AI response done |
|
S β C |
Server announcement |
|
S β C |
Heartbeat check |
|
C β S |
Heartbeat response |
|
S β C |
Error notification |
|
S β C |
AI started generating |
# Message envelope dataclass with serialization
@dataclass
class WSMessage:
"""
Structured WebSocket message envelope.
All WebSocket messages should use this format.
"""
type: str # Message category (auth, chat, token, ping, error, ...)
data: Any # Payload - can be str, dict, list, or None
message_id: str = None
timestamp: str = None
def __post_init__(self):
if self.message_id is None:
self.message_id = uuid.uuid4().hex[:8]
if self.timestamp is None:
self.timestamp = datetime.utcnow().isoformat() + "Z"
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, raw: str) -> "WSMessage":
try:
d = json.loads(raw)
return cls(
type=d.get("type", "unknown"),
data=d.get("data"),
message_id=d.get("message_id"),
timestamp=d.get("timestamp")
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in WebSocket message: {e}")
@classmethod
def chat(cls, text: str, user: str = "user") -> "WSMessage":
return cls(type="chat", data={"text": text, "user": user})
@classmethod
def token(cls, text: str) -> "WSMessage":
return cls(type="token", data={"text": text})
@classmethod
def system(cls, text: str) -> "WSMessage":
return cls(type="system", data={"text": text})
@classmethod
def error(cls, message: str, code: int = 0) -> "WSMessage":
return cls(type="error", data={"message": message, "code": code})
@classmethod
def ping(cls) -> "WSMessage":
return cls(type="ping", data=None)
@classmethod
def pong(cls) -> "WSMessage":
return cls(type="pong", data=None)
# Demonstrate envelope usage
print("WSMessage examples:\n")
msg1 = WSMessage.chat("Hello, AI!", user="alice")
print(f"Chat message: {msg1.to_json()}")
msg2 = WSMessage.token("Hello")
print(f"Token message: {msg2.to_json()}")
msg3 = WSMessage.error("Rate limit exceeded", code=429)
print(f"Error message: {msg3.to_json()}")
# Round-trip: serialize then deserialize
raw = msg1.to_json()
parsed = WSMessage.from_json(raw)
print(f"\nRound-trip:")
print(f" Original: type={msg1.type}, data={msg1.data}")
print(f" Parsed: type={parsed.type}, data={parsed.data}")
print(f" IDs match: {msg1.message_id == parsed.message_id}")
Part 4: Connection Lifecycle ManagementΒΆ
Connection StatesΒΆ
CONNECTING
β
βΌ (websocket.accept() called)
OPEN βββββ All normal messaging happens here
β
βΌ (websocket.close() or WebSocketDisconnect raised)
CLOSING
β
βΌ (TCP connection torn down)
CLOSED
Server-Side Responsibilities Per StateΒΆ
State |
Server Action |
|---|---|
CONNECTING |
Validate origin, check rate limits, call |
OPEN |
Process messages, send responses, run heartbeat |
CLOSING |
Save state, clean up resources, remove from connection registry |
CLOSED |
Remove references (prevent memory leaks) |
Common MistakesΒΆ
Not removing from registry on disconnect β memory leak, stale connections
Not catching WebSocketDisconnect β unhandled exception logs noise
Blocking the event loop β all connections freeze (use
await, nottime.sleep)Not calling accept() β connection never opens
# ConnectionManager: tracks all active connections
from fastapi import WebSocket, WebSocketDisconnect
class ConnectionManager:
"""
Central registry for all active WebSocket connections.
Thread-safe for use within a single asyncio event loop.
"""
def __init__(self):
# client_id -> WebSocket
self._connections: Dict[str, WebSocket] = {}
# client_id -> metadata dict
self._metadata: Dict[str, dict] = {}
self._total_connections = 0 # Monotonically increasing counter
async def connect(self, websocket: WebSocket, client_id: str, **metadata):
"""Accept and register a new connection."""
await websocket.accept()
self._connections[client_id] = websocket
self._metadata[client_id] = {
"connected_at": datetime.utcnow().isoformat(),
"message_count": 0,
**metadata
}
self._total_connections += 1
logger.info(f"[CONNECT] {client_id} | active={self.active_count}")
def disconnect(self, client_id: str):
"""Remove a connection from the registry (does NOT close the socket)."""
self._connections.pop(client_id, None)
self._metadata.pop(client_id, None)
logger.info(f"[DISCONNECT] {client_id} | active={self.active_count}")
async def send(self, client_id: str, message: WSMessage) -> bool:
"""
Send a WSMessage to a specific client.
Returns True on success, False if client is gone.
"""
ws = self._connections.get(client_id)
if not ws:
return False
try:
await ws.send_text(message.to_json())
if client_id in self._metadata:
self._metadata[client_id]["message_count"] += 1
return True
except Exception as e:
logger.warning(f"[SEND FAILED] {client_id}: {e}")
self.disconnect(client_id)
return False
async def broadcast(
self,
message: WSMessage,
exclude: Set[str] = None
) -> int:
"""
Send a message to all connected clients.
Returns the number of clients successfully reached.
"""
exclude = exclude or set()
client_ids = list(self._connections.keys())
successes = 0
for cid in client_ids:
if cid not in exclude:
if await self.send(cid, message):
successes += 1
return successes
def is_connected(self, client_id: str) -> bool:
return client_id in self._connections
@property
def active_count(self) -> int:
return len(self._connections)
def get_stats(self) -> dict:
return {
"active_connections": self.active_count,
"total_ever_connected": self._total_connections,
"clients": [
{"id": cid, **meta}
for cid, meta in self._metadata.items()
]
}
# Demo
manager = ConnectionManager()
print("ConnectionManager defined.")
print(f"Initial stats: {manager.get_stats()}")
Part 5: Heartbeat / Ping-PongΒΆ
Why Heartbeats Are NecessaryΒΆ
TCP connections can become silently dead due to:
NAT timeouts: Many routers drop idle connections after 60-300 seconds
Mobile network switching: WiFi β 4G β WiFi transitions
Proxy/load balancer timeouts: Nginx default is 60s idle timeout
Crashed clients: No close frame sent, server doesnβt know
Without heartbeats, the server holds resources for dead connections indefinitely.
Two Levels of Ping/PongΒΆ
WebSocket protocol ping/pong (opcode 0x9/0xA): Handled transparently by the library. The
websocketslibrary sends these automatically withping_interval.Application-level ping/pong (JSON messages): Your own heartbeat using the message envelope. Useful for detecting logic-level hangs, not just TCP-level drops.
Recommended ConfigurationΒΆ
# Server-side (protocol level)
@app.websocket("/ws")
async def endpoint(ws: WebSocket):
# FastAPI/Starlette handles protocol ping automatically
...
# Client-side (protocol level via websockets library)
async with websockets.connect(
uri,
ping_interval=30, # Ping every 30s
ping_timeout=10 # Give up if no pong in 10s
) as ws:
...
# Application-level heartbeat implementation
async def run_heartbeat(
websocket: WebSocket,
client_id: str,
interval: float = 30.0,
timeout: float = 10.0
):
"""
Send periodic application-level pings.
Disconnect if pong is not received within timeout.
This runs as a background asyncio task alongside the message handler.
Cancel this task when the connection closes.
"""
missed_pongs = 0
max_missed = 2 # Disconnect after 2 missed pongs
while True:
await asyncio.sleep(interval)
ping = WSMessage.ping()
try:
await websocket.send_text(ping.to_json())
logger.debug(f"[PING] β {client_id}")
except Exception:
logger.warning(f"[PING FAILED] Cannot reach {client_id}")
break
# Wait for pong response
# (In practice, the main message loop should update a 'last_pong' timestamp
# and we check that timestamp here. This simplified version just counts.)
# A full implementation would use an asyncio.Event:
# pong_received = asyncio.Event()
# try: await asyncio.wait_for(pong_received.wait(), timeout)
# except asyncio.TimeoutError: missed_pongs += 1
# For demo purposes we just log that we sent a ping
logger.debug(f"[HEARTBEAT] Ping sent to {client_id}")
async def handle_connection_with_heartbeat(
websocket: WebSocket,
client_id: str
):
"""
WebSocket handler that runs heartbeat as a concurrent background task.
The heartbeat task is cancelled when the main loop exits.
"""
# Start heartbeat as background task
heartbeat_task = asyncio.create_task(
run_heartbeat(websocket, client_id, interval=15.0)
)
try:
while True:
raw = await websocket.receive_text()
msg = WSMessage.from_json(raw)
if msg.type == "pong":
# Client responded to our ping
logger.debug(f"[PONG] β {client_id}")
continue
# Handle other message types
response = WSMessage.system(f"Received your '{msg.type}' message")
await websocket.send_text(response.to_json())
except WebSocketDisconnect:
logger.info(f"[DISCONNECT] {client_id}")
except asyncio.CancelledError:
pass
finally:
# Always cancel the heartbeat task when connection ends
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
print("Heartbeat functions defined.")
print()
print("Usage in a FastAPI app:")
print()
print(" @app.websocket('/ws/{client_id}')")
print(" async def ws_endpoint(ws: WebSocket, client_id: str):")
print(" await ws.accept()")
print(" await handle_connection_with_heartbeat(ws, client_id)")
Part 6: Multiple Concurrent ConnectionsΒΆ
asyncio Concurrency ModelΒΆ
FastAPIβs WebSocket handling uses asyncio. Each connection runs in the same event loop via cooperative multitasking:
Event Loop
β
ββ Coroutine: handle_client(alice) β awaiting receive_text()
ββ Coroutine: handle_client(bob) β awaiting receive_text()
ββ Coroutine: handle_client(carol) β streaming tokens to carol
ββ Coroutine: heartbeat_task β sleeping for 30s
When any coroutine hits an await, Python switches to another coroutine. This allows
hundreds of concurrent connections with a single process.
Scaling Beyond One ProcessΒΆ
For thousands of concurrent connections:
Run multiple Uvicorn workers:
uvicorn app:app --workers 4Use Redis Pub/Sub to broadcast across workers
Use a sticky session load balancer (connections must reach the same worker)
# Full multi-client AI chat FastAPI application
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
app_chat = FastAPI(title="Multi-Client AI Chat")
app_chat.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
)
# Global connection registry
chat_manager = ConnectionManager()
@app_chat.websocket("/ws/chat/{client_id}")
async def chat_endpoint(websocket: WebSocket, client_id: str):
"""
Multi-client AI chat WebSocket endpoint.
Each client_id gets its own persistent connection.
AI responses are streamed token by token over the WebSocket.
"""
# Register the connection
await chat_manager.connect(websocket, client_id, username=client_id)
# Welcome the new client
await chat_manager.send(
client_id,
WSMessage.system(f"Welcome, {client_id}! {chat_manager.active_count} user(s) online.")
)
# Start heartbeat in background
hb_task = asyncio.create_task(
run_heartbeat(websocket, client_id, interval=30.0)
)
# OpenAI async client
oai = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
try:
while True:
# Wait for a message from this client
raw = await websocket.receive_text()
# Parse using our envelope
try:
msg = WSMessage.from_json(raw)
except ValueError as e:
await chat_manager.send(
client_id, WSMessage.error(f"Invalid message format: {e}")
)
continue
# Handle pong (heartbeat response)
if msg.type == "pong":
continue
# Handle chat message
if msg.type == "chat":
user_text = msg.data.get("text", "") if isinstance(msg.data, dict) else str(msg.data)
if not user_text.strip():
continue
# Signal that AI is generating
await chat_manager.send(
client_id,
WSMessage(type="typing_start", data={"model": "gpt-4o-mini"})
)
# Stream AI response token by token
full_response_parts = []
token_count = 0
stream_start = time.time()
try:
async with oai.chat.completions.stream(
model="gpt-4o-mini",
messages=[{"role": "user", "content": user_text}],
max_tokens=300
) as stream:
async for chunk in stream:
token = chunk.choices[0].delta.content
if token:
full_response_parts.append(token)
token_count += 1
# Send each token as a separate WebSocket message
sent = await chat_manager.send(
client_id,
WSMessage.token(token)
)
if not sent:
break # Client disconnected mid-stream
except openai.OpenAIError as e:
await chat_manager.send(
client_id,
WSMessage.error(f"AI error: {str(e)[:100]}")
)
continue
# Signal completion with full text and stats
elapsed = time.time() - stream_start
await chat_manager.send(
client_id,
WSMessage(
type="message_complete",
data={
"full_text": "".join(full_response_parts),
"token_count": token_count,
"elapsed_s": round(elapsed, 3),
"tps": round(token_count / elapsed, 1) if elapsed > 0 else 0
}
)
)
# Handle unknown message types gracefully
elif msg.type not in ("ping", "pong", "auth"):
await chat_manager.send(
client_id,
WSMessage.error(f"Unknown message type: {msg.type!r}")
)
except WebSocketDisconnect as e:
logger.info(f"[DISCONNECT] {client_id} (code={e.code})")
except Exception as e:
logger.error(f"[ERROR] {client_id}: {type(e).__name__}: {e}")
finally:
hb_task.cancel()
chat_manager.disconnect(client_id)
@app_chat.get("/stats")
async def get_stats():
"""REST endpoint to check connection stats."""
return chat_manager.get_stats()
print("Multi-client AI chat app defined.")
print()
print("To run: save this to ws_chat_app.py then:")
print(" uvicorn ws_chat_app:app_chat --reload --port 8001")
print()
print("Endpoints:")
print(" WS ws://localhost:8001/ws/chat/{client_id}")
print(" GET http://localhost:8001/stats")
Part 7: Broadcasting to Multiple ClientsΒΆ
Room/Channel PatternΒΆ
Group clients into named rooms. Messages sent to a room are broadcast to all members:
Room "general": [alice, bob, carol]
Room "python": [alice, dave]
Room "js": [bob, evan]
Message to "general" β alice, bob, carol receive it
Message to "python" β alice, dave receive it
Fan-Out ConsiderationsΒΆ
Sequential fan-out: Send to each client one at a time (safe, slower)
Concurrent fan-out: Use
asyncio.gather(faster, but errors can interfere)Dead connection cleanup: Remove failed sends from the registry
# Room-based broadcast manager
class RoomManager:
"""
Extends ConnectionManager with room/channel support.
Clients can join multiple rooms and receive broadcasts to those rooms.
"""
def __init__(self):
self._connections: Dict[str, WebSocket] = {}
self._client_rooms: Dict[str, Set[str]] = {} # client_id -> set of room_ids
self._room_members: Dict[str, Set[str]] = {} # room_id -> set of client_ids
async def connect(self, websocket: WebSocket, client_id: str):
"""Accept and register a connection (not in any room yet)."""
await websocket.accept()
self._connections[client_id] = websocket
self._client_rooms[client_id] = set()
logger.info(f"[ROOM] {client_id} connected")
def disconnect(self, client_id: str):
"""Remove client from all rooms and the connection registry."""
for room_id in list(self._client_rooms.get(client_id, [])):
self._leave_room_internal(client_id, room_id)
self._connections.pop(client_id, None)
self._client_rooms.pop(client_id, None)
logger.info(f"[ROOM] {client_id} disconnected")
async def join_room(self, client_id: str, room_id: str):
"""Add a client to a room and notify other room members."""
if room_id not in self._room_members:
self._room_members[room_id] = set()
self._room_members[room_id].add(client_id)
self._client_rooms[client_id].add(room_id)
# Notify existing members
await self.broadcast_to_room(
room_id,
WSMessage.system(f"{client_id} joined #{room_id}"),
exclude={client_id}
)
# Tell the new member how many people are in the room
member_count = len(self._room_members[room_id])
ws = self._connections.get(client_id)
if ws:
msg = WSMessage.system(f"You joined #{room_id} ({member_count} member(s))")
await ws.send_text(msg.to_json())
logger.info(f"[ROOM] {client_id} joined #{room_id} (members: {member_count})")
async def leave_room(self, client_id: str, room_id: str):
"""Remove a client from a room."""
self._leave_room_internal(client_id, room_id)
await self.broadcast_to_room(
room_id,
WSMessage.system(f"{client_id} left #{room_id}"),
)
def _leave_room_internal(self, client_id: str, room_id: str):
if room_id in self._room_members:
self._room_members[room_id].discard(client_id)
if not self._room_members[room_id]: # Clean up empty rooms
del self._room_members[room_id]
if client_id in self._client_rooms:
self._client_rooms[client_id].discard(room_id)
async def broadcast_to_room(
self,
room_id: str,
message: WSMessage,
exclude: Set[str] = None
) -> int:
"""
Send a message to all members of a room.
Returns the number of clients reached.
"""
exclude = exclude or set()
members = list(self._room_members.get(room_id, set()))
dead = []
sent_count = 0
for cid in members:
if cid in exclude:
continue
ws = self._connections.get(cid)
if not ws:
dead.append(cid)
continue
try:
await ws.send_text(message.to_json())
sent_count += 1
except Exception:
dead.append(cid)
# Clean up dead connections
for cid in dead:
self.disconnect(cid)
return sent_count
def room_info(self, room_id: str) -> dict:
members = list(self._room_members.get(room_id, set()))
return {
"room_id": room_id,
"member_count": len(members),
"members": members
}
def list_rooms(self) -> list:
return [
self.room_info(rid)
for rid in self._room_members
]
room_manager = RoomManager()
print("RoomManager defined.")
print()
print("Usage in FastAPI:")
print(" @app.websocket('/ws/{room_id}/{client_id}')")
print(" async def room_endpoint(ws, room_id, client_id):")
print(" await room_manager.connect(ws, client_id)")
print(" await room_manager.join_room(client_id, room_id)")
print(" ...")
Part 8: Authentication with WebSocketsΒΆ
The ChallengeΒΆ
HTTP headers are sent during the WebSocket upgrade handshake but the browserβs EventSource
and WebSocket APIs do not support custom headers. This creates a challenge for authentication.
Authentication OptionsΒΆ
Method |
Security |
Complexity |
Browser Support |
|---|---|---|---|
Token in query param |
Low (logged in URLs) |
Simple |
Full |
Cookie-based |
Good (HttpOnly cookie) |
Medium |
Full |
First-message auth |
Good |
Medium |
Full |
Header (non-browser) |
Good |
Simple |
Python/node only |
Recommended Pattern: First-Message AuthenticationΒΆ
Server accepts the TCP connection (without calling
websocket.accept()yet)Server sends a challenge or waits for auth message
Client sends
{"type": "auth", "token": "..."}Server validates; calls
close(1008)if invalid, proceeds if valid
# WebSocket authentication implementation
from fastapi import status as http_status
# Simulated token store (in production: validate JWT or query database)
VALID_TOKENS: Dict[str, dict] = {
"token_alice_secret_123": {"user_id": "alice", "role": "user"},
"token_bob_secret_456": {"user_id": "bob", "role": "user"},
"token_admin_secret_789": {"user_id": "admin", "role": "admin"},
}
AUTH_TIMEOUT_SECONDS = 5.0
def verify_token(token: str) -> Optional[dict]:
"""
Validate a token and return user info, or None if invalid.
In production: decode and verify a JWT here.
"""
return VALID_TOKENS.get(token)
async def authenticate_websocket(
websocket: WebSocket
) -> Optional[dict]:
"""
Perform first-message authentication on a WebSocket connection.
Returns the user info dict on success, None on failure.
Closes the WebSocket with an appropriate code on failure.
Pattern:
1. Accept the connection
2. Wait up to AUTH_TIMEOUT_SECONDS for an auth message
3. Validate token
4. Return user info or close with 1008 Policy Violation
"""
await websocket.accept()
# Wait for auth message with timeout
try:
raw = await asyncio.wait_for(
websocket.receive_text(),
timeout=AUTH_TIMEOUT_SECONDS
)
except asyncio.TimeoutError:
await websocket.send_text(
WSMessage.error(f"Authentication timeout ({AUTH_TIMEOUT_SECONDS}s)").to_json()
)
await websocket.close(code=1008) # Policy Violation
logger.warning("[AUTH] Timeout waiting for auth message")
return None
# Parse the auth message
try:
msg = WSMessage.from_json(raw)
except ValueError:
await websocket.send_text(
WSMessage.error("First message must be valid JSON").to_json()
)
await websocket.close(code=1008)
return None
if msg.type != "auth":
await websocket.send_text(
WSMessage.error(f"Expected 'auth' message, got '{msg.type}'").to_json()
)
await websocket.close(code=1008)
logger.warning(f"[AUTH] Wrong first message type: {msg.type!r}")
return None
# Extract and validate token
token = msg.data.get("token", "") if isinstance(msg.data, dict) else ""
user_info = verify_token(token)
if not user_info:
await websocket.send_text(
WSMessage.error("Invalid or expired token", code=401).to_json()
)
await websocket.close(code=1008)
logger.warning(f"[AUTH] Invalid token attempt: {token[:20]!r}...")
return None
# Auth success
await websocket.send_text(
WSMessage(
type="auth_success",
data={"user_id": user_info["user_id"], "role": user_info["role"]}
).to_json()
)
logger.info(f"[AUTH] {user_info['user_id']} authenticated successfully")
return user_info
# Example authenticated endpoint
app_auth = FastAPI(title="Authenticated WebSocket")
@app_auth.websocket("/ws/secure")
async def secure_ws_endpoint(websocket: WebSocket):
"""WebSocket endpoint requiring token authentication."""
user = await authenticate_websocket(websocket)
if not user:
return # Connection closed in authenticate_websocket
# User is authenticated - proceed normally
try:
while True:
raw = await websocket.receive_text()
msg = WSMessage.from_json(raw)
# Echo with user identity
response = WSMessage(
type="response",
data={"echo": msg.data, "from": user["user_id"]}
)
await websocket.send_text(response.to_json())
except WebSocketDisconnect:
logger.info(f"[DISCONNECT] {user['user_id']}")
print("Authentication functions defined.")
print()
print("Valid test tokens:")
for token, info in VALID_TOKENS.items():
print(f" {token!r} β {info}")
# Authenticated WebSocket client
import websockets
import asyncio
async def authenticated_ws_client(
uri: str,
token: str,
messages_to_send: list
):
"""
WebSocket client that authenticates via first-message auth pattern.
"""
print(f"Connecting to {uri}...")
try:
async with websockets.connect(uri, ping_interval=20, ping_timeout=10) as ws:
print("Connected. Sending auth token...")
# Step 1: Send auth message
auth_msg = WSMessage(type="auth", data={"token": token})
await ws.send(auth_msg.to_json())
# Step 2: Wait for auth response
raw_auth = await asyncio.wait_for(ws.recv(), timeout=10.0)
auth_response = WSMessage.from_json(raw_auth)
if auth_response.type == "error":
print(f"Auth failed: {auth_response.data}")
return
if auth_response.type != "auth_success":
print(f"Unexpected auth response: {auth_response.type}")
return
user_info = auth_response.data
print(f"Authenticated as: {user_info.get('user_id')} (role={user_info.get('role')})")
# Step 3: Send and receive messages
for text in messages_to_send:
msg = WSMessage.chat(text, user=user_info.get("user_id", "user"))
await ws.send(msg.to_json())
print(f"\nSent: {text!r}")
# For the AI chat app, collect streaming tokens
collected_tokens = []
while True:
raw = await asyncio.wait_for(ws.recv(), timeout=15.0)
response = WSMessage.from_json(raw)
if response.type == "typing_start":
print("AI: ", end="", flush=True)
continue
if response.type == "token":
token_text = response.data.get("text", "")
print(token_text, end="", flush=True)
collected_tokens.append(token_text)
continue
if response.type == "message_complete":
stats = response.data
print(f"\n[{stats.get('token_count')} tokens, {stats.get('tps')} TPS]")
break
if response.type == "ping":
pong = WSMessage.pong()
await ws.send(pong.to_json())
continue
if response.type == "error":
print(f"\n[ERROR] {response.data}")
break
print("\nAll messages sent. Closing connection.")
except websockets.ConnectionClosedError as e:
print(f"Connection closed: code={e.code}, reason={e.reason!r}")
except asyncio.TimeoutError:
print("Timeout waiting for server response")
print("Authenticated client function defined.")
print()
print("Usage (requires authenticated server on port 8001):")
print(" await authenticated_ws_client(")
print(" uri='ws://localhost:8001/ws/chat/alice',")
print(" token='token_alice_secret_123',")
print(" messages_to_send=['What is Python?']")
print(" )")
Part 9: Error Handling and Reconnection with Exponential BackoffΒΆ
Why Reconnection Logic Is EssentialΒΆ
Unlike SSE (which auto-reconnects in browsers), WebSockets require you to implement reconnection logic yourself. Connections drop due to:
Network blips (mobile switching WiFi to 4G)
Server restarts (deploys, crashes)
Proxy/load balancer timeouts
Client-side network changes
Exponential Backoff FormulaΒΆ
delay = min(base * 2^attempt + jitter, max_delay)
With base=1s, jitter=U(0, 1):
Attempt 1: 1s + jitter
Attempt 2: 2s + jitter
Attempt 3: 4s + jitter
Attempt 4: 8s + jitter
Attempt 5: 16s + jitter
Attempt 6: 32s + jitter
Attempt 7+: 60s (capped)
Jitter prevents thundering herd: if 1000 clients all disconnect simultaneously and retry at exactly the same time, they can overwhelm the server. Random jitter spreads out retries.
Circuit BreakerΒΆ
After N consecutive failures, stop retrying and surface an error to the user. This prevents endlessly burning retries against a server that is permanently down.
# Resilient WebSocket client with exponential backoff and circuit breaker
import websockets
import asyncio
import random
from typing import Callable, Awaitable
class CircuitBreaker:
"""
Simple circuit breaker: opens after max_failures, resets after reset_timeout.
States: CLOSED (normal) β OPEN (blocking) β HALF_OPEN (testing) β CLOSED
"""
def __init__(self, max_failures: int = 5, reset_timeout: float = 60.0):
self.max_failures = max_failures
self.reset_timeout = reset_timeout
self.failure_count = 0
self.last_failure_time: Optional[float] = None
self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN
def record_success(self):
self.failure_count = 0
self.state = "CLOSED"
def record_failure(self):
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.max_failures:
self.state = "OPEN"
logger.warning(f"[CIRCUIT] OPEN after {self.failure_count} failures")
def is_open(self) -> bool:
if self.state == "CLOSED":
return False
if self.state == "OPEN":
# Check if reset timeout has passed
if self.last_failure_time and (time.time() - self.last_failure_time) > self.reset_timeout:
self.state = "HALF_OPEN"
logger.info("[CIRCUIT] HALF_OPEN - will try one request")
return False
return True
return False # HALF_OPEN: allow one through
async def resilient_ws_client(
uri: str,
message_handler: Callable,
max_retries: int = 6,
base_delay: float = 1.0,
max_delay: float = 60.0,
jitter: float = 1.0
):
"""
WebSocket client with automatic reconnection using exponential backoff.
Args:
uri: WebSocket server URI
message_handler: async function(ws, message_str) -> bool
Return True to continue, False to stop gracefully.
max_retries: Give up after this many consecutive failures
base_delay: Initial backoff delay in seconds
max_delay: Maximum delay between retries
jitter: Maximum random jitter to add (prevents thundering herd)
"""
circuit = CircuitBreaker(max_failures=max_retries)
attempt = 0
while attempt <= max_retries:
if circuit.is_open():
logger.error(f"[CIRCUIT] Open - not attempting connection")
break
try:
logger.info(f"[CONNECT] Attempt {attempt + 1}/{max_retries + 1} β {uri}")
async with websockets.connect(
uri,
ping_interval=30,
ping_timeout=10,
open_timeout=5.0
) as ws:
logger.info(f"[CONNECTED] Successfully connected")
circuit.record_success()
attempt = 0 # Reset attempt counter on success
# Process incoming messages
async for raw_message in ws:
should_continue = await message_handler(ws, raw_message)
if not should_continue:
logger.info("[STOP] Handler requested graceful shutdown")
return # Clean exit
except websockets.ConnectionClosedOK:
logger.info("[CLOSED] Connection closed normally (code 1000)")
return # Don't retry on clean close
except websockets.ConnectionClosedError as e:
logger.warning(f"[CLOSED ERROR] code={e.code}, reason={e.reason!r}")
circuit.record_failure()
# Don't retry on auth errors
if e.code == 1008:
logger.error("[AUTH] Policy violation - not retrying")
break
except (OSError, websockets.InvalidURI) as e:
logger.warning(f"[CONNECTION ERROR] {type(e).__name__}: {e}")
circuit.record_failure()
except Exception as e:
logger.error(f"[UNEXPECTED] {type(e).__name__}: {e}")
circuit.record_failure()
# Calculate backoff delay
attempt += 1
if attempt > max_retries:
logger.error(f"[FAILED] Exhausted {max_retries} retries. Giving up.")
break
delay = min(base_delay * (2 ** (attempt - 1)), max_delay)
delay += random.uniform(0, jitter) # Add jitter
logger.info(f"[RETRY] Waiting {delay:.2f}s before attempt {attempt + 1}...")
await asyncio.sleep(delay)
# Visualize the backoff schedule
print("Exponential backoff schedule (base=1s, max=60s, jitter=1s):")
print(f"{'Attempt':>8} {'Min delay':>12} {'Max delay':>12}")
print("-" * 35)
base, max_d, jitter = 1.0, 60.0, 1.0
for i in range(1, 8):
delay = min(base * (2 ** (i - 1)), max_d)
print(f"{i:>8} {delay:>11.1f}s {delay + jitter:>11.1f}s")
Part 10: Connection PoolingΒΆ
Why Pool WebSocket Connections?ΒΆ
In server-to-server scenarios (e.g., your API server calling a WebSocket-based AI service), creating a new WebSocket connection for each request adds overhead:
TCP handshake: ~20-100ms
TLS handshake: ~50-200ms
WebSocket upgrade: ~10-30ms
Total per-connection overhead: 80-330ms β which defeats the purpose of streaming.
Pool DesignΒΆ
Request 1 βββββββΊ acquire() βββΊ WS conn A (from pool) βββΊ send/recv βββΊ release() βββΊ back to pool
Request 2 βββββββΊ acquire() βββΊ WS conn B (from pool) βββΊ send/recv βββΊ release() βββΊ back to pool
Request 3 βββββββΊ acquire() βββΊ WS conn C (new) βββΊ send/recv βββΊ release() βββΊ back to pool
Request 4 βββββββΊ acquire() βββΊ WAIT (pool full)...
# WebSocket connection pool for server-to-server use
import asyncio
import websockets
from contextlib import asynccontextmanager
class WebSocketPool:
"""
Async pool of WebSocket connections for reuse across requests.
Useful when your server needs to make many WebSocket calls to another service.
Usage:
pool = WebSocketPool("wss://api.example.com/ws", pool_size=5)
await pool.initialize()
async with pool.connection() as ws:
await ws.send(json.dumps({"type": "query", "data": "..."}))
response = await ws.recv()
"""
def __init__(self, uri: str, pool_size: int = 5, connect_timeout: float = 5.0):
self.uri = uri
self.pool_size = pool_size
self.connect_timeout = connect_timeout
self._pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
self._total_created = 0
self._initialized = False
async def initialize(self):
"""Pre-create all connections in the pool."""
logger.info(f"[POOL] Initializing {self.pool_size} connections to {self.uri}")
for _ in range(self.pool_size):
try:
ws = await asyncio.wait_for(
websockets.connect(self.uri, ping_interval=30),
timeout=self.connect_timeout
)
await self._pool.put(ws)
self._total_created += 1
except Exception as e:
logger.warning(f"[POOL] Failed to pre-create connection: {e}")
self._initialized = True
logger.info(f"[POOL] Ready with {self._pool.qsize()}/{self.pool_size} connections")
async def _create_connection(self) -> websockets.WebSocketClientProtocol:
"""Create a new connection."""
ws = await asyncio.wait_for(
websockets.connect(self.uri, ping_interval=30),
timeout=self.connect_timeout
)
self._total_created += 1
return ws
async def acquire(self) -> websockets.WebSocketClientProtocol:
"""
Get a connection from the pool.
Creates a new one if pool is empty and below pool_size.
Blocks if pool is at capacity.
"""
try:
ws = self._pool.get_nowait()
# Check if connection is still alive
if ws.closed:
logger.debug("[POOL] Dead connection found, creating new one")
ws = await self._create_connection()
return ws
except asyncio.QueueEmpty:
if self._total_created < self.pool_size:
return await self._create_connection()
# Wait for a connection to be released
logger.debug("[POOL] All connections in use, waiting...")
ws = await self._pool.get()
if ws.closed:
ws = await self._create_connection()
return ws
async def release(self, ws: websockets.WebSocketClientProtocol):
"""Return a connection to the pool."""
if not ws.closed:
try:
self._pool.put_nowait(ws)
except asyncio.QueueFull:
# Pool is full (shouldn't happen with proper usage)
await ws.close()
else:
# Don't return dead connections to the pool
self._total_created -= 1
@asynccontextmanager
async def connection(self):
"""Context manager for pool connections (auto-acquire and release)."""
ws = await self.acquire()
try:
yield ws
except Exception:
# On error, close the connection rather than returning to pool
await ws.close()
self._total_created -= 1
raise
else:
await self.release(ws)
async def close_all(self):
"""Close all connections in the pool."""
while not self._pool.empty():
ws = self._pool.get_nowait()
await ws.close()
self._total_created = 0
logger.info("[POOL] All connections closed")
def stats(self) -> dict:
return {
"pool_size": self.pool_size,
"available": self._pool.qsize(),
"in_use": self._total_created - self._pool.qsize(),
"total_created": self._total_created
}
print("WebSocketPool defined.")
print()
print("Usage example:")
print("""
pool = WebSocketPool("wss://api.example.com/ws", pool_size=5)
await pool.initialize()
# In request handler:
async with pool.connection() as ws:
await ws.send(json.dumps({"query": "Hello"}))
response = await ws.recv()
print(response)
# Connection is automatically returned to the pool
# Cleanup on shutdown:
await pool.close_all()
""")
Part 11: Complete Demo Chat ApplicationΒΆ
ArchitectureΒΆ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Chat Application β
β β
β Python Terminal Clients FastAPI Server β
β ββββββββββββββββββββ βββββββββββββββββββββββββ β
β β alice_client.py β βββ WebSocket βββΊ β /ws/{room}/{user} β β
β ββββββββββββββββββββ β β β
β ββββββββββββββββββββ β RoomManager β β
β β bob_client.py β βββ WebSocket βββΊ β (broadcast) β β
β ββββββββββββββββββββ β β β
β β OpenAI AsyncClient β β
β β (streaming tokens) β β
β βββββββββββββββββββββββββ β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Message flow:
1. alice sends {type: chat, data: {text: "Hello AI"}}
2. Server broadcasts the message to all room members
3. Server calls OpenAI API with stream=True
4. Each token β {type: token, data: {text: "H"}}
5. Broadcast tokens to all room members in real-time
6. {type: message_complete} signals end of generation
# Save complete chat server to a runnable file
server_code = '''
"""Complete WebSocket-based AI chat server."""
import os, json, asyncio, uuid, time, logging
from datetime import datetime
from typing import Dict, Set, Optional, Any
from dataclasses import dataclass, asdict
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import openai
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("chat_server")
app = FastAPI(title="AI Chat Server")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
oai = openai.AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
@dataclass
class WSMessage:
type: str
data: Any
message_id: str = None
timestamp: str = None
def __post_init__(self):
if not self.message_id: self.message_id = uuid.uuid4().hex[:8]
if not self.timestamp: self.timestamp = datetime.utcnow().isoformat() + "Z"
def to_json(self): return json.dumps(asdict(self))
@classmethod
def from_json(cls, raw):
d = json.loads(raw)
return cls(type=d["type"], data=d.get("data"), message_id=d.get("message_id"), timestamp=d.get("timestamp"))
class RoomManager:
def __init__(self):
self.connections: Dict[str, WebSocket] = {}
self.rooms: Dict[str, Set[str]] = {}
async def connect(self, ws: WebSocket, client_id: str, room_id: str):
await ws.accept()
self.connections[client_id] = ws
self.rooms.setdefault(room_id, set()).add(client_id)
await self.broadcast(room_id, WSMessage("system", f"{client_id} joined #{room_id}"), exclude={client_id})
logger.info(f"[JOIN] {client_id} β #{room_id} (members: {len(self.rooms[room_id])})")
def disconnect(self, client_id: str, room_id: str):
self.connections.pop(client_id, None)
if room_id in self.rooms:
self.rooms[room_id].discard(client_id)
async def broadcast(self, room_id: str, msg: WSMessage, exclude: Set[str] = None):
exclude = exclude or set()
dead = []
for cid in list(self.rooms.get(room_id, [])):
if cid in exclude: continue
ws = self.connections.get(cid)
if ws:
try: await ws.send_text(msg.to_json())
except: dead.append(cid)
for cid in dead: self.disconnect(cid, room_id)
rooms = RoomManager()
@app.websocket("/ws/{room_id}/{client_id}")
async def chat_endpoint(ws: WebSocket, room_id: str, client_id: str):
await rooms.connect(ws, client_id, room_id)
try:
while True:
raw = await ws.receive_text()
msg = WSMessage.from_json(raw)
if msg.type == "pong": continue
if msg.type != "chat": continue
text = msg.data.get("text", "") if isinstance(msg.data, dict) else str(msg.data)
if not text.strip(): continue
# Echo user message to room
await rooms.broadcast(room_id, WSMessage("chat", {"from": client_id, "text": text}))
# Stream AI response
await rooms.broadcast(room_id, WSMessage("typing_start", {"user": "AI"}))
parts, count, t0 = [], 0, time.time()
async with oai.chat.completions.stream(
model="gpt-4o-mini", messages=[{"role":"user","content":text}], max_tokens=300
) as stream:
async for chunk in stream:
tok = chunk.choices[0].delta.content
if tok:
parts.append(tok); count += 1
await rooms.broadcast(room_id, WSMessage("token", {"text": tok}))
elapsed = time.time() - t0
await rooms.broadcast(room_id, WSMessage("message_complete", {
"full_text": "".join(parts),
"tokens": count,
"tps": round(count/elapsed, 1) if elapsed else 0
}))
except WebSocketDisconnect:
rooms.disconnect(client_id, room_id)
await rooms.broadcast(room_id, WSMessage("system", f"{client_id} left #{room_id}"))
@app.get("/rooms")
def list_rooms():
return [{"room": r, "members": list(m)} for r, m in rooms.rooms.items()]
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)
'''
server_path = "/tmp/ws_chat_server.py"
with open(server_path, "w") as f:
f.write(server_code)
print(f"Server written to {server_path}")
print("Start: python /tmp/ws_chat_server.py")
# Save complete chat client to a runnable file
client_code = '''
"""Interactive WebSocket chat client."""
import asyncio, json, sys, uuid
from datetime import datetime
try:
import websockets
except ImportError:
print("Install: pip install websockets")
sys.exit(1)
SERVER_URL = "ws://localhost:8001"
async def chat_client(room_id: str, username: str):
uri = f"{SERVER_URL}/ws/{room_id}/{username}"
print(f"Connecting to #{room_id} as {username}...")
print("Type messages and press Enter. Type /quit to exit.\\n")
async with websockets.connect(uri, ping_interval=30, ping_timeout=10) as ws:
print(f"Connected!")
async def receive_loop():
"""Background task: continuously receive and display messages."""
ai_streaming = False
async for raw in ws:
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
t = data.get("type")
if t == "system":
print(f"\\n[System] {data.get(\'data\')}")
elif t == "chat":
d = data.get("data", {})
if d.get("from") != username:
print(f"\\n[{d.get(\'from\')}]: {d.get(\'text\')}")
elif t == "typing_start":
print("\\nAI: ", end="", flush=True)
ai_streaming = True
elif t == "token":
print(data["data"]["text"], end="", flush=True)
elif t == "message_complete":
d = data.get("data", {})
print(f"\\n [{d.get(\'tokens\')} tokens @ {d.get(\'tps\')} TPS]")
ai_streaming = False
elif t == "error":
print(f"\\n[ERROR] {data.get(\'data\')}")
elif t == "ping":
await ws.send(json.dumps({"type": "pong", "data": None}))
recv_task = asyncio.create_task(receive_loop())
loop = asyncio.get_event_loop()
try:
while True:
line = await loop.run_in_executor(None, input, f"[{username}]: ")
if line.strip().lower() in ("/quit", "/exit", "quit", "exit"):
break
if not line.strip():
continue
payload = json.dumps({"type": "chat", "data": {"text": line}})
await ws.send(payload)
except (KeyboardInterrupt, EOFError):
pass
finally:
recv_task.cancel()
print("\\nDisconnecting...")
if __name__ == "__main__":
room = sys.argv[1] if len(sys.argv) > 1 else "general"
user = sys.argv[2] if len(sys.argv) > 2 else f"user_{uuid.uuid4().hex[:4]}"
asyncio.run(chat_client(room, user))
'''
client_path = "/tmp/ws_chat_client.py"
with open(client_path, "w") as f:
f.write(client_code)
print(f"Client written to {client_path}")
print()
print("To run a multi-user demo:")
print()
print(" Terminal 1 (server):")
print(" python /tmp/ws_chat_server.py")
print()
print(" Terminal 2 (alice):")
print(" python /tmp/ws_chat_client.py general alice")
print()
print(" Terminal 3 (bob):")
print(" python /tmp/ws_chat_client.py general bob")
print()
print("Both alice and bob will see the AI response streamed in real-time!")
# Quick self-test: verify all classes and functions are importable and usable
print("Self-test: verifying all components...\n")
# 1. WSMessage round-trip
msg = WSMessage.chat("Hello!", user="test")
parsed = WSMessage.from_json(msg.to_json())
assert parsed.type == "chat"
assert parsed.data["text"] == "Hello!"
print("[OK] WSMessage serialization/deserialization")
# 2. ConnectionManager instantiation
cm = ConnectionManager()
assert cm.active_count == 0
stats = cm.get_stats()
assert stats["active_connections"] == 0
print("[OK] ConnectionManager instantiation")
# 3. RoomManager instantiation
rm = RoomManager()
assert rm.list_rooms() == []
print("[OK] RoomManager instantiation")
# 4. CircuitBreaker logic
cb = CircuitBreaker(max_failures=3, reset_timeout=60.0)
assert cb.state == "CLOSED"
assert not cb.is_open()
cb.record_failure()
cb.record_failure()
cb.record_failure()
assert cb.state == "OPEN"
assert cb.is_open()
cb.record_success()
assert cb.state == "CLOSED"
print("[OK] CircuitBreaker state transitions")
# 5. WebSocketPool instantiation
pool = WebSocketPool("ws://localhost:9999", pool_size=3)
s = pool.stats()
assert s["pool_size"] == 3
assert s["available"] == 0
print("[OK] WebSocketPool instantiation")
# 6. Verify server and client files were created
import os
assert os.path.exists("/tmp/ws_chat_server.py")
assert os.path.exists("/tmp/ws_chat_client.py")
print("[OK] Server and client scripts written to /tmp/")
# 7. Verify VALID_TOKENS dictionary
assert len(VALID_TOKENS) == 3
assert verify_token("token_alice_secret_123") is not None
assert verify_token("bad_token") is None
print("[OK] Authentication token verification")
# 8. Backoff schedule sanity check
delays = [min(1.0 * (2 ** i), 60.0) for i in range(7)]
assert delays[0] == 1.0
assert delays[5] == 32.0
assert delays[6] == 60.0
print("[OK] Exponential backoff schedule")
print("\nAll self-tests passed!")
SummaryΒΆ
What We CoveredΒΆ
SSE vs WebSockets: Use SSE for server-to-client streaming (LLM tokens, logs). Use WebSockets for bidirectional communication (chat, collaboration, games).
WebSocket protocol: HTTP upgrade handshake, frame types (text/binary/ping/pong/close), connection states (CONNECTING/OPEN/CLOSING/CLOSED), close codes.
FastAPI WebSocket:
@app.websocket(),await websocket.accept(),receive_text()/send_text()/send_json(),WebSocketDisconnect.Message envelope pattern: Always use typed JSON envelopes. Include
type,data,message_id,timestampin every message.ConnectionManager: Central registry for all active connections. Clean up on disconnect to prevent memory leaks.
Heartbeat: Application-level ping/pong to detect dead connections. Run as a background
asyncio.Task, cancel on disconnect.Multiple concurrent connections: asyncio handles hundreds of connections cooperatively. Never block the event loop with synchronous calls.
RoomManager: Group clients into named rooms for targeted broadcasts. Clean up empty rooms and dead connections on every broadcast.
Authentication: First-message auth is the safest browser-compatible approach. Validate within
AUTH_TIMEOUT_SECONDSor close with code 1008.Exponential backoff:
delay = min(base * 2^attempt + jitter, max_delay). Jitter prevents thundering herd. Circuit breaker prevents endless retries.Connection pooling: Reuse connections in server-to-server scenarios. Use context manager pattern (
async with pool.connection() as ws:) for safety.
Decision GuideΒΆ
What do you need?
βββ Stream AI tokens to browser?
β βββ Use SSE (Notebook 1) - simpler, auto-reconnect
β
βββ Interactive AI chat with stop/modify?
β βββ Use WebSockets - bidirectional control
β
βββ Multi-user real-time chat?
β βββ WebSockets + RoomManager + broadcast
β
βββ Server calling another WebSocket API many times?
β βββ WebSocketPool - reuse connections
β
βββ Mobile/unreliable network client?
βββ WebSockets + resilient_ws_client with backoff
Security ChecklistΒΆ
Authenticate before allowing any application messages
Rate limit messages per connection (not just per IP)
Validate and sanitize all message
datafieldsSet
max_sizeonwebsockets.connect()to prevent memory exhaustionUse
wss://(TLS) in production, neverws://Close with appropriate codes on auth failure (1008)
Quick ReferenceΒΆ
# Server
@app.websocket("/ws/{client_id}")
async def endpoint(ws: WebSocket, client_id: str):
await ws.accept()
try:
while True:
raw = await ws.receive_text()
msg = WSMessage.from_json(raw)
await ws.send_text(WSMessage.system("Got it").to_json())
except WebSocketDisconnect:
pass # Clean up here
# Client
async with websockets.connect("ws://localhost:8001/ws/alice") as ws:
await ws.send(WSMessage.chat("Hello AI").to_json())
response = WSMessage.from_json(await ws.recv())
print(response.data)
# Authenticated endpoint pattern
user = await authenticate_websocket(websocket) # First-message auth
if not user:
return # Already closed with 1008
Files created during this notebook:
/tmp/ws_chat_server.pyβ Full room-based AI chat server/tmp/ws_chat_client.pyβ Interactive terminal client
Next: Notebook 3 covers chunked generation and progressive rendering patterns.