Saving and Loading your RL Algorithms and Policies

You can use Checkpoint objects to store and load the current state of your Algorithm or Policy and the neural networks (weights) within these structures. In the following, we will cover how you can create these checkpoints (and hence save your Algos and Policies) to disk, where you can find them, and how you can recover (load) your Algorithm or Policy from such a given checkpoint.

What’s a checkpoint?

A checkpoint is a set of information, located inside a directory (which may contain further subdirectories) and used to restore either an Algorithm or a single Policy instance. The Algorithm- or Policy instances that were used to create the checkpoint in the first place may or may not have been trained prior to this.

RLlib uses the new Ray AIR Checkpoint class to create checkpoints and restore objects from them.

Algorithm checkpoints

An Algorithm checkpoint contains all of the Algorithm’s state, including its configuration, its actual Algorithm subclass, all of its Policies’ weights, its current counters, etc..

Restoring a new Algorithm from such a Checkpoint leaves you in a state, where you can continue working with that new Algorithm exactly like you would have continued working with the old Algorithm (from which the checkpoint as taken).

How do I create an Algorithm checkpoint?

The Algorithm save() method creates a new checkpoint (directory with files in it) and returns the path to that directory.

Let’s take a look at a simple example on how to create such an Algorithm checkpoint:

# Create a PPO algorithm object using a config object ..
from ray.rllib.algorithms.ppo import PPOConfig

my_ppo_config = PPOConfig().environment("CartPole-v1")
my_ppo = my_ppo_config.build()

# .. train one iteration ..
my_ppo.train()
# .. and call `save()` to create a checkpoint.
path_to_checkpoint = my_ppo.save()
print(
    "An Algorithm checkpoint has been created inside directory: "
    f"'{path_to_checkpoint}'."
)

# Let's terminate the algo for demonstration purposes.
my_ppo.stop()
# Doing this will lead to an error.
# my_ppo.train()

If you take a look at the directory returned by the save() call, you should see something like this:

$ ls -la
  .
  ..
  .is_checkpoint
  .tune_metadata
  policies/
  algorithm_state.pkl
  rllib_checkpoint.json

As you can see, there is a policies sub-directory created for us (more on that later), a algorithm_state.pkl file, and a rllib_checkpoint.json file. The algorithm_state.pkl file contains all state information of the Algorithm that is not Policy-specific, such as the algo’s counters and other important variables to persistently keep track of. The rllib_checkpoint.json file contains the checkpoint version used for the user’s convenience. From Ray RLlib 2.0 and up, all checkpoint versions will be backward compatible, meaning an RLlib version V will be able to handle any checkpoints created with Ray 2.0 or any version up to V.

$ more rllib_checkpoint.json
{"type": "Algorithm", "checkpoint_version": "1.0"}

Now, let’s check out the policies/ sub-directory:

$ cd policies
$ ls -la
  .
  ..
  default_policy/

We can see yet another sub-directory, called default_policy. RLlib creates exactly one sub-directory inside the policies/ dir per Policy instance that the Algorithm uses. In the standard single-agent case, this will be the “default_policy”. Note here, that “default_policy” is the so-called PolicyID. In the multi-agent case, depending on your particular setup and environment, you might see multiple sub-directories here with different names (the PolicyIDs of the different policies trained). For example, if you are training 2 Policies with the IDs “policy_1” and “policy_2”, you should see the sub-directories:

$ ls -la
  .
  ..
  policy_1/
  policy_2/

Lastly, let’s quickly take a look at our default_policy sub-directory:

$ cd default_policy
$ ls -la
  .
  ..
  rllib_checkpoint.json
  policy_state.pkl

Similar to the algorithm’s state (saved within algorithm_state.pkl), a Policy’s state is stored under the policy_state.pkl file. We’ll cover more details on the contents of this file when talking about Policy checkpoints below. Note that Policy checkpoint also have a info file (rllib_checkpoint.json), which is always identical to the enclosing algorithm checkpoint version.

How do I restore an Algorithm from a checkpoint?

Given our checkpoint path (returned by Algorithm.save()), we can now create a completely new Algorithm instance and make it the exact same as the one we had stopped (and could thus no longer use) in the example above:

from ray.rllib.algorithms.algorithm import Algorithm

# Use the Algorithm's `from_checkpoint` utility to get a new algo instance
# that has the exact same state as the old one, from which the checkpoint was
# created in the first place:
my_new_ppo = Algorithm.from_checkpoint(path_to_checkpoint)

