RLlib Tutorial¶
In this guide, we will train and deploy a simple Ray RLlib PPO model. In particular, we show:
How to load the model from checkpoint
How to parse the JSON request and evaluate payload in RLlib
Please see the Core APIs to learn more general information about Ray Serve.
Let’s import Ray Serve and some other helpers.
import gym
from starlette.requests import Request
import requests
import ray
import ray.rllib.agents.ppo as ppo
from ray import serve
We will train and checkpoint a simple PPO model with CartPole-v0
environment.
We are just writing to local disk for now. In production, you might want to consider
loading the weights from a cloud storage (S3) or shared file system.
def train_ppo_model():
trainer = ppo.PPOTrainer(
config={
"framework": "torch",
"num_workers": 0
},
env="CartPole-v0",
)
# Train for one iteration
trainer.train()
trainer.save("/tmp/rllib_checkpoint")
return "/tmp/rllib_checkpoint/checkpoint_000001/checkpoint-1"
checkpoint_path = train_ppo_model()
Services are just defined as normal classes with __init__
and __call__
methods.
The __call__
method will be invoked per request. For each request, the method
retrieves the request.json()["observation"]
as input.
Tip
Although we used a single input and trainer.compute_action(...)
here, you
can process a batch of input using Ray Serve’s batching feature
and use trainer.compute_actions(...)
(notice the plural!) to process a
batch.
class ServePPOModel:
def __init__(self, checkpoint_path) -> None:
self.trainer = ppo.PPOTrainer(
config={
"framework": "torch",
# only 1 "local" worker with an env (not really used here).
"num_workers": 0,
},
env="CartPole-v0")
self.trainer.restore(checkpoint_path)
async def __call__(self, request: Request):
json_input = await request.json()
obs = json_input["observation"]
action = self.trainer.compute_action(obs)
return {"action": int(action)}
Now that we’ve defined our services, let’s deploy the model to Ray Serve. We will define an endpoint for the route representing the ppo model, a backend correspond the physical implementation, and connect them together.
client = serve.start()
client.create_backend("ppo", ServePPOModel, checkpoint_path)
client.create_endpoint("ppo-endpoint", backend="ppo", route="/cartpole-ppo")
Let’s query it!
# That's it! Let's test it
for _ in range(10):
env = gym.make("CartPole-v0")
obs = env.reset()
print(f"-> Sending observation {obs}")
resp = requests.get(
"http://localhost:8000/cartpole-ppo",
json={"observation": obs.tolist()})
print(f"<- Received response {resp.json()}")
# Output:
# <- Received response {'action': 1}
# -> Sending observation [0.04228249 0.02289503 0.00690076 0.03095441]
# <- Received response {'action': 0}
# -> Sending observation [ 0.04819471 -0.04702759 -0.00477937 -0.00735569]
# <- Received response {'action': 0}
# ...