03: Deep Q-Networks (DQN)ยถ
โThe question of whether a computer can think is no more interesting than the question of whether a submarine can swim.โ - Edsger W. Dijkstra
Welcome to the revolutionary world of Deep Q-Networks (DQN)! This notebook introduces the groundbreaking combination of Q-learning with deep neural networks, enabling RL to solve complex problems with high-dimensional state spaces.
๐ฏ Learning Objectivesยถ
By the end of this notebook, youโll understand:
Why we need function approximation for complex problems
The core innovations that made DQN work
Experience replay and target networks
How to implement DQN from scratch
Practical tips for training stable DQNs
๐ง The Problem with Tabular Q-Learningยถ
Limitations of Tabular Methodsยถ
State space explosion: Canโt handle large or continuous state spaces
Memory inefficiency: Need to store Q-values for every state-action pair
Generalization: No ability to generalize across similar states
The Solution: Function Approximationยถ
Use a function approximator to estimate Q-values
Neural networks are powerful function approximators
Can handle high-dimensional inputs (images, sensor data)
Generalize across similar states
๐ฎ Deep Q-Networksยถ
Core Ideaยถ
Replace the Q-table with a deep neural network that takes states as input and outputs Q-values for all actions.
Architectureยถ
State โ Neural Network โ Q-values for each action
Training Objectiveยถ
Minimize the temporal difference error:
L(ฮธ) = E[(y_target - Q(s,a;ฮธ))^2]
Where:
y_target = r + ฮณ max_a' Q(s',a';ฮธ_target)(for non-terminal states)y_target = r(for terminal states)
๐ง Key Innovations in DQNยถ
1. Experience Replayยถ
Store experiences in a replay buffer
Sample mini-batches randomly for training
Benefits: Reduces correlation, improves sample efficiency, enables off-policy learning
2. Target Networkยถ
Maintain a separate target network for computing target Q-values
Update target network periodically (soft updates or hard updates)
Benefits: Stabilizes training, reduces oscillations
3. Frame Stackingยถ
Stack multiple consecutive frames as input
Captures temporal information (motion, velocity)
Benefits: Better representation of dynamic environments
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, namedtuple
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
from typing import List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')
# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = [12, 8]
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
Building the DQN Architectureยถ
The neural network at the heart of DQN takes a state representation as input and produces a Q-value estimate for every possible action. For our grid world the input is a simple 2-element vector \((row, col)\), but the same architecture scales to image pixels or sensor readings. Two hidden layers with ReLU activations give the network enough capacity to learn nonlinear relationships between position and value. The output layer has one neuron per action, so a single forward pass yields all Q-values needed to select the greedy action via argmax. This โall actions at onceโ design is far more efficient than running a separate forward pass per action, and it is the standard pattern used in Atari-playing DQNs.
class DQN(nn.Module):
"""Deep Q-Network for grid world"""
def __init__(self, input_size: int, hidden_size: int, output_size: int):
super(DQN, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)
)
def forward(self, x):
return self.network(x)
# Test the network
input_size = 2 # (row, col) position
hidden_size = 64
output_size = 4 # 4 actions: up, down, left, right
dqn = DQN(input_size, hidden_size, output_size)
print(f"DQN Architecture:\n{dqn}")
# Test forward pass
test_state = torch.tensor([[0, 0]], dtype=torch.float32) # Start position
q_values = dqn(test_state)
print(f"\nQ-values for start state: {q_values.detach().numpy()}")
print(f"Best action: {['up', 'down', 'left', 'right'][torch.argmax(q_values).item()]}")
Experience Replay Bufferยถ
Experience replay is one of the two critical innovations that made DQN stable. Without it, consecutive training samples are highly correlated (the agent visits similar states in sequence), which violates the i.i.d. assumption of stochastic gradient descent and causes the network to oscillate or diverge. The replay buffer stores transitions \((s, a, r, s', \text{done})\) in a fixed-capacity deque and samples random mini-batches for each gradient step. This decorrelates the data, smooths out learning, and allows each experience to be reused many times โ dramatically improving sample efficiency. In production systems, replay buffers can hold millions of transitions and may use prioritized sampling to focus on the most informative experiences.
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done'])
class ReplayBuffer:
"""Experience replay buffer"""
def __init__(self, capacity: int = 10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
"""Add experience to buffer"""
experience = Experience(state, action, reward, next_state, done)
self.buffer.append(experience)
def sample(self, batch_size: int) -> List[Experience]:
"""Sample batch of experiences"""
return random.sample(self.buffer, min(batch_size, len(self.buffer)))
def __len__(self):
return len(self.buffer)
# Test the replay buffer
buffer = ReplayBuffer(capacity=100)
# Add some experiences
for i in range(10):
state = np.array([i % 4, i // 4])
action = i % 4
reward = -1 if i % 3 != 0 else 10
next_state = np.array([(i+1) % 4, (i+1) // 4])
done = (i == 9)
buffer.push(state, action, reward, next_state, done)
print(f"Buffer size: {len(buffer)}")
# Sample a batch
batch = buffer.sample(5)
print(f"Sampled batch size: {len(batch)}")
print(f"First experience: {batch[0]}")
DQN Agent Implementationยถ
The DQNAgent class brings together the three core DQN components: (1) a policy network that is updated every step, (2) a target network whose weights are periodically copied from the policy network to provide stable regression targets, and (3) the replay buffer from above. The training loop follows: select an action with epsilon-greedy, store the transition, sample a mini-batch, compute the loss \(L(\theta) = \mathbb{E}\bigl[(r + \gamma \max_{a'} Q(s', a'; \theta^{-}) - Q(s, a; \theta))^2\bigr]\) where \(\theta^{-}\) are the frozen target-network weights, and take a gradient step on the policy network. This separation of โwho selects the targetโ from โwho gets updatedโ prevents the moving-target problem that destabilizes naive neural Q-learning.
class DQNAgent:
"""Deep Q-Network Agent"""
def __init__(self, env, hidden_size: int = 64, learning_rate: float = 1e-3,
gamma: float = 0.99, epsilon_start: float = 1.0, epsilon_end: float = 0.01,
epsilon_decay: float = 0.995, buffer_size: int = 10000, batch_size: int = 32,
target_update_freq: int = 10):
self.env = env
self.gamma = gamma
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.batch_size = batch_size
self.target_update_freq = target_update_freq
# Networks
input_size = 2 # (row, col)
output_size = len(env.actions)
self.policy_net = DQN(input_size, hidden_size, output_size)
self.target_net = DQN(input_size, hidden_size, output_size)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
# Optimizer
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
# Replay buffer
self.replay_buffer = ReplayBuffer(buffer_size)
# Action mapping
self.action_to_idx = {action: idx for idx, action in enumerate(env.actions)}
self.idx_to_action = {idx: action for action, idx in self.action_to_idx.items()}
# Training stats
self.episode_rewards = []
self.episode_lengths = []
self.losses = []
self.training_step = 0
def state_to_tensor(self, state):
"""Convert state to tensor"""
return torch.tensor(state, dtype=torch.float32).unsqueeze(0)
def get_action(self, state, explore: bool = True):
"""Select action using epsilon-greedy policy"""
if explore and random.random() < self.epsilon:
# Explore
action_idx = random.randint(0, len(self.env.actions) - 1)
else:
# Exploit
with torch.no_grad():
state_tensor = self.state_to_tensor(state)
q_values = self.policy_net(state_tensor)
action_idx = torch.argmax(q_values).item()
return self.idx_to_action[action_idx]
def optimize_model(self):
"""Perform one optimization step"""
if len(self.replay_buffer) < self.batch_size:
return
# Sample batch
experiences = self.replay_buffer.sample(self.batch_size)
batch = Experience(*zip(*experiences))
# Convert to tensors
state_batch = torch.tensor(np.array(batch.state), dtype=torch.float32)
action_batch = torch.tensor([self.action_to_idx[a] for a in batch.action], dtype=torch.long)
reward_batch = torch.tensor(batch.reward, dtype=torch.float32)
next_state_batch = torch.tensor(np.array(batch.next_state), dtype=torch.float32)
done_batch = torch.tensor(batch.done, dtype=torch.float32)
# Compute Q(s,a)
q_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1)).squeeze(1)
# Compute target Q-values
with torch.no_grad():
next_q_values = self.target_net(next_state_batch).max(1)[0]
target_q_values = reward_batch + self.gamma * next_q_values * (1 - done_batch)
# Compute loss
loss = F.mse_loss(q_values, target_q_values)
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.losses.append(loss.item())
return loss.item()
def update_target_network(self):
"""Update target network with policy network weights"""
self.target_net.load_state_dict(self.policy_net.state_dict())
def train_episode(self):
"""Run one training episode"""
state = self.env.start
total_reward = 0.0
steps = 0
while not self.env.is_terminal(state) and steps < 100:
# Select action
action = self.get_action(state, explore=True)
# Take action
next_state = self.env.get_next_state(state, action)
reward = self.env.get_reward(state, action, next_state)
done = self.env.is_terminal(next_state)
# Store experience
self.replay_buffer.push(state, action, reward, next_state, done)
# Optimize model
loss = self.optimize_model()
# Update target network
self.training_step += 1
if self.training_step % self.target_update_freq == 0:
self.update_target_network()
# Update state
state = next_state
total_reward += reward
steps += 1
# Decay epsilon
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
return total_reward, steps
def train(self, num_episodes: int = 1000):
"""Train the agent"""
for episode in range(num_episodes):
reward, length = self.train_episode()
self.episode_rewards.append(reward)
self.episode_lengths.append(length)
if (episode + 1) % 100 == 0:
avg_reward = np.mean(self.episode_rewards[-100:])
print(f"Episode {episode+1}/{num_episodes}, Avg Reward: {avg_reward:.2f}, Epsilon: {self.epsilon:.3f}")
def get_policy(self):
"""Extract greedy policy from Q-network"""
policy = {}
for state in self.env.states:
if not self.env.is_terminal(state):
with torch.no_grad():
state_tensor = self.state_to_tensor(state)
q_values = self.policy_net(state_tensor)
action_idx = torch.argmax(q_values).item()
policy[state] = self.idx_to_action[action_idx]
return policy
Training the DQN Agentยถ
During training we track four diagnostic signals: raw episode rewards (noisy but informative), smoothed rewards (the moving average reveals the true trend), training loss (TD error magnitude โ should decrease over time but may spike during exploration), and epsilon decay (shows the scheduled transition from exploration to exploitation). A well-trained DQN typically shows an initial plateau while the replay buffer fills, followed by a steady reward climb as the network begins to generalize, and finally convergence once epsilon is low and the target network is closely aligned with the policy network.
# Create and train DQN agent
dqn_agent = DQNAgent(env, hidden_size=64, learning_rate=1e-3, gamma=0.9,
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
buffer_size=10000, batch_size=32, target_update_freq=10)
print("Training DQN agent...")
dqn_agent.train(num_episodes=1000)
# Plot training progress
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
# Episode rewards
ax1.plot(dqn_agent.episode_rewards, alpha=0.7)
ax1.set_xlabel('Episode')
ax1.set_ylabel('Total Reward')
ax1.set_title('DQN Training: Episode Rewards')
ax1.grid(True, alpha=0.3)
# Moving average rewards
window_size = 50
if len(dqn_agent.episode_rewards) >= window_size:
moving_avg = np.convolve(dqn_agent.episode_rewards, np.ones(window_size)/window_size, mode='valid')
ax2.plot(moving_avg, color='red', linewidth=2)
ax2.set_xlabel('Episode')
ax2.set_ylabel('Average Reward (50 episodes)')
ax2.set_title('Smoothed Learning Progress')
ax2.grid(True, alpha=0.3)
# Training losses
if dqn_agent.losses:
ax3.plot(dqn_agent.losses, alpha=0.7)
ax3.set_xlabel('Training Step')
ax3.set_ylabel('Loss')
ax3.set_title('Training Loss')
ax3.grid(True, alpha=0.3)
# Epsilon decay
epsilons = [1.0]
eps = 1.0
for _ in range(len(dqn_agent.episode_rewards)):
eps = max(0.01, eps * 0.995)
epsilons.append(eps)
ax4.plot(epsilons[:-1], color='green', linewidth=2)
ax4.set_xlabel('Episode')
ax4.set_ylabel('Epsilon')
ax4.set_title('Exploration Rate Decay')
ax4.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"Final average reward (last 100 episodes): {np.mean(dqn_agent.episode_rewards[-100:]):.2f}")
print(f"Final epsilon: {dqn_agent.epsilon:.3f}")
Visualizing DQN Performanceยถ
Plotting the learned policy on the grid provides an immediate sanity check: arrows should point coherently toward the goal while routing around the obstacle. Unlike the tabular Q-table, the DQNโs policy emerges from continuous function approximation, so it can generalize to states that were rarely visited during training. Comparing this policy with the tabular Q-learning result from the previous notebook highlights a key trade-off โ tabular methods guarantee convergence to exact optimal values, while DQN trades exactness for the ability to scale to vastly larger state spaces.
def visualize_dqn_policy(agent: DQNAgent):
"""Visualize DQN learned policy"""
policy = agent.get_policy()
fig, ax = plt.subplots(figsize=(8, 8))
# Create empty grid
grid = np.zeros((agent.env.grid_size, agent.env.grid_size))
ax.imshow(grid, cmap='Blues', alpha=0.1)
# Add grid lines
for i in range(agent.env.grid_size + 1):
ax.axhline(i - 0.5, color='black', linewidth=2)
ax.axvline(i - 0.5, color='black', linewidth=2)
# Add policy arrows
action_arrows = {'up': 'โ', 'down': 'โ', 'left': 'โ', 'right': 'โ'}
for i in range(agent.env.grid_size):
for j in range(agent.env.grid_size):
state = (i, j)
if state == agent.env.start:
action = policy.get(state, '')
arrow = action_arrows.get(action, '')
ax.text(j, i, f'START\n{arrow}', ha='center', va='center',
fontsize=10, fontweight='bold')
elif state == agent.env.goal:
ax.text(j, i, 'GOAL', ha='center', va='center', fontsize=12, fontweight='bold')
elif state == agent.env.obstacle:
ax.text(j, i, 'OBSTACLE', ha='center', va='center', fontsize=8, fontweight='bold', color='red')
elif not agent.env.is_terminal(state):
action = policy.get(state, '')
arrow = action_arrows.get(action, '')
ax.text(j, i, arrow, ha='center', va='center', fontsize=20, fontweight='bold')
ax.set_xlim(-0.5, agent.env.grid_size - 0.5)
ax.set_ylim(-0.5, agent.env.grid_size - 0.5)
ax.set_xticks(range(agent.env.grid_size))
ax.set_yticks(range(agent.env.grid_size))
ax.set_title('DQN Learned Policy', fontsize=16)
plt.gca().invert_yaxis()
plt.show()
# Visualize DQN policy
visualize_dqn_policy(dqn_agent)
Understanding DQN Training Dynamicsยถ
Inspecting raw Q-value predictions for specific states reveals whether the network has learned meaningful structure. At the start state, the best actionโs Q-value should be noticeably higher than alternatives; near the goal, Q-values should be large and positive; near the obstacle, actions leading into it should carry low or negative values. We also trace the greedy path from start to goal โ if the path reaches the goal in a reasonable number of steps without looping, the training was successful. When debugging DQN in practice, this kind of per-state Q-value analysis is often more informative than aggregate reward curves alone.
def analyze_dqn_q_values(agent: DQNAgent):
"""Analyze Q-values learned by DQN"""
print("DQN Q-Value Analysis:")
print("=" * 50)
# Get Q-values for key states
key_states = [agent.env.start, agent.env.goal, agent.env.obstacle, (1, 0), (2, 2)]
state_names = ['Start', 'Goal', 'Obstacle', 'Near Start', 'Center']
for state, name in zip(key_states, state_names):
if agent.env.is_terminal(state) and state != agent.env.start:
print(f"\n{name} State {state}: Terminal (Q=0)")
continue
with torch.no_grad():
state_tensor = agent.state_to_tensor(state)
q_values = agent.policy_net(state_tensor).squeeze(0).numpy()
action_names = agent.env.actions
best_action_idx = np.argmax(q_values)
best_action = action_names[best_action_idx]
print(f"\n{name} State {state}:")
for action, q_val in zip(action_names, q_values):
marker = " โ BEST" if action == best_action else ""
print(f" {action}: {q_val:.3f}{marker}")
# Check if policy makes sense
print("\nPolicy Analysis:")
policy = agent.get_policy()
# Check path from start to goal
current = agent.env.start
path = [current]
visited = set()
while not agent.env.is_terminal(current) and current not in visited and len(path) < 10:
visited.add(current)
action = policy.get(current)
if action:
next_state = agent.env.get_next_state(current, action)
path.append(next_state)
current = next_state
else:
break
print(f"Learned path from start: {path}")
if path[-1] == agent.env.goal:
print("โ Policy successfully reaches goal!")
else:
print("โ Policy does not reach goal")
# Analyze DQN learning
analyze_dqn_q_values(dqn_agent)
๐ฎ Advanced DQN Techniquesยถ
Double DQNยถ
Addresses overestimation bias in Q-learning
Uses policy network to select action, target network to evaluate
Reduces overestimation of Q-values
Dueling DQNยถ
Separates value and advantage functions
Architecture: State โ Value Stream + Advantage Stream โ Q-values
Better value function learning
Prioritized Experience Replayยถ
Samples important experiences more frequently
Uses TD-error as priority metric
Improves sample efficiency
Noisy Networksยถ
Adds noise to network parameters instead of ฮต-greedy
Better exploration in complex environments
Implementing Double DQNยถ
Standard DQN tends to overestimate Q-values because the same network both selects and evaluates the best next action, creating an upward bias whenever estimation noise is present. Double DQN decouples these two steps: the policy network selects the best action \(a^* = \arg\max_{a'} Q(s', a'; \theta)\), but the target network evaluates it \(Q(s', a^*; \theta^{-})\). This simple change โ selecting with one network, evaluating with another โ significantly reduces overestimation and often leads to more stable training and better final policies. The modification requires only a few lines of code in the optimize_model method but can make a meaningful difference on harder environments.
class DoubleDQNAgent(DQNAgent):
"""Double DQN Agent"""
def optimize_model(self):
"""Double DQN optimization step"""
if len(self.replay_buffer) < self.batch_size:
return
# Sample batch
experiences = self.replay_buffer.sample(self.batch_size)
batch = Experience(*zip(*experiences))
# Convert to tensors
state_batch = torch.tensor(np.array(batch.state), dtype=torch.float32)
action_batch = torch.tensor([self.action_to_idx[a] for a in batch.action], dtype=torch.long)
reward_batch = torch.tensor(batch.reward, dtype=torch.float32)
next_state_batch = torch.tensor(np.array(batch.next_state), dtype=torch.float32)
done_batch = torch.tensor(batch.done, dtype=torch.float32)
# Compute Q(s,a) using policy network
q_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1)).squeeze(1)
# Double DQN: Use policy network to select action, target network to evaluate
with torch.no_grad():
# Select best actions using policy network
next_actions = self.policy_net(next_state_batch).argmax(1).unsqueeze(1)
# Evaluate using target network
next_q_values = self.target_net(next_state_batch).gather(1, next_actions).squeeze(1)
target_q_values = reward_batch + self.gamma * next_q_values * (1 - done_batch)
# Compute loss
loss = F.mse_loss(q_values, target_q_values)
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.losses.append(loss.item())
return loss.item()
# Compare DQN vs Double DQN
print("Comparing DQN vs Double DQN...")
# Train Double DQN
double_dqn_agent = DoubleDQNAgent(env, hidden_size=64, learning_rate=1e-3, gamma=0.9,
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995,
buffer_size=10000, batch_size=32, target_update_freq=10)
double_dqn_agent.train(num_episodes=1000)
# Plot comparison
plt.figure(figsize=(12, 6))
plt.plot(dqn_agent.episode_rewards, label=f'DQN (final: {np.mean(dqn_agent.episode_rewards[-100:]):.1f})', alpha=0.7)
plt.plot(double_dqn_agent.episode_rewards, label=f'Double DQN (final: {np.mean(double_dqn_agent.episode_rewards[-100:]):.1f})', alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('DQN vs Double DQN Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
๐ง Key Takeawaysยถ
Function approximation enables scaling: Neural networks can handle complex state spaces
Experience replay breaks correlation: Random sampling improves learning stability
Target networks reduce oscillations: Separate networks for targets stabilize training
Double DQN reduces overestimation: Better action selection reduces Q-value bias
Hyperparameter tuning is crucial: Learning rate, batch size, and network architecture matter
๐ Whatโs Next?ยถ
Now that you understand DQN, youโre ready for:
Policy Gradient Methods: Learning policies directly (REINFORCE, PPO)
Actor-Critic Methods: Combining value and policy learning (A2C, A3C)
Advanced Environments: Atari games, continuous control, multi-agent systems
Real-world Applications: Robotics, game playing, recommendation systems
๐ Further Readingยถ
Playing Atari with Deep Reinforcement Learning - Original DQN paper
Deep Reinforcement Learning with Double Q-learning - Double DQN
Dueling Network Architectures for Deep Reinforcement Learning - Dueling DQN
๐๏ธ Exercisesยถ
Implement Dueling DQN from scratch
Add Prioritized Experience Replay to your DQN
Experiment with different network architectures (CNN for images)
Solve a Gymnasium environment (CartPole, LunarLander)
Implement Rainbow DQN (combines multiple improvements)
๐ก Discussion Questionsยถ
Why do we need experience replay and target networks?
How does Double DQN address the overestimation problem?
What are the limitations of value-based methods?
How might you apply DQN to real-world problems?
๐ง Practical Tipsยถ
Start simple: Use small networks and simple environments first
Monitor training: Track rewards, losses, and Q-values
Tune hyperparameters: Learning rate, batch size, and epsilon decay are critical
Use GPU: Deep networks train faster on GPU
Save models: Checkpoint your trained models for later use