04: Policy-Based Methods (REINFORCE)ยถ

โ€œThe best way to predict the future is to invent it.โ€ - Alan Kay

Welcome to the world of policy-based reinforcement learning! While value-based methods like Q-learning learn to estimate how good actions are, policy-based methods learn policies directly - they learn what actions to take in different situations.

๐ŸŽฏ Learning Objectivesยถ

By the end of this notebook, youโ€™ll understand:

  • The difference between value-based and policy-based methods

  • How policy gradients work and why theyโ€™re useful

  • The REINFORCE algorithm and its variants

  • Advantages of policy-based methods for continuous action spaces

  • Practical implementation and training tips

๐Ÿง  Value-Based vs Policy-Based Learningยถ

Value-Based Methods (Q-Learning, DQN)ยถ

  • Learn value functions: How good are different states/actions?

  • Derive policy from values: ฯ€(s) = argmax_a Q(s,a)

  • Pros: Sample efficient, stable, good for discrete actions

  • Cons: Hard with continuous actions, can be unstable

Policy-Based Methods (REINFORCE, PPO)ยถ

  • Learn policies directly: What action to take in each state?

  • Parameterize policy: ฯ€_ฮธ(a|s) - probability of action given state

  • Pros: Work with continuous actions, can learn stochastic policies

  • Cons: Usually less sample efficient, can have high variance

๐Ÿ“ˆ Policy Gradientsยถ

The Objectiveยถ

We want to maximize the expected return: J(ฮธ) = E[โˆ‘_t ฮณ^t r_t | ฯ€_ฮธ]

Policy Gradient Theoremยถ

โˆ‡_ฮธ J(ฮธ) = E[โˆ‘_t โˆ‡_ฮธ log ฯ€_ฮธ(a_t|s_t) * Q^ฯ€(s_t,a_t)]

This tells us how to update policy parameters to improve expected returns!

Intuitionยถ

  • If an action led to high returns, increase probability of that action

  • If an action led to low returns, decrease probability of that action

  • The gradient shows us the direction to adjust parameters

๐ŸŽฎ REINFORCE Algorithmยถ

The Basic Ideaยถ

  1. Generate episode using current policy ฯ€_ฮธ

  2. For each step, compute: โˆ‡_ฮธ log ฯ€_ฮธ(a_t|s_t) * G_t

  3. Update policy: ฮธ โ† ฮธ + ฮฑ * โˆ‡_ฮธ log ฯ€_ฮธ(a_t|s_t) * G_t

Where G_t is the return from time step t.

Key Componentsยถ

  • Policy Network: Parameterized policy ฯ€_ฮธ(a|s)

  • Returns: G_t = โˆ‘_{k=t}^T ฮณ^{k-t} r_k (discounted sum of future rewards)

  • Log Probability: log ฯ€_ฮธ(a_t|s_t) - log likelihood of chosen action

  • Gradient: โˆ‡_ฮธ [log ฯ€_ฮธ(a_t|s_t) * G_t]

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
from typing import List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = [12, 8]

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

Policy Networkยถ

Unlike value-based methods where the network outputs Q-values, a policy network outputs a probability distribution over actions. The final Softmax layer ensures outputs are non-negative and sum to one, making them valid probabilities. To select an action, we sample from this distribution using PyTorchโ€™s Categorical class, which also computes the log-probability \(\log \pi_\theta(a|s)\) needed for the policy gradient. Stochastic sampling is essential here โ€“ it provides built-in exploration (unlike the epsilon-greedy hack in DQN) and allows the agent to naturally represent uncertainty about which action is best. This probabilistic formulation also extends seamlessly to continuous action spaces by replacing Softmax + Categorical with a Gaussian distribution parameterized by learned mean and standard deviation.