# Continue training.
my_new_ppo.train()

Alternatively, you could also first create a new Algorithm instance using the same config that you used for the original algo, and only then call the new Algorithm’s restore() method, passing it the checkpoint directory:

# Re-build a fresh algorithm.
my_new_ppo = my_ppo_config.build()

# Restore the old (checkpointed) state.
my_new_ppo.restore(path_to_checkpoint)

# Continue training.
my_new_ppo.train()

The above procedure used to be the only way of restoring an algo, however, it is more tedious than using the from_checkpoint() utility as it requires an extra step and you will have to keep your original config stored somewhere.

Which Algorithm checkpoint versions can I use?

RLlib uses simple checkpoint versions (for example v0.1 or v1.0) to figure out how to restore an Algorithm (or a Policy; see below) from a given checkpoint directory.

From Ray 2.1 on, you can find the checkpoint version written in the rllib_checkpoint.json file at the top-level of your checkpoint directory. RLlib does not use this file or information therein, it solely exists for the user’s convenience.

From Ray RLlib 2.0 and up, all checkpoint versions will be backward compatible, meaning some RLlib version 2.x will be able to handle any checkpoints created by RLlib 2.0 or any version up to 2.x.

Multi-agent Algorithm checkpoints

In case you are working with a multi-agent setup and have more than one Policy to train inside your Algorithm, you can create an Algorithm checkpoint in the exact same way as described above and will find your individual Policy checkpoints inside the sub-directory policies/.

For example:

import os

# Use our example multi-agent CartPole environment to train in.
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole

# Set up a multi-agent Algorithm, training two policies independently.
my_ma_config = PPOConfig().multi_agent(
    # Which policies should RLlib create and train?
    policies={"pol1", "pol2"},
    # Let RLlib know, which agents in the environment (we'll have "agent1"
    # and "agent2") map to which policies.
    policy_mapping_fn=(
        lambda agent_id, episode, worker, **kw: (
            "pol1" if agent_id == "agent1" else "pol2"
        )
    ),
    # Setting these is not necessary. All policies will always be trained by default.
    # However, since we do provide a list of IDs here, we need to remain in charge of
    # changing this `policies_to_train` list, should we ever alter the Algorithm
    # (e.g. remove one of the policies or add a new one).
    policies_to_train=["pol1", "pol2"],  # Again, `None` would be totally fine here.
)

# Add the MultiAgentCartPole env to our config and build our Algorithm.
my_ma_config.environment(
    MultiAgentCartPole,
    env_config={
        "num_agents": 2,
    },
)

my_ma_algo = my_ma_config.build()

ma_checkpoint_dir = my_ma_algo.save()

print(
    "An Algorithm checkpoint has been created inside directory: "
    f"'{ma_checkpoint_dir}'.\n"
    "Individual Policy checkpoints can be found in "
    f"'{os.path.join(ma_checkpoint_dir, 'policies')}'."
)

# Create a new Algorithm instance from the above checkpoint, just as you would for
# a single-agent setup:
my_ma_algo_clone = Algorithm.from_checkpoint(ma_checkpoint_dir)

Assuming you would like to restore all policies within the checkpoint, you would do so just as described above in the single-agent case (via algo = Algorithm.from_checkpoint([path to your multi-agent checkpoint])).

However, there may be a situation where you have so many policies in your algorithm (e.g. you are doing league-based training) and would like to restore a new Algorithm instance from your checkpoint, but only include some of the original policies in this new Algorithm object. In this case, you can also do:

# Here, we use the same (multi-agent Algorithm) checkpoint as above, but only restore
# it with the first Policy ("pol1").

my_ma_algo_only_pol1 = Algorithm.from_checkpoint(
    ma_checkpoint_dir,
    # Tell the `from_checkpoint` util to create a new Algo, but only with "pol1" in it.
    policy_ids=["pol1"],
    # Make sure to update the mapping function (we must not map to "pol2" anymore
    # to avoid a runtime error). Now both agents ("agent0" and "agent1") map to
    # the same policy.
    policy_mapping_fn=lambda agent_id, episode, worker, **kw: "pol1",
    # Since we defined this above, we have to re-define it here with the updated
    # PolicyIDs, otherwise, RLlib will throw an error (it will think that there is an
    # unknown PolicyID in this list ("pol2")).
    policies_to_train=["pol1"],
)

# Make sure, pol2 is NOT in this Algorithm anymore.
assert my_ma_algo_only_pol1.get_policy("pol2") is None

