Building on our previous post on Iterative Policy Evaluation in a Nutshell, we now turn to the next two critical components of dynamic programming in reinforcement learning:
At its core, policy improvement upgrades a policy by making it greedy with respect to its own action-value function .
We will reuse our maze example.
class Agent:
def __init__(self, env: Env):
# Initialize all states to take action LEFT by default
self.ag_policy = {state.name: LEFT for state in env.states}
def policy(self, state: State, action: int) -> float:
# Deterministic policy: 1 for the chosen action, 0 otherwise
return 1.0 if self.ag_policy[state.name] == action else 0.0
def update_policy(self, policy) -> None:
self.ag_policy = policy
The iterative_policy_eval(agent, env, gamma, threshold)
is implemented exactly as in the previous post.
def iterative_policy_eval(agent, env, gamma, threshold):
# …same as before…
pass
import numpy as np
def policy_improv(values, agent: Agent, env: Env, gamma: float = 0.9):
is_policy_stable = True
for state in env.states:
old_action = agent.ag_policy[state.name]
action_values = []
# Compute q_π(s,a) for each action
for action in env.actions:
total = 0.0
for next_state, reward, prob in env.step(state, action):
total += prob * (reward + gamma * values[next_state.name])
action_values.append(total)
# Greedy action selection
best_action = int(np.argmax(action_values))
agent.ag_policy[state.name] = best_action
# Check if policy changed
if best_action != old_action:
is_policy_stable = False
return is_policy_stable, agent.ag_policy
Key Points:
Policy Iteration combines evaluation and improvement until convergence:
greedy_wrt
.from tqdm import tqdm
MAX_ITER = 1_000
def policy_iteration(agent: Agent, env: Env, gamma: float = 0.9):
# Start with an initial evaluation
values = iterative_policy_eval(agent, env, gamma)
for _ in tqdm(range(MAX_ITER)):
# 1) Policy Improvement
is_stable, new_policy = policy_improv(values, agent, env, gamma)
agent.update_policy(new_policy)
# 2) Policy Evaluation
values = iterative_policy_eval(agent, env, gamma)
# 3) Check for convergence
if is_stable:
break
return values, agent.ag_policy
Highlights:
Function below displays the maze env:
def display_policy(env: Env, policy: dict):
arrow_map = { UP:'U', DOWN:'D', LEFT:'L', RIGHT:'R' }
for r in range(MAZE_ROWS):
row = []
for c in range(MAZE_COLS):
if (r, c) in WALLS:
row.append('#')
elif (r, c) == GOAL:
row.append('G')
else:
row.append(arrow_map[policy.get((r, c), LEFT)])
print(' '.join(row))
print()
Sample Run:
if __name__ == '__main__':
env = Env()
agent = Agent(env)
print('Initial Policy:')
display_policy(env, agent.ag_policy)
values, policy = policy_iteration(agent, env)
print('Optimal Policy:')
display_policy(env, policy)
for state in sorted(env.states, key=lambda s: s.name):
print(f"State {state.name}: V = {values[state.name]:.2f}")
Output:
Initial Policy
L L L # G
L # L # L
L # L L L
L L L # L
# # L L L
Optimal Policy
R R D # G
U # D # U
D # R R U
R R U # U
# # U R U
State (0, 0): V = -5.22
State (0, 1): V = -4.69
State (0, 2): V = -4.10
State (0, 4): V = 0.00
State (1, 0): V = -5.70
State (1, 2): V = -3.44
State (1, 4): V = 0.00
State (2, 0): V = -5.22
State (2, 2): V = -2.71
State (2, 3): V = -1.90
State (2, 4): V = -1.00
State (3, 0): V = -4.69
State (3, 1): V = -4.10
State (3, 2): V = -3.44
State (3, 4): V = -1.90
State (4, 2): V = -4.10
State (4, 3): V = -3.44
State (4, 4): V = -2.71