← back ·
Implementing an RL algorithm (PPO) from a whitepaper
How difficult is it really?
Deep Learning whitepapers are not straightforward to implement. They require multiple readings, reading related literature, multiple iterations and lots of trial and error.
It makes the task of implementing and engineering research breakthroughs an arduous tasks, especially for less experienced engineers.
I would like to show how that might look like on a random paper I selected.
Proximal Policy Optimization Algorithms (PPO)
…is an on-policy policy gradient reinforcement learning algorithm arXiv:1707.06347
I chose this one for a few reasons - it stood the test of time, is widely used and generic and several baseline implementations exist (e.g. ClearRL PPO) - these qualities make it a very worthwhile piece of code to implement and understand.
But most important of all - despite the being a high quality paper, I found the experience of reading it very similar to many other Deep Learning whitepapers. So it’s not an outlier, but a good representative on the class of problems I want to focus on.
First read of the paper
I want to take you, my dear reader, on a journey. And that means you need to get your feet dirty to fully appreciate the experience.
So please click on the link below and read the original whitepaper. Approach it with a sense of curiosity that you need knowing that in a moment you will need to implement it. Please note your observations somewhere.
You may or may not be familiar with Reinforcement Learning. No matter your experience, I will assume you have no clue about it this is your first experience with it. I will however assume that you are a software engineer - you have a working knowledge of a select programming language (we’ll be using Python here), basic data structures and algorithms, code complexity management, testing your code.
Go ahead - arXiv:1707.06347 - come back here when you’re done.
Engineering approach to implementation
The first order questions I will attempt to answer are:
-
what exactly am I implementing
-
how should I validate my implementation
Note that we are translating a paper to an implementation, rather than designing our own algorithm/system - therefore we don’t need to worry about defining and fulfilling functional and non-functional requirements.
Evaluation
In terms of evaluation we can find references to various RL environments the original paper was evaluated on in sections Experiments and Appendix B:
-
environments that have been moved to OpenAI Gymnasium since the publication of this paper
-
Roboschool has been deprecated, and the link to PyBullet mentioned in the deprecation note no longer works
OpenAI Gymnasium is well documented and has a standardized environment API which allows to interact with any environment using the same code.
Cumulative score from each environment is used as the metric. Notice that the scores are specific to each environment (different Y axis scales):
Figure 1. Example PPO evaluation plots
The X axis on the plots refers to the evaluation scores obtained after x training steps.
Based on this information, we can develop the following framework:
import logging
import gymnasium as gym
def evaluate(**evaluated_algo: ???** , training_step: int) -> None:
eval_environment_ids = ["HalfCheetah-v1", "Hopper-v1", "Swimmer-v1", ] # ... add others
for env_id in eval_environment_ids:
env = gym.make(env_id)
# Use the standardized OpenAI Gumnasium API to run the environments
observation = env.reset()
episode_finished = False
score = 0
while not episode_finished:
**action = evaluated_algo(observation)** observation, reward, episode_finished, _, _ = env.step(action)
score += reward
logging.info("Score for %s after training step %d is %f", env_id, training_step, score)
I highlighted the two open questions:
-
what exactly is the implemented algorithm is still unclear
-
we can technically figure out that the implemented algo is supposed to convert observations returned by the environment into actions that will be fed back to the environment - this information is NOT mentioned in the paper however.
Arcane knowledge
The second point specifically requires the reader to scale the ladder of references and build an understanding of what is the essence of Reinforcement Learning algorithms.
Today, with ample resources, RL is not a mystery it used to be in year 2017, when this paper was originally published. Back then however, this was arcane knowledge, with few materials and fewer available example implementations. What was worse - the knowledge that existed introduced such vast terminology that it was very difficult to piece a cohesive understanding of these algorithms easily.
This led (in my case, and perhaps in yours too) to loosing the forest for the trees - instead of being able to implement e.g. PPO, one had to first undestand “on-policy”, “off-policy”, “gradient policy”, “policy functions”, “value functions”, “Bellman equation”, “losses”, “rewards - sparse and dense” …
Just take a look for yourself:
Figure 2. Taxonomy of Reinforcement learning models, Springer
And here’s a very good overview paper - arXiv:2209.14940v
This approach is employed by 99% of the whitepapers I’ve read during my career, and makes each of them a delta that builds on top of other information, rather than a self contained thing.
Arcane knowledge of Reinforcement Learning demistified
Figure 3. Reinforcement learning system, simplified view
Reinforcement Learning describes a system of 2 entities - an environment and a policy, that trade 3 key pieces of data between themselves: observations (sometimes referred to as state), actions and rewards.
Both can be thought of as functions:
def environment(action) -> tuple[observation, reward]:
pass
def policy(observation, reward) -> action:
pass
The “learning” part of Reinforcement Learning system pertains to that reward signal we are passign to the policy - we are stating that the policy learns and changes its definition based on that signal. Environment on the other hand never changes its definition and stays constant.
How the policy learns - that is the real reason behind such diversity of methods shown in Figure 2 - some versions learn from each action, some require massive amounts of accumulated data, and some use auxiliary entities.
Enter PPO.
PPO demistified
Figure 4. PPO overlayed on top of the training system. System from Figure 3 colored purple.
I arrived at the diagram in Figure 4 by stitching the information scattered across chapters 2, 3, 5 and 6. That information has to be grafted onto a framework that the paper does not mention but assumes familiarity with.
Step 1 overall PPO training algorithm
Figure 5. Relationship between the PPO training algorithm and the RL system design that uses PPO.
We can find it in section 5. In Figure 5 I colored the relevant components from Figure 4 to highlight the role they play in the algorithm
Step 2 Computing advantages
You will find references to those concepts showing up in Sections 2-5, with the definition of generalized advantage estimation presented in equation 11 and 12:
Notice the use of V(s_{t+1}) and V(s_{t}) - these are values obtained from the Value function, an auxiliary neural function that PPO algorithm introduces.
The code, implemented in pytorch, looks like this:
def compute_gae(
rewards: torch.Tensor,
values: torch.Tensor,
dones: torch.Tensor,
discount: float,
gae_lambda: float = 0.95,
) -> torch.Tensor:
"""Compute Generalized Advantage Estimation."""
advantages: list[torch.Tensor] = []
gae = torch.tensor(0.0)
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = torch.tensor(
0.0
) # Bootstrap value (or get V(s_T) if not terminal)
else:
next_value = values[t + 1]
# Mask next_value if episode ended
next_value = next_value * (1 - dones[t])
# δ_t = r_t + γV(s_{t+1}) - V(s_t)
delta = rewards[t] + discount * next_value - values[t]
# A_t = δ_t + (γλ) δ_{t+1} + ... + (γλ)_{T-t+1}δ_{T-1}
gae = delta + discount * gae_lambda * (1 - dones[t]) * gae
advantages.insert(0, gae)
return torch.FloatTensor(advantages)
Step 3 Optimize surrogate L (loss)
Loss is defined in chapters 3 and 5
Clipped policy loss (chapter 3):
Full loss equation that combines the clipped policy loss with value function loss and entropy bonus (chapter 5):
PyTorch implementation of the loss looks like this:
class Trajectory(typing.NamedTuple):
states: torch.Tensor
actions: torch.Tensor
rewards: torch.Tensor
log_probs: torch.Tensor
dones: torch.Tensor
def ppo_loss(
agent: _PPOAgent, trajectory: Trajectory, discount: float, clip_eps: float
) -> torch.Tensor:
# LCLIP + L_VF + L_St(θ) = E_t[LCLIP_t(θ) − c_1 L_t^VF(θ) + c_2 S[πθ](st)]
log_probs, entropies = agent.policy.evaluate_actions(
trajectory.states, trajectory.actions
)
values = agent.value_fn(trajectory.states)
advantages = compute_gae(trajectory.rewards, values, trajectory.dones, discount)
returns = advantages + values
# Normalize advantages to stabilize training
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
assert log_probs.shape == (trajectory.states.shape[0],)
assert trajectory.log_probs.shape == (trajectory.states.shape[0],)
ratio = torch.exp(log_probs - trajectory.log_probs)
policy_loss = -torch.min(
ratio * advantages,
torch.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages,
).mean()
value_loss = 0.5 * (returns - values).pow(2).mean()
c_2 = 0.01
entropy_loss = c_2 * entropies.mean()
loss = policy_loss + value_loss - entropy_loss
return loss
I am on purpose omitting the implementations of PPOPolicy and _PPOValue for the time being in order not to distract from the main objective of this section - the loss function.
Note that the loss function expects several values to be provided either as input or as the function of the policy, among them log probabilities and entropies.
Also notice that there are 2 sets of log probabilities - those calculated by the policy.evaluate_actions, and those provided as a part of the trajectory input.
Picking up on this implementation nuance requires jumping to the beginning of chapter 3 of the paper and looking at the definition of probability ratio that is used in clipped policy loss.
pi_{theta} and pi_{theta}_old refer to probabilities (but not the log probabilities), of actions generated by a new (in this context - trained) policy, and the old (in this context - the policy used to collect the trajectory).
An algorithmic trick that inolves taking a logarithm of this equation allows us to use log probabilities , which packages such as pytorch readily provide from their distribution implementations.
Step 4 Policy and Value function network architecture
This detail is described all the way in chapter 6
Step 5 Implementing the Policy function
At this point we can go ahead and implement the networks themselves.
To make the long story short - depending on the types of actions an environment works with - continuous or discrete - we need to interpret the logits returned by a model differently, and slightly change the architecture of the network
Policy base class
class PPOPolicy(nn.Module, abc.ABC):
def __init__(self, state_dim: int, action_dim: int) -> None:
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
@abc.abstractmethod
def get_action(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
pass
@abc.abstractmethod
def evaluate_actions(
self, states: torch.Tensor, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
pass
Implementation for the continuous action spaces trains the policy model to return parameters (1-st moments) of a gaussian distribution, and then samples those to return action values. An action is a vector of floating points - they may for example represent the speeds of motors rotating a robot’s arm:
class PPOPolicyContinuous(PPOPolicy):
def __init__(self, state_dim: int, action_dim: int) -> None:
super().__init__(state_dim, action_dim)
self.policy = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 2 * action_dim),
)
def _get_action_distribution(self, state: torch.Tensor) -> dist.Distribution:
action_gaussian_params = self.policy(state)
means, log_stds = torch.chunk(action_gaussian_params, 2, dim=-1)
stds = torch.exp(log_stds.clamp(-20, 2))
return dist.Normal(loc=means, scale=stds)
def get_action(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if len(states.shape) != 2 or states.shape[1] != self.state_dim:
raise ValueError(
f"state should have shape (T, {self.state_dim}) != {states.shape}"
)
actions_dist = self._get_action_distribution(states)
action = actions_dist.sample()
# aggregate across the last dimension - we want the probability
# of the joint distribution of actions for a given timestep
log_prob = actions_dist.log_prob(action).sum(-1)
assert action.shape == (states.shape[0], self.action_dim)
assert log_prob.shape == (states.shape[0],)
return action, log_prob
def evaluate_actions(
self, states: torch.Tensor, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
if len(states.shape) != 2 or states.shape[1] != self.state_dim:
raise ValueError(
f"states should have shape (T, {self.state_dim}) != {states.shape}"
)
if len(actions.shape) != 2 or actions.shape[1] != self.action_dim:
raise ValueError(
f"actions should have shape (T, {self.action_dim}) != {actions.shape}"
)
actions_dist = self._get_action_distribution(states)
log_probs = actions_dist.log_prob(actions)
entropy = actions_dist.entropy()
# aggregate across the last dimension - we want the probability
# and the entropy of the joint distribution of actions for a given timestep
log_probs = log_probs.sum(-1)
entropy = entropy.sum(-1)
assert log_probs.shape == (states.shape[0],)
assert entropy.shape == (states.shape[0],)
return log_probs, entropy
Policy for discrete action spaces on the other hand assumes that the policy returns a single integer number that represents a discrete action to be taken by an environment. An example of such action is the index of an arrow key on a keyboard when we play one of atari games.
class PPOPolicyDiscrete(PPOPolicy):
def __init__(self, state_dim: int, action_dim: int) -> None:
super().__init__(state_dim, action_dim)
self.policy = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, action_dim),
)
def _get_action_distribution(self, state: torch.Tensor) -> dist.Distribution:
logits = self.policy(state)
logits = nn.functional.softplus(logits)
return dist.Categorical(logits)
def get_action(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if len(states.shape) != 2 or states.shape[1] != self.state_dim:
raise ValueError(
f"State should have shape (T, {self.state_dim}) != {states.shape}"
)
actions_dist = self._get_action_distribution(states)
action = actions_dist.sample()
log_prob = actions_dist.log_prob(action)
assert action.shape == (states.shape[0],)
assert log_prob.shape == (states.shape[0],)
return action, log_prob
def evaluate_actions(
self, states: torch.Tensor, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
if len(states.shape) != 2 or states.shape[1] != self.state_dim:
raise ValueError(
f"states should have shape (T, {self.state_dim}) != {states.shape}"
)
if len(actions.shape) != 1:
raise ValueError(f"actions should have shape (T,) != {actions.shape}")
actions_dist = self._get_action_distribution(states)
log_probs = actions_dist.log_prob(actions)
entropy = actions_dist.entropy()
assert log_probs.shape == (states.shape[0],)
assert entropy.shape == (states.shape[0],)
return log_probs, entropy
Step 6 Implementing the Value function
The value function uses the same network architecture, with the difference of returning a single floating point value - the value the function would assign to the state (observation) of an environment.
class _PPOValue(nn.Module):
def __init__(self, state_dim: int):
super().__init__()
self.value = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1),
)
def forward(self, states: torch.Tensor) -> torch.Tensor:
values = self.value(states)
return values.squeeze(
-1
) # squeeze returns a tensor that's shaped just like the rewards tensor
Complete code
Here’s the complete solution:
import abc
import logging
import typing
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
import mlflow
import torch
import torch.distributions as dist
import torch.nn as nn
from tqdm import trange # type: ignore
def compute_gae(
rewards: torch.Tensor,
values: torch.Tensor,
dones: torch.Tensor,
discount: float,
gae_lambda: float = 0.95,
) -> torch.Tensor:
"""Compute Generalized Advantage Estimation."""
advantages: list[torch.Tensor] = []
gae = torch.tensor(0.0)
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = torch.tensor(
0.0
) # Bootstrap value (or get V(s_T) if not terminal)
else:
next_value = values[t + 1]
# Mask next_value if episode ended
next_value = next_value * (1 - dones[t])
# δ_t = r_t + γV(s_{t+1}) - V(s_t)
delta = rewards[t] + discount * next_value - values[t]
# A_t = δ_t + (γλ) δ_{t+1} + ... + (γλ)_{T-t+1}δ_{T-1}
gae = delta + discount * gae_lambda * (1 - dones[t]) * gae
advantages.insert(0, gae)
return torch.FloatTensor(advantages)
class PPOPolicy(nn.Module, abc.ABC):
def __init__(self, state_dim: int, action_dim: int) -> None:
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
@abc.abstractmethod
def get_action(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
pass
@abc.abstractmethod
def evaluate_actions(
self, states: torch.Tensor, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
pass
class PPOPolicyContinuous(PPOPolicy):
def __init__(self, state_dim: int, action_dim: int) -> None:
super().__init__(state_dim, action_dim)
self.policy = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 2 * action_dim),
)
def _get_action_distribution(self, state: torch.Tensor) -> dist.Distribution:
action_gaussian_params = self.policy(state)
means, log_stds = torch.chunk(action_gaussian_params, 2, dim=-1)
stds = torch.exp(log_stds.clamp(-20, 2))
return dist.Normal(loc=means, scale=stds)
def get_action(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if len(states.shape) != 2 or states.shape[1] != self.state_dim:
raise ValueError(
f"state should have shape (T, {self.state_dim}) != {states.shape}"
)
actions_dist = self._get_action_distribution(states)
action = actions_dist.sample()
# aggregate across the last dimension - we want the probability
# of the joint distribution of actions for a given timestep
log_prob = actions_dist.log_prob(action).sum(-1)
assert action.shape == (states.shape[0], self.action_dim)
assert log_prob.shape == (states.shape[0],)
return action, log_prob
def evaluate_actions(
self, states: torch.Tensor, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
if len(states.shape) != 2 or states.shape[1] != self.state_dim:
raise ValueError(
f"states should have shape (T, {self.state_dim}) != {states.shape}"
)
if len(actions.shape) != 2 or actions.shape[1] != self.action_dim:
raise ValueError(
f"actions should have shape (T, {self.action_dim}) != {actions.shape}"
)
actions_dist = self._get_action_distribution(states)
log_probs = actions_dist.log_prob(actions)
entropy = actions_dist.entropy()
# aggregate across the last dimension - we want the probability
# and the entropy of the joint distribution of actions for a given timestep
log_probs = log_probs.sum(-1)
entropy = entropy.sum(-1)
assert log_probs.shape == (states.shape[0],)
assert entropy.shape == (states.shape[0],)
return log_probs, entropy
class PPOPolicyDiscrete(PPOPolicy):
def __init__(self, state_dim: int, action_dim: int) -> None:
super().__init__(state_dim, action_dim)
self.policy = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, action_dim),
)
def _get_action_distribution(self, state: torch.Tensor) -> dist.Distribution:
logits = self.policy(state)
logits = nn.functional.softplus(logits)
return dist.Categorical(logits)
def get_action(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if len(states.shape) != 2 or states.shape[1] != self.state_dim:
raise ValueError(
f"State should have shape (T, {self.state_dim}) != {states.shape}"
)
actions_dist = self._get_action_distribution(states)
action = actions_dist.sample()
log_prob = actions_dist.log_prob(action)
assert action.shape == (states.shape[0],)
assert log_prob.shape == (states.shape[0],)
return action, log_prob
def evaluate_actions(
self, states: torch.Tensor, actions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
if len(states.shape) != 2 or states.shape[1] != self.state_dim:
raise ValueError(
f"states should have shape (T, {self.state_dim}) != {states.shape}"
)
if len(actions.shape) != 1:
raise ValueError(f"actions should have shape (T,) != {actions.shape}")
actions_dist = self._get_action_distribution(states)
log_probs = actions_dist.log_prob(actions)
entropy = actions_dist.entropy()
assert log_probs.shape == (states.shape[0],)
assert entropy.shape == (states.shape[0],)
return log_probs, entropy
class _PPOValue(nn.Module):
def __init__(self, state_dim: int):
super().__init__()
self.value = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1),
)
def forward(self, states: torch.Tensor) -> torch.Tensor:
values = self.value(states)
return values.squeeze(
-1
) # squeeze returns a tensor that's shaped just like the rewards tensor
class Trajectory(typing.NamedTuple):
states: torch.Tensor
actions: torch.Tensor
rewards: torch.Tensor
log_probs: torch.Tensor
dones: torch.Tensor
def enable_grad(self) -> None:
self.states.requires_grad = self.states.dtype.is_floating_point
self.actions.requires_grad = self.actions.dtype.is_floating_point
self.log_probs.requires_grad = self.log_probs.dtype.is_floating_point
self.rewards.requires_grad = self.rewards.dtype.is_floating_point
self.dones.requires_grad = self.dones.dtype.is_floating_point
def __len__(self) -> int:
if self.states is not None and len(self.states.shape) >= 1:
return self.states.shape[0]
else:
return 0
@staticmethod
def concat(lhs: typing.Optional["Trajectory"], rhs: "Trajectory") -> "Trajectory":
if lhs is None:
return rhs
return Trajectory(
torch.concat([lhs.states, rhs.states]),
torch.concat([lhs.actions, rhs.actions]),
torch.concat([lhs.rewards, rhs.rewards]),
torch.concat([lhs.log_probs, rhs.log_probs]),
torch.concat([lhs.dones, rhs.dones]),
)
class _PPOAgent(nn.Module):
def __init__(self, policy: PPOPolicy, value_fn: _PPOValue) -> None:
super().__init__()
self.policy = policy
self.value_fn = value_fn
def ppo_loss(
agent: _PPOAgent, trajectory: Trajectory, discount: float, clip_eps: float
) -> torch.Tensor:
# LCLIP + L_VF + L_St(θ) = E_t[LCLIP_t(θ) − c_1 L_t^VF(θ) + c_2 S[πθ](st)]
log_probs, entropies = agent.policy.evaluate_actions(
trajectory.states, trajectory.actions
)
values = agent.value_fn(trajectory.states)
advantages = compute_gae(trajectory.rewards, values, trajectory.dones, discount)
returns = advantages + values
# Normalize advantages to stabilize training
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
assert log_probs.shape == (trajectory.states.shape[0],)
assert trajectory.log_probs.shape == (trajectory.states.shape[0],)
ratio = torch.exp(log_probs - trajectory.log_probs)
policy_loss = -torch.min(
ratio * advantages,
torch.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages,
).mean()
value_loss = 0.5 * (returns - values).pow(2).mean()
c_2 = 0.01
entropy_loss = c_2 * entropies.mean()
loss = policy_loss + value_loss - entropy_loss
return loss
def train_one_trajectory(
agent: _PPOAgent,
optimizer: torch.optim.Optimizer,
trajectory: Trajectory,
num_updates: int,
discount: float = 0.95,
clip_eps: float = 0.2,
) -> list[torch.Tensor]:
"""Trains the PPO policy and value networks on a single trajectory.
Multiple steps of training are preformed, the number defined by `num_updates` parameter.
After calling this function, `trajectory` should be discarded and a new trajectory should be
sampled.
"""
trajectory.enable_grad()
losses: list[torch.Tensor] = []
for _ in trange(num_updates, desc="PPO update step", leave=False):
optimizer.zero_grad()
loss = ppo_loss(agent, trajectory, discount, clip_eps)
losses.append(loss)
loss.backward()
# TODO: does gradient clipping fix training?
# nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
optimizer.step()
return losses
def rollout_episode(env, policy: PPOPolicy, max_trajectory_len: int) -> Trajectory:
states, rewards, actions, log_probs = [], [], [], []
dones: list[int] = []
state = env.reset()[0]
done = False
num_steps = 0
with torch.no_grad():
while not done and num_steps < max_trajectory_len:
state_t = torch.tensor(state)
one_state_in_batch_t = state_t.unsqueeze(0)
action_t, log_prob_t = policy.get_action(one_state_in_batch_t)
# Extract the only action and log probability from the batch tensor
log_prob_f = log_prob_t.detach()[0]
action_f = action_t.detach()[0]
next_state, reward, done, _, _ = env.step(action_f.numpy())
states.append(state_t)
actions.append(action_f)
rewards.append(reward)
log_probs.append(log_prob_f)
dones.append(1 if done else 0)
state = next_state
num_steps += 1
if not states:
raise RuntimeError("No trajectory rolled out")
states_t = torch.stack(states)
actions_t = torch.stack(actions)
rewards_t = torch.tensor(rewards)
log_probs_t = torch.stack(log_probs)
dones_t = torch.tensor(dones)
assert states_t.shape == (num_steps, policy.state_dim)
assert rewards_t.shape == (num_steps,)
assert log_probs_t.shape == (num_steps,)
assert dones_t.shape == (num_steps,)
if isinstance(policy, PPOPolicyContinuous):
assert actions_t.shape == (num_steps, policy.action_dim)
elif isinstance(policy, PPOPolicyDiscrete):
assert actions_t.shape == (num_steps,)
else:
raise TypeError(f"Unsupported policy type {type(policy)}")
return Trajectory(states_t, actions_t, rewards_t, log_probs_t, dones_t)
def train(
env,
policy: PPOPolicy,
num_updates_per_epoch: int,
num_episodes_per_epoch: int,
num_epochs: int,
max_trajectory_len: int,
discount: float = 0.95,
clip_eps: float = 0.2,
) -> None:
# TODO: how to implement learning rate decay?
value_fn = _PPOValue(policy.state_dim)
agent = _PPOAgent(policy, value_fn)
optimizer = torch.optim.Adam(agent.parameters(), lr=3e-4)
experiment = mlflow.set_experiment("PPO training")
with mlflow.start_run(experiment_id=experiment.experiment_id):
mlflow.log_param("torch_version", torch.__version__)
mlflow.log_param("cuda_available", torch.cuda.is_available())
mlflow.log_param("num_updates_per_epoch", num_updates_per_epoch)
mlflow.log_param("num_episodes_per_epoch", num_episodes_per_epoch)
mlflow.log_param("num_epochs", num_epochs)
mlflow.log_param("max_trajectory_len", max_trajectory_len)
mlflow.log_param("discount", discount)
mlflow.log_param("clip_eps", clip_eps)
if torch.cuda.is_available():
mlflow.log_param("cuda_device_count", torch.cuda.device_count())
try:
mlflow.log_param("cuda_device_name", torch.cuda.get_device_name(0))
except Exception:
pass
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)
mlflow.log_param("total_parameters", total_params)
mlflow.log_param("trainable_parameters", trainable_params)
mlflow.log_param(
"trainable_percentage", 100.0 * trainable_params / total_params
)
for epoch in trange(num_epochs, desc="train", leave=False):
trajectory = None
for _ in trange(
num_episodes_per_epoch, desc="Collecting trajectory", leave=False
):
traj_step = rollout_episode(env, policy, max_trajectory_len)
trajectory = Trajectory.concat(trajectory, traj_step)
if len(trajectory) > max_trajectory_len:
break
if trajectory is None:
logging.warning("Trajectory not collected")
return
mlflow.log_metric("traj_len", trajectory.states.shape[0], step=epoch)
losses = train_one_trajectory(
agent, optimizer, trajectory, num_updates_per_epoch, discount, clip_eps
)
losses_t = torch.stack(losses)
mlflow.log_metric("loss", float(losses_t.mean()), step=epoch)
# evaluation
# TODO: extract to a separate function - but it will require extracting mlflow initialization
# to log artifacts to the same run
env_eval: RecordVideo = RecordVideo(
env,
video_folder="./videos/",
episode_trigger=lambda x: (x % 10 == 0), # record every 10th episode
name_prefix="ppo-lunarlander",
)
with torch.no_grad():
for eval_epoch in trange(num_epochs, desc="eval", leave=False):
trajectory = rollout_episode(env_eval, policy, max_trajectory_len)
loss = ppo_loss(agent, trajectory, discount, clip_eps)
mlflow.log_metric("eval loss", float(loss), step=eval_epoch)
def main():
env = gym.make("LunarLander-v3", render_mode="rgb_array")
action_dim = int(env.action_space.n)
state_dim = env.reset()[0].shape[0]
policy = PPOPolicyDiscrete(state_dim=state_dim, action_dim=action_dim)
train(
env,
policy,
num_updates_per_epoch=20,
num_episodes_per_epoch=10,
num_epochs=100,
max_trajectory_len=1000,
)
if __name__ == "__main__":
main()
It was written and tested with Python 3.14 and uses the following dependencies:
gymnasium[box2d,other]>=1.2.3
mlflow>=3.3.1
pydantic>=2.11.7
swig>=4.4.1
torch>=2.9.0
tqdm>=4.67.1
Conclusion
I’d love to hear you comments and thoughts on the topic.
The point here was to show that implementing a Deep Learning algorithm from a whitepaper is not a straightforward challenge and to encourage you to do it.