# Continue training (only with pol1).
my_ma_algo_only_pol1.train()

Policy checkpoints

We have already looked at the policies/ sub-directory inside an Algorithm checkpoint dir and learned that individual policies inside the Algorithm store all their state information under their policy ID inside that sub-directory. Thus, we now have the entire picture of a checkpoint:

.
..
.is_checkpoint
.tune_metadata

algorithm_state.pkl        # <- state of the Algorithm (excluding Policy states)
rllib_checkpoint.json      # <- checkpoint info, such as checkpoint version, e.g. "1.0"

policies/
  policy_A/
    policy_state.pkl       # <- state of policy_A
    rllib_checkpoint.json  # <- checkpoint info, such as checkpoint version, e.g. "1.0"

  policy_B/
    policy_state.pkl       # <- state of policy_B
    rllib_checkpoint.json  # <- checkpoint info, such as checkpoint version, e.g. "1.0"

How do I create a Policy checkpoint?

You can create a Policy checkpoint by either calling save() on your Algorithm, which will save each individual Policy’s checkpoint under the policies/ sub-directory as described above or - if you need more fine-grained control - by doing the following:

# Retrieve the Policy object from an Algorithm.
# Note that for normal, single-agent Algorithms, the Policy ID is "default_policy".
policy1 = my_ma_algo.get_policy(policy_id="pol1")

# Tell RLlib to store an individual policy checkpoint (only for "pol1") inside
# /tmp/my_policy_checkpoint
policy1.export_checkpoint("/tmp/my_policy_checkpoint")

If you now check out the provided directory (/tmp/my_policy_checkpoint/), you should see the following files in there:

.
..
rllib_checkpoint.json   # <- checkpoint info, such as checkpoint version, e.g. "1.0"
policy_state.pkl        # <- state of "pol1"

How do I restore from a Policy checkpoint?

Assume you would like to serve your trained policy(ies) in production and would therefore like to use only the RLlib Policy instance, without all the other functionality that normally comes with the Algorithm object, like different RolloutWorkers for collecting training samples or for evaluation (both of which include RL environment copies), etc..

In this case, it would be quite useful if you had a way to restore just the Policy from either a Policy checkpoint or an Algorithm checkpoint, which - as we learned above - contains all its Policies’ checkpoints.

Here is how you can do this:

import numpy as np

from ray.rllib.policy.policy import Policy

# Use the `from_checkpoint` utility of the Policy class:
my_restored_policy = Policy.from_checkpoint("/tmp/my_policy_checkpoint")

# Use the restored policy for serving actions.
obs = np.array([0.0, 0.1, 0.2, 0.3])  # individual CartPole observation
action = my_restored_policy.compute_single_action(obs)

print(f"Computed action {action} from given CartPole observation.")

How do I restore a multi-agent Algorithm with a subset of the original policies?

Imagine you have trained a multi-agent Algorithm with e.g. 100 different Policies and created a checkpoint from this Algorithm. The checkpoint now includes 100 sub-directories in the policies/ dir, named after the different policy IDs.

After careful evaluation of the different policies, you would like to restore the Algorithm and continue training it, but only with a subset of the original 100 policies, for example only with the policies, whose IDs are “polA” and “polB”.

You can use the original checkpoint (with the 100 policies in it) and the Algorithm.from_checkpoint() utility to achieve this in an efficient way.

This example here shows this for five original policies that you would like reduce to two policies:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole

# Set up an Algorithm with 5 Policies.
algo_w_5_policies = (
    PPOConfig()
    .environment(
        env=MultiAgentCartPole,
        env_config={
            "num_agents": 5,
        },
    )
    .multi_agent(
        policies={"pol0", "pol1", "pol2", "pol3", "pol4"},
        # Map "agent0" -> "pol0", etc...
        policy_mapping_fn=(
            lambda agent_id, episode, worker, **kwargs: f"pol{agent_id}"
        ),
    )
    .build()
)

# .. train one iteration ..
algo_w_5_policies.train()
# .. and call `save()` to create a checkpoint.
path_to_checkpoint = algo_w_5_policies.save()
print(
    "An Algorithm checkpoint has been created inside directory: "
    f"'{path_to_checkpoint}'. It should contain 5 policies in the 'policies/' sub dir."
)
# Let's terminate the algo for demonstration purposes.
algo_w_5_policies.stop()

