Source code for ray.rllib.execution.train_ops
import logging
import numpy as np
import math
from typing import Dict
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_TRAINED,
NUM_AGENT_STEPS_TRAINED,
LEARN_ON_BATCH_TIMER,
LOAD_BATCH_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.util import log_once
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
[docs]
@OldAPIStack
def train_one_step(algorithm, train_batch, policies_to_train=None) -> Dict:
"""Function that improves the all policies in `train_batch` on the local worker.
.. testcode::
:skipif: True
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
algo = [...]
train_batch = synchronous_parallel_sample(algo.env_runner_group)
# This trains the policy on one batch.
print(train_one_step(algo, train_batch)))
.. testoutput::
{"default_policy": ...}
Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
the LEARN_ON_BATCH_TIMER timer of the `algorithm` object.
"""
config = algorithm.config
workers = algorithm.env_runner_group
local_worker = workers.local_env_runner
num_sgd_iter = config.get("num_epochs", config.get("num_sgd_iter", 1))
minibatch_size = config.get("minibatch_size")
if minibatch_size is None:
minibatch_size = config.get("sgd_minibatch_size", 0)
learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
with learn_timer:
# Subsample minibatches (size=`minibatch_size`) from the
# train batch and loop through train batch `num_sgd_iter` times.
if num_sgd_iter > 1 or minibatch_size > 0:
info = do_minibatch_sgd(
train_batch,
{
pid: local_worker.get_policy(pid)
for pid in policies_to_train
or local_worker.get_policies_to_train(train_batch)
},
local_worker,
num_sgd_iter,
minibatch_size,
[],
)
# Single update step using train batch.
else:
info = local_worker.learn_on_batch(train_batch)
learn_timer.push_units_processed(train_batch.count)
algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
if algorithm.reward_estimators:
info[DEFAULT_POLICY_ID]["off_policy_estimation"] = {}
for name, estimator in algorithm.reward_estimators.items():
info[DEFAULT_POLICY_ID]["off_policy_estimation"][name] = estimator.train(
train_batch
)
return info
[docs]
@OldAPIStack
def multi_gpu_train_one_step(algorithm, train_batch) -> Dict:
"""Multi-GPU version of train_one_step.
Uses the policies' `load_batch_into_buffer` and `learn_on_loaded_batch` methods
to be more efficient wrt CPU/GPU data transfers. For example, when doing multiple
passes through a train batch (e.g. for PPO) using `config.num_sgd_iter`, the
actual train batch is only split once and loaded once into the GPU(s).
.. testcode::
:skipif: True
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
algo = [...]
train_batch = synchronous_parallel_sample(algo.env_runner_group)
# This trains the policy on one batch.
print(multi_gpu_train_one_step(algo, train_batch)))
.. testoutput::
{"default_policy": ...}
Updates the NUM_ENV_STEPS_TRAINED and NUM_AGENT_STEPS_TRAINED counters as well as
the LOAD_BATCH_TIMER and LEARN_ON_BATCH_TIMER timers of the Algorithm instance.
"""
if log_once("mulit_gpu_train_one_step_deprecation_warning"):
deprecation_warning(
old=("ray.rllib.execution.train_ops." "multi_gpu_train_one_step")
)
config = algorithm.config
workers = algorithm.env_runner_group
local_worker = workers.local_env_runner
num_sgd_iter = config.get("num_epochs", config.get("num_sgd_iter", 1))
minibatch_size = config.get("minibatch_size")
if minibatch_size is None:
minibatch_size = config["train_batch_size"]
# Determine the number of devices (GPUs or 1 CPU) we use.
num_devices = int(math.ceil(config["num_gpus"] or 1))
# Make sure total batch size is dividable by the number of devices.
# Batch size per tower.
per_device_batch_size = minibatch_size // num_devices
# Total batch size.
batch_size = per_device_batch_size * num_devices
assert batch_size % num_devices == 0
assert batch_size >= num_devices, "Batch size too small!"
# Handle everything as if multi-agent.
train_batch = train_batch.as_multi_agent()
# Load data into GPUs.
load_timer = algorithm._timers[LOAD_BATCH_TIMER]
with load_timer:
num_loaded_samples = {}
for policy_id, batch in train_batch.policy_batches.items():
# Not a policy-to-train.
if (
local_worker.is_policy_to_train is not None
and not local_worker.is_policy_to_train(policy_id, train_batch)
):
continue
# Decompress SampleBatch, in case some columns are compressed.
batch.decompress_if_needed()
# Load the entire train batch into the Policy's only buffer
# (idx=0). Policies only have >1 buffers, if we are training
# asynchronously.
num_loaded_samples[policy_id] = local_worker.policy_map[
policy_id
].load_batch_into_buffer(batch, buffer_index=0)
# Execute minibatch SGD on loaded data.
learn_timer = algorithm._timers[LEARN_ON_BATCH_TIMER]
with learn_timer:
# Use LearnerInfoBuilder as a unified way to build the final
# results dict from `learn_on_loaded_batch` call(s).
# This makes sure results dicts always have the same structure
# no matter the setup (multi-GPU, multi-agent, minibatch SGD,
# tf vs torch).
learner_info_builder = LearnerInfoBuilder(num_devices=num_devices)
for policy_id, samples_per_device in num_loaded_samples.items():
policy = local_worker.policy_map[policy_id]
num_batches = max(1, int(samples_per_device) // int(per_device_batch_size))
logger.debug("== sgd epochs for {} ==".format(policy_id))
for _ in range(num_sgd_iter):
permutation = np.random.permutation(num_batches)
for batch_index in range(num_batches):
# Learn on the pre-loaded data in the buffer.
# Note: For minibatch SGD, the data is an offset into
# the pre-loaded entire train batch.
results = policy.learn_on_loaded_batch(
permutation[batch_index] * per_device_batch_size, buffer_index=0
)
learner_info_builder.add_learn_on_batch_results(results, policy_id)
# Tower reduce and finalize results.
learner_info = learner_info_builder.finalize()
load_timer.push_units_processed(train_batch.count)
learn_timer.push_units_processed(train_batch.count)
# TODO: Move this into Algorithm's `training_step` method for
# better transparency.
algorithm._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
algorithm._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
if algorithm.reward_estimators:
learner_info[DEFAULT_POLICY_ID]["off_policy_estimation"] = {}
for name, estimator in algorithm.reward_estimators.items():
learner_info[DEFAULT_POLICY_ID]["off_policy_estimation"][
name
] = estimator.train(train_batch)
return learner_info