class PolicyNetwork(nn.Module):
    """Neural network policy for grid world"""
    
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super(PolicyNetwork, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Softmax(dim=-1)  # Convert to probabilities
        )
    
    def forward(self, x):
        return self.network(x)
    
    def get_action(self, state):
        """Sample action from policy"""
        probs = self.forward(state)
        dist = Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action)

# Test the policy network
input_size = 2  # (row, col) position
hidden_size = 64
output_size = 4  # 4 actions: up, down, left, right

policy_net = PolicyNetwork(input_size, hidden_size, output_size)
print(f"Policy Network:\n{policy_net}")

# Test forward pass
test_state = torch.tensor([[0, 0]], dtype=torch.float32)  # Start position
action_probs = policy_net(test_state)
print(f"\nAction probabilities for start state: {action_probs.detach().numpy()}")

# Sample action
action, log_prob = policy_net.get_action(test_state)
action_names = ['up', 'down', 'left', 'right']
print(f"Sampled action: {action_names[action]} (log prob: {log_prob.item():.3f})")

REINFORCE Agent Implementationยถ

The REINFORCE algorithm is the simplest policy gradient method. After collecting a complete episode trajectory, it computes the discounted return \(G_t = \sum_{k=0}^{T-t} \gamma^k r_{t+k}\) for each time step, then updates the policy parameters by ascending the gradient:

\[\theta \leftarrow \theta + \alpha \sum_t \nabla_\theta \log \pi_\theta(a_t | s_t) \, G_t\]

The key implementation detail is return normalization โ€“ subtracting the mean and dividing by the standard deviation of the returns within an episode. This acts as a simple baseline, reducing variance without introducing bias. Without normalization, REINFORCE can be extremely noisy because the raw returns may vary by orders of magnitude across episodes. The compute_returns method walks backward through the reward sequence, accumulating discounted sums efficiently in \(O(T)\) time.

class REINFORCEAgent:
    """REINFORCE Policy Gradient Agent"""
    
    def __init__(self, env, hidden_size: int = 64, learning_rate: float = 1e-3, gamma: float = 0.99):
        self.env = env
        self.gamma = gamma
        
        # Policy network
        input_size = 2  # (row, col)
        output_size = len(env.actions)
        self.policy_net = PolicyNetwork(input_size, hidden_size, output_size)
        
        # Optimizer
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
        
        # Action mapping
        self.action_to_idx = {action: idx for idx, action in enumerate(env.actions)}
        self.idx_to_action = {idx: action for action, idx in self.action_to_idx.items()}
        
        # Training stats
        self.episode_rewards = []
        self.episode_lengths = []
        self.losses = []
    
    def state_to_tensor(self, state):
        """Convert state to tensor"""
        return torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    
    def get_action(self, state):
        """Get action from policy"""
        state_tensor = self.state_to_tensor(state)
        action_idx, log_prob = self.policy_net.get_action(state_tensor)
        return self.idx_to_action[action_idx], log_prob
    
    def compute_returns(self, rewards: List[float]) -> List[float]:
        """Compute discounted returns for an episode"""
        returns = []
        G = 0
        for reward in reversed(rewards):
            G = reward + self.gamma * G
            returns.insert(0, G)
        return returns
    
    def train_episode(self):
        """Run one training episode and update policy"""
        # Generate episode
        states = []
        actions = []
        rewards = []
        log_probs = []
        
        state = self.env.start
        done = False
        steps = 0
        
        while not done and steps < 100:
            # Get action from policy
            action, log_prob = self.get_action(state)
            
            # Take action
            next_state = self.env.get_next_state(state, action)
            reward = self.env.get_reward(state, action, next_state)
            done = self.env.is_terminal(next_state)
            
            # Store experience
            states.append(state)
            actions.append(self.action_to_idx[action])
            rewards.append(reward)
            log_probs.append(log_prob)
            
            state = next_state
            steps += 1
        
        # Compute returns
        returns = self.compute_returns(rewards)
        returns = torch.tensor(returns, dtype=torch.float32)
        
        # Normalize returns (optional, helps with training stability)
        if len(returns) > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
        
        # Compute policy loss
        policy_loss = []
        for log_prob, G in zip(log_probs, returns):
            policy_loss.append(-log_prob * G)  # Negative because we want to maximize
        
        policy_loss = torch.stack(policy_loss).sum()
        
        # Update policy
        self.optimizer.zero_grad()
        policy_loss.backward()
        self.optimizer.step()
        
        total_reward = sum(rewards)
        self.episode_rewards.append(total_reward)
        self.episode_lengths.append(steps)
        self.losses.append(policy_loss.item())
        
        return total_reward, steps
    
    def train(self, num_episodes: int = 1000):
        """Train the agent"""
        for episode in range(num_episodes):
            reward, length = self.train_episode()
            
            if (episode + 1) % 100 == 0:
                avg_reward = np.mean(self.episode_rewards[-100:])
                print(f"Episode {episode+1}/{num_episodes}, Avg Reward: {avg_reward:.2f}")
    
    def get_policy(self):
        """Extract deterministic policy from stochastic policy network"""
        policy = {}
        for state in self.env.states:
            if not self.env.is_terminal(state):
                with torch.no_grad():
                    state_tensor = self.state_to_tensor(state)
                    action_probs = self.policy_net(state_tensor).squeeze(0)
                    action_idx = torch.argmax(action_probs).item()
                    policy[state] = self.idx_to_action[action_idx]
        return policy
    
    def get_action_probabilities(self, state):
        """Get action probabilities for a state"""
        with torch.no_grad():
            state_tensor = self.state_to_tensor(state)
            probs = self.policy_net(state_tensor).squeeze(0).numpy()
        return dict(zip(self.env.actions, probs))

