RLlib Tutorial

In this guide, we will train and deploy a simple Ray RLlib PPO model. In particular, we show:

  • How to load the model from checkpoint

  • How to parse the JSON request and evaluate payload in RLlib

Please see the Core API: Deployments to learn more general information about Ray Serve.

Let’s import Ray Serve and some other helpers.

import gym
from starlette.requests import Request
import requests

import ray
import ray.rllib.agents.ppo as ppo
from ray import serve

We will train and checkpoint a simple PPO model with CartPole-v0 environment. We are just writing to local disk for now. In production, you might want to consider loading the weights from a cloud storage (S3) or shared file system.

def train_ppo_model():
    trainer = ppo.PPOTrainer(
            "framework": "torch",
            "num_workers": 0
    # Train for one iteration
    return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"

checkpoint_path = train_ppo_model()

Services are just defined as normal classes with __init__ and __call__ methods. The __call__ method will be invoked per request. For each request, the method retrieves the request.json()["observation"] as input.


Although we used a single input and trainer.compute_action(...) here, you can process a batch of input using Ray Serve’s batching feature and use trainer.compute_actions(...) (notice the plural!) to process a batch.

class ServePPOModel:
    def __init__(self, checkpoint_path) -> None:
        self.trainer = ppo.PPOTrainer(
                "framework": "torch",
                # only 1 "local" worker with an env (not really used here).
                "num_workers": 0,

    async def __call__(self, request: Request):
        json_input = await request.json()
        obs = json_input["observation"]

        action = self.trainer.compute_action(obs)
        return {"action": int(action)}

Now that we’ve defined our services, let’s deploy the model to Ray Serve. We will define a Serve deployment that will be exposed over an HTTP route.


Let’s query it!

# That's it! Let's test it
for _ in range(10):
    env = gym.make("CartPole-v0")
    obs = env.reset()

    print(f"-> Sending observation {obs}")
    resp = requests.get(
        json={"observation": obs.tolist()})
    print(f"<- Received response {resp.json()}")
# Output:
# <- Received response {'action': 1}
# -> Sending observation [0.04228249 0.02289503 0.00690076 0.03095441]
# <- Received response {'action': 0}
# -> Sending observation [ 0.04819471 -0.04702759 -0.00477937 -0.00735569]
# <- Received response {'action': 0}
# ...