# Serving RLlib Models¶

In this guide, we will train and deploy a simple Ray RLlib model. In particular, we show:

• How to train and store an RLlib model.

• How to load this model from a checkpoint.

• How to parse the JSON request and evaluate the payload in RLlib.

We will train and checkpoint a simple PPO model with the CartPole-v0 environment from gym. In this tutorial we simply write to local disk, but in production you might want to consider using a cloud storage solution like S3 or a shared file system.

Let’s get started by defining a PPO instance, training it for one iteration and then creating a checkpoint:

import ray
import ray.rllib.algorithms.ppo as ppo
from ray import serve

def train_ppo_model():
# Configure our PPO algorithm.
config = ppo.PPOConfig()\
.framework("torch")\
.rollouts(num_rollout_workers=0)
# Create a PPO instance from the config.
algo = config.build(env="CartPole-v0")
# Train for one iteration.
algo.train()
# Save state of the trained Algorithm in a checkpoint.
algo.save("/tmp/rllib_checkpoint")
return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"

checkpoint_path = train_ppo_model()

You create deployments with Ray Serve by using the @serve.deployment on a class that implements two methods:

• The __init__ call creates the deployment instance and loads your data once. In the below example we restore our PPO Algorithm from the checkpoint we just created.

• The __call__ method will be invoked every request. For each incoming request, this method has access to a request object, which is a Starlette Request.

We can load the request body as a JSON object and, assuming there is a key called observation, in your deployment you can use request.json()["observation"] to retrieve observations (obs) and pass them into the restored Algorithm using the compute_single_action method.

from starlette.requests import Request

@serve.deployment(route_prefix="/cartpole-ppo")
class ServePPOModel:
def __init__(self, checkpoint_path) -> None:
# Re-create the originally used config.
config = ppo.PPOConfig()\
.framework("torch")\
.rollouts(num_rollout_workers=0)
# Build the Algorithm instance using the config.
self.algorithm = config.build(env="CartPole-v0")
# Restore the algo's state from the checkpoint.
self.algorithm.restore(checkpoint_path)

async def __call__(self, request: Request):
json_input = await request.json()
obs = json_input["observation"]

action = self.algorithm.compute_single_action(obs)
return {"action": int(action)}

Tip

Although we used a single input and Algorithm.compute_single_action(...) here, you can process a batch of input using Ray Serve’s batching feature and use Algorithm.compute_actions(...) to process a batch of inputs.

Now that we’ve defined our ServePPOModel service, let’s deploy it to Ray Serve. The deployment will be exposed through the /cartpole-ppo route.

serve.start()
ServePPOModel.deploy(checkpoint_path)

Note that the checkpoint_path that we passed to the deploy() method will be passed to the __init__ method of the ServePPOModel class that we defined above.

Now that the model is deployed, let’s query it!

import gym
import requests

for _ in range(5):
env = gym.make("CartPole-v0")
obs = env.reset()

print(f"-> Sending observation {obs}")
resp = requests.get(
"http://localhost:8000/cartpole-ppo", json={"observation": obs.tolist()}
)

You should see output like this (observation values will differ):