Training the REINFORCE Agentยถ

Because REINFORCE is an on-policy algorithm, it can only learn from trajectories generated by the current policy โ€“ old experience cannot be reused like in DQN. This makes it less sample-efficient but guarantees that gradient estimates are unbiased. The training plots to watch are: episode rewards (expect high variance, especially early on), the policy loss (should trend downward as the policy improves), and episode lengths (shorter episodes generally indicate the agent found a more direct path to the goal). Comparing these curves to the DQN training from the previous notebook illustrates why policy gradients often need more episodes to converge but can handle problems DQN cannot, such as continuous or high-dimensional action spaces.

# Create and train REINFORCE agent
reinforce_agent = REINFORCEAgent(env, hidden_size=64, learning_rate=1e-3, gamma=0.9)

print("Training REINFORCE agent...")
reinforce_agent.train(num_episodes=1000)

# Plot training progress
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))

# Episode rewards
ax1.plot(reinforce_agent.episode_rewards, alpha=0.7)
ax1.set_xlabel('Episode')
ax1.set_ylabel('Total Reward')
ax1.set_title('REINFORCE Training: Episode Rewards')
ax1.grid(True, alpha=0.3)

# Moving average rewards
window_size = 50
if len(reinforce_agent.episode_rewards) >= window_size:
    moving_avg = np.convolve(reinforce_agent.episode_rewards, np.ones(window_size)/window_size, mode='valid')
    ax2.plot(moving_avg, color='red', linewidth=2)
ax2.set_xlabel('Episode')
ax2.set_ylabel('Average Reward (50 episodes)')
ax2.set_title('Smoothed Learning Progress')
ax2.grid(True, alpha=0.3)

# Training losses
if reinforce_agent.losses:
    ax3.plot(reinforce_agent.losses, alpha=0.7)
ax3.set_xlabel('Episode')
ax3.set_ylabel('Policy Loss')
ax3.set_title('Policy Loss')
ax3.grid(True, alpha=0.3)

