Policy Improvement and Policy Iteration in a Nutshell

Building on our previous post on Iterative Policy Evaluation in a Nutshell, we now turn to the next two critical components of dynamic programming in reinforcement learning:

  1. Policy Improvement
  2. Policy Iteration
  • Once we can accurately evaluate a policy (i.e., compute its state-value function vπv_\pi ), the next step is to improve that policy.
  • By iteratively evaluating and improving, we converge to an optimal policy that maximizes expected return in every state.

Policy Improvement Definition

At its core, policy improvement upgrades a policy by making it greedy with respect to its own action-value function qπ(s,a)q_\pi(s,a).

π(s)=argmaxaqπ(s,a)=argmaxas,rp(s,rs,a)[r+γvπ(s)]π'(s) = \arg\max_a q_π(s, a) = \arg\max_a \sum_{s',r} p(s',r \mid s,a) \bigl[r + γ · v_π(s')\bigr]
  • qπ(s,a)q_\pi(s,a) is the action-value function under policy π\pi.
    • It represents the expected return when starting in state ss, taking action aa, and following π\pi.
  • Greedy improvement chooses the action with the highest expected return in each state.

Policy Improvement Algorithm

We will reuse our maze example.

class Agent:
    def __init__(self, env: Env):
        # Initialize all states to take action LEFT by default
        self.ag_policy = {state.name: LEFT for state in env.states}
        
    def policy(self, state: State, action: int) -> float:
        # Deterministic policy: 1 for the chosen action, 0 otherwise
        return 1.0 if self.ag_policy[state.name] == action else 0.0

    def update_policy(self, policy) -> None:
        self.ag_policy = policy

The iterative_policy_eval(agent, env, gamma, threshold) is implemented exactly as in the previous post.

def iterative_policy_eval(agent, env, gamma, threshold):
    # …same as before…
    pass

Policy Improvement Function

import numpy as np

def policy_improv(values, agent: Agent, env: Env, gamma: float = 0.9):
    is_policy_stable = True

    for state in env.states:
        old_action = agent.ag_policy[state.name]
        action_values = []

        # Compute q_π(s,a) for each action
        for action in env.actions:
            total = 0.0
            for next_state, reward, prob in env.step(state, action):
                total += prob * (reward + gamma * values[next_state.name])
            action_values.append(total)

        # Greedy action selection
        best_action = int(np.argmax(action_values))
        agent.ag_policy[state.name] = best_action

        # Check if policy changed
        if best_action != old_action:
            is_policy_stable = False

    return is_policy_stable, agent.ag_policy

Key Points:

  • We perform a greedy pick of the action with respect to the current value function.
  • If no action changes across all states, the policy is stable.

Policy Iteration Algorithm

Policy Iteration combines evaluation and improvement until convergence:

  1. Initialize a random (or uniform) policy and zeroed value-function.
  2. Repeat
    • Policy Evaluation: Compute vπv_\pi for the current policy.
    • Policy Improvement: Generate π=\pi' = greedy_wrt(vπ)(v_\pi).
    • If π==π\pi == \pi, stop; else ππ\pi \leftarrow \pi'.
from tqdm import tqdm

MAX_ITER = 1_000

def policy_iteration(agent: Agent, env: Env, gamma: float = 0.9):
    # Start with an initial evaluation
    values = iterative_policy_eval(agent, env, gamma)

    for _ in tqdm(range(MAX_ITER)):
        # 1) Policy Improvement
        is_stable, new_policy = policy_improv(values, agent, env, gamma)
        agent.update_policy(new_policy)

        # 2) Policy Evaluation 
        values = iterative_policy_eval(agent, env, gamma)

        # 3) Check for convergence
        if is_stable:
            break

    return values, agent.ag_policy

Highlights:

  • Alternates between evaluation and greedy improvement.
  • Guaranteed to converge to the optimal policy in a finite MDP.

Results on the Maze

Function below displays the maze env:

def display_policy(env: Env, policy: dict):
    arrow_map = { UP:'U', DOWN:'D', LEFT:'L', RIGHT:'R' }
    for r in range(MAZE_ROWS):
        row = []
        for c in range(MAZE_COLS):
            if (r, c) in WALLS:
                row.append('#')
            elif (r, c) == GOAL:
                row.append('G')
            else:
                row.append(arrow_map[policy.get((r, c), LEFT)])
        print(' '.join(row))
    print()

Sample Run:

if __name__ == '__main__':
    env   = Env()
    agent = Agent(env)

    print('Initial Policy:')
    display_policy(env, agent.ag_policy)

    values, policy = policy_iteration(agent, env)
    
    print('Optimal Policy:')
    display_policy(env, policy)

    for state in sorted(env.states, key=lambda s: s.name):
        print(f"State {state.name}: V = {values[state.name]:.2f}")

Output:

Initial Policy
L L L # G
L # L # L
L # L L L
L L L # L
# # L L L

Optimal Policy
R R D # G
U # D # U
D # R R U
R R U # U
# # U R U

State (0, 0): V = -5.22
State (0, 1): V = -4.69
State (0, 2): V = -4.10
State (0, 4): V = 0.00
State (1, 0): V = -5.70
State (1, 2): V = -3.44
State (1, 4): V = 0.00
State (2, 0): V = -5.22
State (2, 2): V = -2.71
State (2, 3): V = -1.90
State (2, 4): V = -1.00
State (3, 0): V = -4.69
State (3, 1): V = -4.10
State (3, 2): V = -3.44
State (3, 4): V = -1.90
State (4, 2): V = -4.10
State (4, 3): V = -3.44
State (4, 4): V = -2.71

Conclusion

  • Policy Improvement makes a policy greedy with respect to its current state-value function.
  • Policy Iteration repeatedly evaluates and improves until no further gains are possible.