# We will now recreate a new algo from this checkpoint, but only with 2 of the
# original policies ("pol0" and "pol1"). Note that this will require us to change the
# `policy_mapping_fn` (instead of mapping 5 agents to 5 policies, we now have
# to map 5 agents to only 2 policies).


def new_policy_mapping_fn(agent_id, episode, worker, **kwargs):
    return "pol0" if agent_id in ["agent0", "agent1"] else "pol1"


algo_w_2_policies = Algorithm.from_checkpoint(
    checkpoint=path_to_checkpoint,
    policy_ids={"pol0", "pol1"},  # <- restore only those policy IDs here.
    policy_mapping_fn=new_policy_mapping_fn,  # <- use this new mapping fn.
)

# Test, whether we can train with this new setup.
algo_w_2_policies.train()
# Terminate the new algo.
algo_w_2_policies.stop()

Note that we had to change our original policy_mapping_fn from one that maps “agent0” to “pol0”, “agent1” to “pol1”, etc.. to a new function that maps our five agents to only the two remaining policies: “agent0” and “agent1” to “pol0”, all other agents to “pol1”.

Model Exports

Apart from creating checkpoints for your RLlib objects (such as an RLlib Algorithm or an individual RLlib Policy), it may also be very useful to only export your NN models in their native (non-RLlib dependent) format, for example as a keras- or PyTorch model. You could then use the trained NN models outside of RLlib, e.g. for serving purposes in your production environments.

How do I export my NN Model?

There are several ways of creating Keras- or PyTorch native model “exports”.

Here is the example code that illustrates these:

from ray.rllib.algorithms.ppo import PPOConfig

# Create a new Algorithm (which contains a Policy, which contains a NN Model).
# Switch on for native models to be included in the Policy checkpoints.
ppo_config = (
    PPOConfig().environment("Pendulum-v1").checkpointing(export_native_model_files=True)
)

# The default framework is TensorFlow, but if you would like to do this example with
# PyTorch, uncomment the following line of code:
# ppo_config.framework("torch")

# Create the Algorithm and train one iteration.
ppo = ppo_config.build()
ppo.train()

# Get the underlying PPOTF1Policy (or PPOTorchPolicy) object.
ppo_policy = ppo.get_policy()

We can now export the Keras NN model (that our PPOTF1Policy inside the PPO Algorithm uses) to disk …

  1. Using the Policy object:

ppo_policy.export_model("/tmp/my_nn_model")
# .. check /tmp/my_nn_model/ for the keras model files. You should be able to recover
# the keras model via:
# keras_model = tf.saved_model.load("/tmp/my_nn_model/")
# And pass in a Pendulum-v1 observation:
# results = keras_model(tf.convert_to_tensor(
#     np.array([[0.0, 0.1, 0.2]]), dtype=np.float32)
# )

# For PyTorch, do:
# pytorch_model = torch.load("/tmp/my_nn_model/model.pt")
# results = pytorch_model(
#     input_dict={
#         "obs": torch.from_numpy(np.array([[0.0, 0.1, 0.2]], dtype=np.float32)),
#     },
#     state=[torch.tensor(0)],  # dummy value
#     seq_lens=torch.tensor(0),  # dummy value
# )

  1. Via the Policy’s checkpointing method:

checkpoint_dir = ppo_policy.export_checkpoint("tmp/ppo_policy")
# .. check /tmp/ppo_policy/model/ for the keras model files.
# You should be able to recover the keras model via:
# keras_model = tf.saved_model.load("/tmp/ppo_policy/model")
# And pass in a Pendulum-v1 observation:
# results = keras_model(tf.convert_to_tensor(
#     np.array([[0.0, 0.1, 0.2]]), dtype=np.float32)
# )

  1. Via the Algorithm (Policy) checkpoint:

checkpoint_dir = ppo.save()
# .. check `checkpoint_dir` for the Algorithm checkpoint files.
# You should be able to recover the keras model via:
# keras_model = tf.saved_model.load(checkpoint_dir + "/policies/default_policy/model/")
# And pass in a Pendulum-v1 observation
# results = keras_model(tf.convert_to_tensor(
#     np.array([[0.0, 0.1, 0.2]]), dtype=np.float32)
# )

And what about exporting my NN Models in ONNX format?

RLlib also supports exporting your NN models in the ONNX format. For that, use the Policy export_model method, but provide the extra onnx arg as follows:

# Using the same Policy object, we can also export our NN Model in the ONNX format:
ppo_policy.export_model("/tmp/my_nn_model", onnx=True)