# Episode lengths
ax4.plot(reinforce_agent.episode_lengths, alpha=0.7, color='green')
ax4.set_xlabel('Episode')
ax4.set_ylabel('Episode Length')
ax4.set_title('Episode Lengths')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final average reward (last 100 episodes): {np.mean(reinforce_agent.episode_rewards[-100:]):.2f}")
print(f"Final average episode length: {np.mean(reinforce_agent.episode_lengths[-100:]):.2f}")

Analyzing Policy Learningยถ

Examining the learned action probabilities โ€“ rather than just the argmax policy โ€“ gives a richer picture of what REINFORCE has learned. A well-trained stochastic policy assigns high probability to good actions while maintaining small but nonzero probability on alternatives, reflecting residual uncertainty. Near the goal, probabilities should be sharply peaked; farther away, the distribution may be flatter. This probabilistic view also helps diagnose training issues: if probabilities are nearly uniform everywhere, the policy has not learned; if they are overly concentrated too early (entropy collapse), the agent may be stuck in a suboptimal policy because it stopped exploring prematurely.

def analyze_reinforce_policy(agent: REINFORCEAgent):
    """Analyze the learned policy"""
    
    print("REINFORCE Policy Analysis:")
    print("=" * 50)
    
    # Get action probabilities for key states
    key_states = [agent.env.start, agent.env.goal, agent.env.obstacle, (1, 0), (2, 2)]
    state_names = ['Start', 'Goal', 'Obstacle', 'Near Start', 'Center']
    
    for state, name in zip(key_states, state_names):
        if agent.env.is_terminal(state) and state != agent.env.start:
            print(f"\n{name} State {state}: Terminal")
            continue
        
        probs = agent.get_action_probabilities(state)
        best_action = max(probs, key=probs.get)
        
        print(f"\n{name} State {state}:")
        for action, prob in probs.items():
            marker = " โ† BEST" if action == best_action else ""
            print(f"  {action}: {prob:.3f}{marker}")
    
    # Visualize policy
    policy = agent.get_policy()
    
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Create empty grid
    grid = np.zeros((agent.env.grid_size, agent.env.grid_size))
    ax.imshow(grid, cmap='Blues', alpha=0.1)
    
    # Add grid lines
    for i in range(agent.env.grid_size + 1):
        ax.axhline(i - 0.5, color='black', linewidth=2)
        ax.axvline(i - 0.5, color='black', linewidth=2)
    
    # Add policy arrows
    action_arrows = {'up': 'โ†‘', 'down': 'โ†“', 'left': 'โ†', 'right': 'โ†’'}
    
    for i in range(agent.env.grid_size):
        for j in range(agent.env.grid_size):
            state = (i, j)
            
            if state == agent.env.start:
                action = policy.get(state, '')
                arrow = action_arrows.get(action, '')
                ax.text(j, i, f'START\n{arrow}', ha='center', va='center', 
                       fontsize=10, fontweight='bold')
            elif state == agent.env.goal:
                ax.text(j, i, 'GOAL', ha='center', va='center', fontsize=12, fontweight='bold')
            elif state == agent.env.obstacle:
                ax.text(j, i, 'OBSTACLE', ha='center', va='center', fontsize=8, fontweight='bold', color='red')
            elif not agent.env.is_terminal(state):
                action = policy.get(state, '')
                arrow = action_arrows.get(action, '')
                ax.text(j, i, arrow, ha='center', va='center', fontsize=20, fontweight='bold')
    
    ax.set_xlim(-0.5, agent.env.grid_size - 0.5)
    ax.set_ylim(-0.5, agent.env.grid_size - 0.5)
    ax.set_xticks(range(agent.env.grid_size))
    ax.set_yticks(range(agent.env.grid_size))
    ax.set_title('REINFORCE Learned Policy', fontsize=16)
    plt.gca().invert_yaxis()
    plt.show()

# Analyze the learned policy
analyze_reinforce_policy(reinforce_agent)

๐ŸŽฏ Actor-Critic Methodsยถ

The Problem with REINFORCEยถ

  • High variance: Uses full episode returns

  • Inefficient: Waits until episode end to learn

  • Slow convergence: Requires many samples

