ray.rllib.evaluation.rollout_worker.RolloutWorker.learn_on_batch#

RolloutWorker.learn_on_batch(samples: SampleBatch | MultiAgentBatch | Dict[str, Any]) Dict[source]#

Update policies based on the given batch.

This is the equivalent to apply_gradients(compute_gradients(samples)), but can be optimized to avoid pulling gradients into CPU memory.

Parameters:

samples – The SampleBatch or MultiAgentBatch to learn on.

Returns:

Dictionary of extra metadata from compute_gradients().

import gymnasium as gym
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy
worker = RolloutWorker(
  env_creator=lambda _: gym.make("CartPole-v1"),
  default_policy_class=PPOTF1Policy)
batch = worker.sample()
info = worker.learn_on_batch(samples)