ray.rllib.evaluation.rollout_worker.RolloutWorker.learn_on_batch#
- RolloutWorker.learn_on_batch(samples: SampleBatch | MultiAgentBatch | Dict[str, Any]) Dict [source]#
Update policies based on the given batch.
This is the equivalent to apply_gradients(compute_gradients(samples)), but can be optimized to avoid pulling gradients into CPU memory.
- Parameters:
samples – The SampleBatch or MultiAgentBatch to learn on.
- Returns:
Dictionary of extra metadata from compute_gradients().
import gymnasium as gym from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy worker = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v1"), default_policy_class=PPOTF1Policy) batch = worker.sample() info = worker.learn_on_batch(samples)