Actor-Critic Solutionยถ

  • Actor: Policy network (chooses actions)

  • Critic: Value network (estimates state values)

  • Advantage: A(s,a) = Q(s,a) - V(s) - reduces variance

  • TD Learning: Learn from partial episodes

Advantage Actor-Critic (A2C)ยถ

  • Actor update: โˆ‡_ฮธ J(ฮธ) โˆ โˆ‡_ฮธ log ฯ€_ฮธ(a|s) * A(s,a)

  • Critic update: Minimize ||V_ฯ†(s) - (r + ฮณ V_ฯ†(sโ€™))||ยฒ

class ActorCriticNetwork(nn.Module):
    """Combined Actor-Critic network"""
    
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super(ActorCriticNetwork, self).__init__()
        
        # Shared layers
        self.shared = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        
        # Actor head (policy)
        self.actor = nn.Sequential(
            nn.Linear(hidden_size, output_size),
            nn.Softmax(dim=-1)
        )
        
        # Critic head (value)
        self.critic = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        shared_features = self.shared(x)
        policy = self.actor(shared_features)
        value = self.critic(shared_features)
        return policy, value
    
    def get_action(self, state):
        """Sample action from policy"""
        policy, _ = self.forward(state)
        dist = Categorical(policy)
        action = dist.sample()
        return action.item(), dist.log_prob(action)
    
    def get_value(self, state):
        """Get state value"""
        _, value = self.forward(state)
        return value.item()

class A2CAgent:
    """Advantage Actor-Critic Agent"""
    
    def __init__(self, env, hidden_size: int = 64, learning_rate: float = 1e-3, gamma: float = 0.99):
        self.env = env
        self.gamma = gamma
        
        # Actor-Critic network
        input_size = 2
        output_size = len(env.actions)
        self.network = ActorCriticNetwork(input_size, hidden_size, output_size)
        
        # Optimizer
        self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
        
        # Action mapping
        self.action_to_idx = {action: idx for idx, action in enumerate(env.actions)}
        self.idx_to_action = {idx: action for action, idx in self.action_to_idx.items()}
        
        # Training stats
        self.episode_rewards = []
        self.episode_lengths = []
        self.actor_losses = []
        self.critic_losses = []
    
    def state_to_tensor(self, state):
        return torch.tensor(state, dtype=torch.float32).unsqueeze(0)
    
    def train_episode(self):
        """Run one A2C training episode"""
        states = []
        actions = []
        rewards = []
        log_probs = []
        values = []
        
        state = self.env.start
        done = False
        steps = 0
        
        # Generate episode
        while not done and steps < 100:
            state_tensor = self.state_to_tensor(state)
            
            # Get action and value
            action_idx, log_prob = self.network.get_action(state_tensor)
            action = self.idx_to_action[action_idx]
            value = self.network.get_value(state_tensor)
            
            # Take action
            next_state = self.env.get_next_state(state, action)
            reward = self.env.get_reward(state, action, next_state)
            done = self.env.is_terminal(next_state)
            
            # Store experience
            states.append(state)
            actions.append(action_idx)
            rewards.append(reward)
            log_probs.append(log_prob)
            values.append(value)
            
            state = next_state
            steps += 1
        
        # Compute returns and advantages
        returns = []
        advantages = []
        G = 0
        next_value = 0  # Terminal state value is 0
        
        for t in reversed(range(len(rewards))):
            # Compute return
            G = rewards[t] + self.gamma * G
            returns.insert(0, G)
            
            # Compute advantage
            if t == len(rewards) - 1:
                advantage = rewards[t] + self.gamma * next_value - values[t]
            else:
                advantage = rewards[t] + self.gamma * values[t+1] - values[t]
            advantages.insert(0, advantage)
        
        # Convert to tensors
        returns = torch.tensor(returns, dtype=torch.float32)
        advantages = torch.tensor(advantages, dtype=torch.float32)
        log_probs = torch.stack(log_probs)
        values = torch.tensor(values, dtype=torch.float32)
        
        # Compute losses
        actor_loss = -(log_probs * advantages.detach()).mean()
        critic_loss = F.mse_loss(values, returns)
        total_loss = actor_loss + critic_loss
        
        # Update network
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()
        
        total_reward = sum(rewards)
        self.episode_rewards.append(total_reward)
        self.episode_lengths.append(steps)
        self.actor_losses.append(actor_loss.item())
        self.critic_losses.append(critic_loss.item())
        
        return total_reward, steps
    
    def train(self, num_episodes: int = 1000):
        """Train the agent"""
        for episode in range(num_episodes):
            reward, length = self.train_episode()
            
            if (episode + 1) % 100 == 0:
                avg_reward = np.mean(self.episode_rewards[-100:])
                print(f"Episode {episode+1}/{num_episodes}, Avg Reward: {avg_reward:.2f}")
    
    def get_policy(self):
        """Extract greedy policy"""
        policy = {}
        for state in self.env.states:
            if not self.env.is_terminal(state):
                with torch.no_grad():
                    state_tensor = self.state_to_tensor(state)
                    policy_probs, _ = self.network(state_tensor)
                    action_idx = torch.argmax(policy_probs).item()
                    policy[state] = self.idx_to_action[action_idx]
        return policy

# Compare REINFORCE vs A2C
print("Comparing REINFORCE vs A2C...")

a2c_agent = A2CAgent(env, hidden_size=64, learning_rate=1e-3, gamma=0.9)
a2c_agent.train(num_episodes=1000)

# Plot comparison
plt.figure(figsize=(12, 6))
plt.plot(reinforce_agent.episode_rewards, label=f'REINFORCE (final: {np.mean(reinforce_agent.episode_rewards[-100:]):.1f})', alpha=0.7)
plt.plot(a2c_agent.episode_rewards, label=f'A2C (final: {np.mean(a2c_agent.episode_rewards[-100:]):.1f})', alpha=0.7)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('REINFORCE vs A2C Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

๐Ÿง  Key Takeawaysยถ

  1. Policy gradients learn policies directly: No need to learn values first

  2. REINFORCE is the foundation: Simple but high variance

  3. Actor-Critic reduces variance: Critic provides baseline for advantages

  4. Policy methods handle continuous actions: Natural for continuous control

  5. Stochastic policies are powerful: Can represent uncertainty and exploration

๐Ÿš€ Whatโ€™s Next?ยถ

Now that you understand policy-based methods, youโ€™re ready for:

  • Proximal Policy Optimization (PPO): State-of-the-art policy optimization

  • Trust Region Policy Optimization (TRPO): Conservative policy updates

  • Soft Actor-Critic (SAC): Maximum entropy RL

  • Multi-Agent Reinforcement Learning: Multiple agents learning together

๐Ÿ“š Further Readingยถ

๐Ÿ‹๏ธ Exercisesยถ

  1. Implement REINFORCE with baseline to reduce variance

  2. Add Generalized Advantage Estimation (GAE) to A2C

  3. Implement PPO (clipped surrogate objective)

  4. Solve continuous control tasks (Pendulum, MountainCar)

  5. Add entropy regularization for better exploration

๐Ÿ’ก Discussion Questionsยถ

  • Why do policy gradients have high variance?

  • How does the critic help in actor-critic methods?

  • When should you use policy-based vs value-based methods?

  • What are the advantages of stochastic policies?

๐Ÿ”ง Practical Tipsยถ

  • Normalize advantages: Subtract mean and divide by std

  • Use entropy bonus: Encourage exploration in stochastic policies

  • Clip gradients: Prevent exploding gradients

  • Monitor KL divergence: Ensure policy updates arenโ€™t too large

  • Use multiple workers: Parallel environments for faster training