import copy
import functools
import logging
import math
import os
import threading
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, Union
import gymnasium as gym
import numpy as np
from packaging import version
import tree # pip install dm_tree
import ray
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module import RLModule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import _directStepOptimizerSingleton
from ray.rllib.utils import NullContextManager, force_list
from ray.rllib.utils.annotations import (
OldAPIStack,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
is_overridden,
override,
)
from ray.rllib.utils.error import ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
NUM_AGENT_STEPS_TRAINED,
NUM_GRAD_UPDATES_LIFETIME,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.spaces.space_utils import normalize_action
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_utils import (
convert_to_torch_tensor,
TORCH_COMPILE_REQUIRED_VERSION,
)
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
GradInfoDict,
ModelGradients,
ModelWeights,
PolicyState,
TensorStructType,
TensorType,
)
if TYPE_CHECKING:
from ray.rllib.evaluation import Episode # noqa
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
[docs]
@OldAPIStack
class TorchPolicyV2(Policy):
"""PyTorch specific Policy class to use with RLlib."""
[docs]
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
config: AlgorithmConfigDict,
*,
max_seq_len: int = 20,
):
"""Initializes a TorchPolicy instance.
Args:
observation_space: Observation space of the policy.
action_space: Action space of the policy.
config: The Policy's config dict.
max_seq_len: Max sequence length for LSTM training.
"""
self.framework = config["framework"] = "torch"
self._loss_initialized = False
super().__init__(observation_space, action_space, config)
# Create model.
if self.config.get("enable_rl_module_and_learner", False):
model = self.make_rl_module()
dist_class = None
else:
model, dist_class = self._init_model_and_dist_class()
# Create multi-GPU model towers, if necessary.
# - The central main model will be stored under self.model, residing
# on self.device (normally, a CPU).
# - Each GPU will have a copy of that model under
# self.model_gpu_towers, matching the devices in self.devices.
# - Parallelization is done by splitting the train batch and passing
# it through the model copies in parallel, then averaging over the
# resulting gradients, applying these averages on the main model and
# updating all towers' weights from the main model.
# - In case of just one device (1 (fake or real) GPU or 1 CPU), no
# parallelization will be done.
# Get devices to build the graph on.
num_gpus = self._get_num_gpus_for_policy()
gpu_ids = list(range(torch.cuda.device_count()))
logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
# Place on one or more CPU(s) when either:
# - Fake GPU mode.
# - num_gpus=0 (either set by user or we are in local_mode=True).
# - No GPUs available.
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
self.device = torch.device("cpu")
self.devices = [self.device for _ in range(int(math.ceil(num_gpus)) or 1)]
self.model_gpu_towers = [
model if i == 0 else copy.deepcopy(model)
for i in range(int(math.ceil(num_gpus)) or 1)
]
if hasattr(self, "target_model"):
self.target_models = {
m: self.target_model for m in self.model_gpu_towers
}
self.model = model
# Place on one or more actual GPU(s), when:
# - num_gpus > 0 (set by user) AND
# - local_mode=False AND
# - actual GPUs available AND
# - non-fake GPU mode.
else:
# We are a remote worker (WORKER_MODE=1):
# GPUs should be assigned to us by ray.
if ray._private.worker._mode() == ray._private.worker.WORKER_MODE:
gpu_ids = ray.get_gpu_ids()
if len(gpu_ids) < num_gpus:
raise ValueError(
"TorchPolicy was not able to find enough GPU IDs! Found "
f"{gpu_ids}, but num_gpus={num_gpus}."
)
self.devices = [
torch.device("cuda:{}".format(i))
for i, id_ in enumerate(gpu_ids)
if i < num_gpus
]
self.device = self.devices[0]
ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
self.model_gpu_towers = []
for i, _ in enumerate(ids):
model_copy = copy.deepcopy(model)
self.model_gpu_towers.append(model_copy.to(self.devices[i]))
if hasattr(self, "target_model"):
self.target_models = {
m: copy.deepcopy(self.target_model).to(self.devices[i])
for i, m in enumerate(self.model_gpu_towers)
}
self.model = self.model_gpu_towers[0]
self.dist_class = dist_class
self.unwrapped_model = model # used to support DistributedDataParallel
# Lock used for locking some methods on the object-level.
# This prevents possible race conditions when calling the model
# first, then its value function (e.g. in a loss function), in
# between of which another model call is made (e.g. to compute an
# action).
self._lock = threading.RLock()
self._state_inputs = self.model.get_initial_state()
self._is_recurrent = len(tree.flatten(self._state_inputs)) > 0
if self.config.get("enable_rl_module_and_learner", False):
# Maybe update view_requirements, e.g. for recurrent case.
self.view_requirements = self.model.update_default_view_requirements(
self.view_requirements
)
else:
# Auto-update model's inference view requirements, if recurrent.
self._update_model_view_requirements_from_init_state()
# Combine view_requirements for Model and Policy.
self.view_requirements.update(self.model.view_requirements)
if self.config.get("enable_rl_module_and_learner", False):
# We don't need an exploration object with RLModules
self.exploration = None
else:
self.exploration = self._create_exploration()
if not self.config.get("enable_rl_module_and_learner", False):
self._optimizers = force_list(self.optimizer())
# Backward compatibility workaround so Policy will call self.loss()
# directly.
# TODO (jungong): clean up after all policies are migrated to new sub-class
# implementation.
self._loss = None
# Store, which params (by index within the model's list of
# parameters) should be updated per optimizer.
# Maps optimizer idx to set or param indices.
self.multi_gpu_param_groups: List[Set[int]] = []
main_params = {p: i for i, p in enumerate(self.model.parameters())}
for o in self._optimizers:
param_indices = []
for pg_idx, pg in enumerate(o.param_groups):
for p in pg["params"]:
param_indices.append(main_params[p])
self.multi_gpu_param_groups.append(set(param_indices))
# Create n sample-batch buffers (num_multi_gpu_tower_stacks), each
# one with m towers (num_gpus).
num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1)
self._loaded_batches = [[] for _ in range(num_buffers)]
# If set, means we are using distributed allreduce during learning.
self.distributed_world_size = None
self.batch_divisibility_req = self.get_batch_divisibility_req()
self.max_seq_len = max_seq_len
# If model is an RLModule it won't have tower_stats instead there will be a
# self.tower_state[model] -> dict for each tower.
self.tower_stats = {}
if not hasattr(self.model, "tower_stats"):
for model in self.model_gpu_towers:
self.tower_stats[model] = {}
def loss_initialized(self):
return self._loss_initialized
[docs]
@OverrideToImplementCustomLogic
@override(Policy)
def loss(
self,
model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
"""Constructs the loss function.
Args:
model: The Model to calculate the loss for.
dist_class: The action distr. class.
train_batch: The training data.
Returns:
Loss tensor given the input batch.
"""
# Under the new enable_rl_module_and_learner the loss function still gets called
# in order to initialize the view requirements of the sample batches that are
# returned by
# the sampler. In this case, we don't actually want to compute any loss, however
# if we access the keys that are needed for a forward_train pass, then the
# sampler will include those keys in the sample batches it returns. This means
# that the correct sample batch keys will be available when using the learner
# group API.
if self.config.enable_rl_module_and_learner:
for k in model.input_specs_train():
train_batch[k]
return None
else:
raise NotImplementedError
[docs]
@OverrideToImplementCustomLogic
def action_sampler_fn(
self,
model: ModelV2,
*,
obs_batch: TensorType,
state_batches: TensorType,
**kwargs,
) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
"""Custom function for sampling new actions given policy.
Args:
model: Underlying model.
obs_batch: Observation tensor batch.
state_batches: Action sampling state batch.
Returns:
Sampled action
Log-likelihood
Action distribution inputs
Updated state
"""
return None, None, None, None
[docs]
@OverrideToImplementCustomLogic
def action_distribution_fn(
self,
model: ModelV2,
*,
obs_batch: TensorType,
state_batches: TensorType,
**kwargs,
) -> Tuple[TensorType, type, List[TensorType]]:
"""Action distribution function for this Policy.
Args:
model: Underlying model.
obs_batch: Observation tensor batch.
state_batches: Action sampling state batch.
Returns:
Distribution input.
ActionDistribution class.
State outs.
"""
return None, None, None
[docs]
@OverrideToImplementCustomLogic
def make_model(self) -> ModelV2:
"""Create model.
Note: only one of make_model or make_model_and_action_dist
can be overridden.
Returns:
ModelV2 model.
"""
return None
@override(Policy)
def maybe_remove_time_dimension(self, input_dict: Dict[str, TensorType]):
assert self.config.get(
"enable_rl_module_and_learner", False
), "This is a helper method for the new learner API."
if (
self.config.get("enable_rl_module_and_learner", False)
and self.model.is_stateful()
):
# Note that this is a temporary workaround to fit the old sampling stack
# to RL Modules.
ret = {}
def fold_mapping(item):
item = torch.as_tensor(item)
size = item.size()
b_dim, t_dim = list(size[:2])
other_dims = list(size[2:])
return item.reshape([b_dim * t_dim] + other_dims)
for k, v in input_dict.items():
if k not in (Columns.STATE_IN, Columns.STATE_OUT):
ret[k] = tree.map_structure(fold_mapping, v)
else:
# state in already has time dimension.
ret[k] = v
return ret
else:
return input_dict
[docs]
@OverrideToImplementCustomLogic
def make_model_and_action_dist(
self,
) -> Tuple[ModelV2, Type[TorchDistributionWrapper]]:
"""Create model and action distribution function.
Returns:
ModelV2 model.
ActionDistribution class.
"""
return None, None
[docs]
@OverrideToImplementCustomLogic
def get_batch_divisibility_req(self) -> int:
"""Get batch divisibility request.
Returns:
Size N. A sample batch must be of size K*N.
"""
# By default, any sized batch is ok, so simply return 1.
return 1
[docs]
@OverrideToImplementCustomLogic
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Stats function. Returns a dict of statistics.
Args:
train_batch: The SampleBatch (already) used for training.
Returns:
The stats dict.
"""
return {}
[docs]
@override(Policy)
@OverrideToImplementCustomLogic_CallToSuperRecommended
def postprocess_trajectory(
self,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
episode: Optional["Episode"] = None,
) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent.
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
contain a truncated (at-the-end) episode, in case the
`config.rollout_fragment_length` was reached by the sampler.
- If `config.batch_mode=complete_episodes`, sample_batch will contain
exactly one episode (no matter how long).
New columns can be added to sample_batch and existing ones may be altered.
Args:
sample_batch: The SampleBatch to postprocess.
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[Episode]): Optional multi-agent episode
object in which the agents operated.
Returns:
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
"""
return sample_batch
[docs]
@OverrideToImplementCustomLogic
def optimizer(
self,
) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]:
"""Custom the local PyTorch optimizer(s) to use.
Returns:
The local PyTorch optimizer(s) to use for this Policy.
"""
if hasattr(self, "config"):
optimizers = [
torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
]
else:
optimizers = [torch.optim.Adam(self.model.parameters())]
if self.exploration:
optimizers = self.exploration.get_exploration_optimizer(optimizers)
return optimizers
def _init_model_and_dist_class(self):
if is_overridden(self.make_model) and is_overridden(
self.make_model_and_action_dist
):
raise ValueError(
"Only one of make_model or make_model_and_action_dist "
"can be overridden."
)
if is_overridden(self.make_model):
model = self.make_model()
dist_class, _ = ModelCatalog.get_action_dist(
self.action_space, self.config["model"], framework=self.framework
)
elif is_overridden(self.make_model_and_action_dist):
model, dist_class = self.make_model_and_action_dist()
else:
dist_class, logit_dim = ModelCatalog.get_action_dist(
self.action_space, self.config["model"], framework=self.framework
)
model = ModelCatalog.get_model_v2(
obs_space=self.observation_space,
action_space=self.action_space,
num_outputs=logit_dim,
model_config=self.config["model"],
framework=self.framework,
)
# Compile the model, if requested by the user.
if self.config.get("torch_compile_learner"):
if (
torch is not None
and version.parse(torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
):
raise ValueError("`torch.compile` is not supported for torch < 2.0.0!")
lw = "learner" if self.config.get("worker_index") else "worker"
model = torch.compile(
model,
backend=self.config.get(
f"torch_compile_{lw}_dynamo_backend", "inductor"
),
dynamic=False,
mode=self.config.get(f"torch_compile_{lw}_dynamo_mode"),
)
return model, dist_class
@override(Policy)
def compute_actions_from_input_dict(
self,
input_dict: Dict[str, TensorType],
explore: bool = None,
timestep: Optional[int] = None,
**kwargs,
) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
seq_lens = None
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
input_dict.set_training(True)
if self.config.get("enable_rl_module_and_learner", False):
return self._compute_action_helper(
input_dict,
state_batches=None,
seq_lens=None,
explore=explore,
timestep=timestep,
)
else:
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
]
# Calculate RNN sequence lengths.
if state_batches:
seq_lens = torch.tensor(
[1] * len(state_batches[0]),
dtype=torch.long,
device=state_batches[0].device,
)
return self._compute_action_helper(
input_dict, state_batches, seq_lens, explore, timestep
)
@override(Policy)
def compute_actions(
self,
obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
info_batch: Optional[Dict[str, list]] = None,
episodes: Optional[List["Episode"]] = None,
explore: Optional[bool] = None,
timestep: Optional[int] = None,
**kwargs,
) -> Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]]:
with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict(
{
SampleBatch.CUR_OBS: obs_batch,
"is_training": False,
}
)
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = np.asarray(prev_action_batch)
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = np.asarray(prev_reward_batch)
state_batches = [
convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
]
return self._compute_action_helper(
input_dict, state_batches, seq_lens, explore, timestep
)
@with_lock
@override(Policy)
def compute_log_likelihoods(
self,
actions: Union[List[TensorStructType], TensorStructType],
obs_batch: Union[List[TensorStructType], TensorStructType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[
Union[List[TensorStructType], TensorStructType]
] = None,
prev_reward_batch: Optional[
Union[List[TensorStructType], TensorStructType]
] = None,
actions_normalized: bool = True,
in_training: bool = True,
) -> TensorType:
if is_overridden(self.action_sampler_fn) and not is_overridden(
self.action_distribution_fn
):
raise ValueError(
"Cannot compute log-prob/likelihood w/o an "
"`action_distribution_fn` and a provided "
"`action_sampler_fn`!"
)
with torch.no_grad():
input_dict = self._lazy_tensor_dict(
{SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
)
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
state_batches = [
convert_to_torch_tensor(s, self.device) for s in (state_batches or [])
]
if self.exploration:
# Exploration hook before each forward pass.
self.exploration.before_compute_actions(explore=False)
# Action dist class and inputs are generated via custom function.
if is_overridden(self.action_distribution_fn):
dist_inputs, dist_class, state_out = self.action_distribution_fn(
self.model,
obs_batch=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
action_dist = dist_class(dist_inputs, self.model)
# Default action-dist inputs calculation.
else:
if self.config.get("enable_rl_module_and_learner", False):
if in_training:
output = self.model.forward_train(input_dict)
action_dist_cls = self.model.get_train_action_dist_cls()
if action_dist_cls is None:
raise ValueError(
"The RLModules must provide an appropriate action "
"distribution class for training if is_eval_mode is "
"False."
)
else:
output = self.model.forward_exploration(input_dict)
action_dist_cls = self.model.get_exploration_action_dist_cls()
if action_dist_cls is None:
raise ValueError(
"The RLModules must provide an appropriate action "
"distribution class for exploration if is_eval_mode is "
"True."
)
action_dist_inputs = output.get(
SampleBatch.ACTION_DIST_INPUTS, None
)
if action_dist_inputs is None:
raise ValueError(
"The RLModules must provide inputs to create the action "
"distribution. These should be part of the output of the "
"appropriate forward method under the key "
"SampleBatch.ACTION_DIST_INPUTS."
)
action_dist = action_dist_cls.from_logits(action_dist_inputs)
else:
dist_class = self.dist_class
dist_inputs, _ = self.model(input_dict, state_batches, seq_lens)
action_dist = dist_class(dist_inputs, self.model)
# Normalize actions if necessary.
actions = input_dict[SampleBatch.ACTIONS]
if not actions_normalized and self.config["normalize_actions"]:
actions = normalize_action(actions, self.action_space_struct)
log_likelihoods = action_dist.logp(actions)
return log_likelihoods
@with_lock
@override(Policy)
def learn_on_batch(self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
# Set Model to train mode.
if self.model:
self.model.train()
# Callback handling.
learn_stats = {}
self.callbacks.on_learn_on_batch(
policy=self, train_batch=postprocessed_batch, result=learn_stats
)
# Compute gradients (will calculate all losses and `backward()`
# them to get the grads).
grads, fetches = self.compute_gradients(postprocessed_batch)
# Step the optimizers.
self.apply_gradients(_directStepOptimizerSingleton)
self.num_grad_updates += 1
if self.model and hasattr(self.model, "metrics"):
fetches["model"] = self.model.metrics()
else:
fetches["model"] = {}
fetches.update(
{
"custom_metrics": learn_stats,
NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
# -1, b/c we have to measure this diff before we do the update above.
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
self.num_grad_updates
- 1
- (postprocessed_batch.num_grad_updates or 0)
),
}
)
return fetches
@override(Policy)
def load_batch_into_buffer(
self,
batch: SampleBatch,
buffer_index: int = 0,
) -> int:
# Set the is_training flag of the batch.
batch.set_training(True)
# Shortcut for 1 CPU only: Store batch in `self._loaded_batches`.
if len(self.devices) == 1 and self.devices[0].type == "cpu":
assert buffer_index == 0
pad_batch_to_sequences_of_same_size(
batch=batch,
max_seq_len=self.max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
_enable_new_api_stack=self.config.get(
"enable_rl_module_and_learner", False
),
padding="last"
if self.config.get("enable_rl_module_and_learner", False)
else "zero",
)
self._lazy_tensor_dict(batch)
self._loaded_batches[0] = [batch]
return len(batch)
# Batch (len=28, seq-lens=[4, 7, 4, 10, 3]):
# 0123 0123456 0123 0123456789ABC
# 1) split into n per-GPU sub batches (n=2).
# [0123 0123456] [012] [3 0123456789 ABC]
# (len=14, 14 seq-lens=[4, 7, 3] [1, 10, 3])
slices = batch.timeslices(num_slices=len(self.devices))
# 2) zero-padding (max-seq-len=10).
# - [0123000000 0123456000 0120000000]
# - [3000000000 0123456789 ABC0000000]
for slice in slices:
pad_batch_to_sequences_of_same_size(
batch=slice,
max_seq_len=self.max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
_enable_new_api_stack=self.config.get(
"enable_rl_module_and_learner", False
),
padding="last"
if self.config.get("enable_rl_module_and_learner", False)
else "zero",
)
# 3) Load splits into the given buffer (consisting of n GPUs).
slices = [slice.to_device(self.devices[i]) for i, slice in enumerate(slices)]
self._loaded_batches[buffer_index] = slices
# Return loaded samples per-device.
return len(slices[0])
@override(Policy)
def get_num_samples_loaded_into_buffer(self, buffer_index: int = 0) -> int:
if len(self.devices) == 1 and self.devices[0] == "/cpu:0":
assert buffer_index == 0
return sum(len(b) for b in self._loaded_batches[buffer_index])
@override(Policy)
def learn_on_loaded_batch(self, offset: int = 0, buffer_index: int = 0):
if not self._loaded_batches[buffer_index]:
raise ValueError(
"Must call Policy.load_batch_into_buffer() before "
"Policy.learn_on_loaded_batch()!"
)
# Get the correct slice of the already loaded batch to use,
# based on offset and batch size.
device_batch_size = self.config.get(
"sgd_minibatch_size", self.config["train_batch_size"]
) // len(self.devices)
# Set Model to train mode.
if self.model_gpu_towers:
for t in self.model_gpu_towers:
t.train()
# Shortcut for 1 CPU only: Batch should already be stored in
# `self._loaded_batches`.
if len(self.devices) == 1 and self.devices[0].type == "cpu":
assert buffer_index == 0
if device_batch_size >= len(self._loaded_batches[0][0]):
batch = self._loaded_batches[0][0]
else:
batch = self._loaded_batches[0][0][offset : offset + device_batch_size]
return self.learn_on_batch(batch)
if len(self.devices) > 1:
# Copy weights of main model (tower-0) to all other towers.
state_dict = self.model.state_dict()
# Just making sure tower-0 is really the same as self.model.
assert self.model_gpu_towers[0] is self.model
for tower in self.model_gpu_towers[1:]:
tower.load_state_dict(state_dict)
if device_batch_size >= sum(len(s) for s in self._loaded_batches[buffer_index]):
device_batches = self._loaded_batches[buffer_index]
else:
device_batches = [
b[offset : offset + device_batch_size]
for b in self._loaded_batches[buffer_index]
]
# Callback handling.
batch_fetches = {}
for i, batch in enumerate(device_batches):
custom_metrics = {}
self.callbacks.on_learn_on_batch(
policy=self, train_batch=batch, result=custom_metrics
)
batch_fetches[f"tower_{i}"] = {"custom_metrics": custom_metrics}
# Do the (maybe parallelized) gradient calculation step.
tower_outputs = self._multi_gpu_parallel_grad_calc(device_batches)
# Mean-reduce gradients over GPU-towers (do this on CPU: self.device).
all_grads = []
for i in range(len(tower_outputs[0][0])):
if tower_outputs[0][0][i] is not None:
all_grads.append(
torch.mean(
torch.stack([t[0][i].to(self.device) for t in tower_outputs]),
dim=0,
)
)
else:
all_grads.append(None)
# Set main model's grads to mean-reduced values.
for i, p in enumerate(self.model.parameters()):
p.grad = all_grads[i]
self.apply_gradients(_directStepOptimizerSingleton)
self.num_grad_updates += 1
for i, (model, batch) in enumerate(zip(self.model_gpu_towers, device_batches)):
batch_fetches[f"tower_{i}"].update(
{
LEARNER_STATS_KEY: self.stats_fn(batch),
"model": {}
if self.config.get("enable_rl_module_and_learner", False)
else model.metrics(),
NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
# -1, b/c we have to measure this diff before we do the update
# above.
DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
self.num_grad_updates - 1 - (batch.num_grad_updates or 0)
),
}
)
batch_fetches.update(self.extra_compute_grad_fetches())
return batch_fetches
@with_lock
@override(Policy)
def compute_gradients(self, postprocessed_batch: SampleBatch) -> ModelGradients:
assert len(self.devices) == 1
# If not done yet, see whether we have to zero-pad this batch.
if not postprocessed_batch.zero_padded:
pad_batch_to_sequences_of_same_size(
batch=postprocessed_batch,
max_seq_len=self.max_seq_len,
shuffle=False,
batch_divisibility_req=self.batch_divisibility_req,
view_requirements=self.view_requirements,
_enable_new_api_stack=self.config.get(
"enable_rl_module_and_learner", False
),
padding="last"
if self.config.get("enable_rl_module_and_learner", False)
else "zero",
)
postprocessed_batch.set_training(True)
self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
# Do the (maybe parallelized) gradient calculation step.
tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
all_grads, grad_info = tower_outputs[0]
grad_info["allreduce_latency"] /= len(self._optimizers)
grad_info.update(self.stats_fn(postprocessed_batch))
fetches = self.extra_compute_grad_fetches()
return all_grads, dict(fetches, **{LEARNER_STATS_KEY: grad_info})
@override(Policy)
def apply_gradients(self, gradients: ModelGradients) -> None:
if gradients == _directStepOptimizerSingleton:
for i, opt in enumerate(self._optimizers):
opt.step()
else:
# TODO(sven): Not supported for multiple optimizers yet.
assert len(self._optimizers) == 1
for g, p in zip(gradients, self.model.parameters()):
if g is not None:
if torch.is_tensor(g):
p.grad = g.to(self.device)
else:
p.grad = torch.from_numpy(g).to(self.device)
self._optimizers[0].step()
[docs]
def get_tower_stats(self, stats_name: str) -> List[TensorStructType]:
"""Returns list of per-tower stats, copied to this Policy's device.
Args:
stats_name: The name of the stats to average over (this str
must exist as a key inside each tower's `tower_stats` dict).
Returns:
The list of stats tensor (structs) of all towers, copied to this
Policy's device.
Raises:
AssertionError: If the `stats_name` cannot be found in any one
of the tower's `tower_stats` dicts.
"""
data = []
for model in self.model_gpu_towers:
if self.tower_stats:
tower_stats = self.tower_stats[model]
else:
tower_stats = model.tower_stats
if stats_name in tower_stats:
data.append(
tree.map_structure(
lambda s: s.to(self.device), tower_stats[stats_name]
)
)
assert len(data) > 0, (
f"Stats `{stats_name}` not found in any of the towers (you have "
f"{len(self.model_gpu_towers)} towers in total)! Make "
"sure you call the loss function on at least one of the towers."
)
return data
@override(Policy)
def get_weights(self) -> ModelWeights:
return {k: v.cpu().detach().numpy() for k, v in self.model.state_dict().items()}
@override(Policy)
def set_weights(self, weights: ModelWeights) -> None:
weights = convert_to_torch_tensor(weights, device=self.device)
if self.config.get("enable_rl_module_and_learner", False):
self.model.set_state(weights)
else:
self.model.load_state_dict(weights)
@override(Policy)
def is_recurrent(self) -> bool:
return self._is_recurrent
@override(Policy)
def num_state_tensors(self) -> int:
return len(self.model.get_initial_state())
@override(Policy)
def get_initial_state(self) -> List[TensorType]:
if self.config.get("enable_rl_module_and_learner", False):
# convert the tree of tensors to a tree to numpy arrays
return tree.map_structure(
lambda s: convert_to_numpy(s), self.model.get_initial_state()
)
return [s.detach().cpu().numpy() for s in self.model.get_initial_state()]
@override(Policy)
@OverrideToImplementCustomLogic_CallToSuperRecommended
def get_state(self) -> PolicyState:
# Legacy Policy state (w/o torch.nn.Module and w/o PolicySpec).
state = super().get_state()
state["_optimizer_variables"] = []
# In the new Learner API stack, the optimizers live in the learner.
if not self.config.get("enable_rl_module_and_learner", False):
for i, o in enumerate(self._optimizers):
optim_state_dict = convert_to_numpy(o.state_dict())
state["_optimizer_variables"].append(optim_state_dict)
# Add exploration state.
if (
not self.config.get("enable_rl_module_and_learner", False)
and self.exploration
):
# This is not compatible with RLModules, which have a method
# `forward_exploration` to specify custom exploration behavior.
state["_exploration_state"] = self.exploration.get_state()
return state
@override(Policy)
@OverrideToImplementCustomLogic_CallToSuperRecommended
def set_state(self, state: PolicyState) -> None:
# Set optimizer vars first.
optimizer_vars = state.get("_optimizer_variables", None)
if optimizer_vars:
assert len(optimizer_vars) == len(self._optimizers)
for o, s in zip(self._optimizers, optimizer_vars):
# Torch optimizer param_groups include things like beta, etc. These
# parameters should be left as scalar and not converted to tensors.
# otherwise, torch.optim.step() will start to complain.
optim_state_dict = {"param_groups": s["param_groups"]}
optim_state_dict["state"] = convert_to_torch_tensor(
s["state"], device=self.device
)
o.load_state_dict(optim_state_dict)
# Set exploration's state.
if hasattr(self, "exploration") and "_exploration_state" in state:
self.exploration.set_state(state=state["_exploration_state"])
# Restore global timestep.
self.global_timestep = state["global_timestep"]
# Then the Policy's (NN) weights and connectors.
super().set_state(state)
[docs]
@override(Policy)
def export_model(self, export_dir: str, onnx: Optional[int] = None) -> None:
"""Exports the Policy's Model to local directory for serving.
Creates a TorchScript model and saves it.
Args:
export_dir: Local writable directory or filename.
onnx: If given, will export model in ONNX format. The
value of this parameter set the ONNX OpSet version to use.
"""
os.makedirs(export_dir, exist_ok=True)
enable_rl_module = self.config.get("enable_rl_module_and_learner", False)
if enable_rl_module and onnx:
raise ValueError("ONNX export not supported for RLModule API.")
if onnx:
self._lazy_tensor_dict(self._dummy_batch)
# Provide dummy state inputs if not an RNN (torch cannot jit with
# returned empty internal states list).
if "state_in_0" not in self._dummy_batch:
self._dummy_batch["state_in_0"] = self._dummy_batch[
SampleBatch.SEQ_LENS
] = np.array([1.0])
seq_lens = self._dummy_batch[SampleBatch.SEQ_LENS]
state_ins = []
i = 0
while "state_in_{}".format(i) in self._dummy_batch:
state_ins.append(self._dummy_batch["state_in_{}".format(i)])
i += 1
dummy_inputs = {
k: self._dummy_batch[k]
for k in self._dummy_batch.keys()
if k != "is_training"
}
file_name = os.path.join(export_dir, "model.onnx")
torch.onnx.export(
self.model,
(dummy_inputs, state_ins, seq_lens),
file_name,
export_params=True,
opset_version=onnx,
do_constant_folding=True,
input_names=list(dummy_inputs.keys())
+ ["state_ins", SampleBatch.SEQ_LENS],
output_names=["output", "state_outs"],
dynamic_axes={
k: {0: "batch_size"}
for k in list(dummy_inputs.keys())
+ ["state_ins", SampleBatch.SEQ_LENS]
},
)
# Save the torch.Model (architecture and weights, so it can be retrieved
# w/o access to the original (custom) Model or Policy code).
else:
filename = os.path.join(export_dir, "model.pt")
try:
torch.save(self.model, f=filename)
except Exception:
if os.path.exists(filename):
os.remove(filename)
logger.warning(ERR_MSG_TORCH_POLICY_CANNOT_SAVE_MODEL)
[docs]
@override(Policy)
def import_model_from_h5(self, import_file: str) -> None:
"""Imports weights into torch model."""
return self.model.import_from_h5(import_file)
@with_lock
def _compute_action_helper(
self, input_dict, state_batches, seq_lens, explore, timestep
):
"""Shared forward pass logic (w/ and w/o trajectory view API).
Returns:
A tuple consisting of a) actions, b) state_out, c) extra_fetches.
The input_dict is modified in-place to include a numpy copy of the computed
actions under `SampleBatch.ACTIONS`.
"""
explore = explore if explore is not None else self.config["explore"]
timestep = timestep if timestep is not None else self.global_timestep
# Switch to eval mode.
if self.model:
self.model.eval()
extra_fetches = dist_inputs = logp = None
# New API stack: `self.model` is-a RLModule.
if isinstance(self.model, RLModule):
if self.model.is_stateful():
# For recurrent models, we need to add a time dimension.
if not seq_lens:
# In order to calculate the batch size ad hoc, we need a sample
# batch.
if not isinstance(input_dict, SampleBatch):
input_dict = SampleBatch(input_dict)
seq_lens = np.array([1] * len(input_dict))
input_dict = self.maybe_add_time_dimension(
input_dict, seq_lens=seq_lens
)
input_dict = convert_to_torch_tensor(input_dict, device=self.device)
# Batches going into the RL Module should not have seq_lens.
if SampleBatch.SEQ_LENS in input_dict:
del input_dict[SampleBatch.SEQ_LENS]
if explore:
fwd_out = self.model.forward_exploration(input_dict)
# For recurrent models, we need to remove the time dimension.
fwd_out = self.maybe_remove_time_dimension(fwd_out)
# ACTION_DIST_INPUTS field returned by `forward_exploration()` ->
# Create a distribution object.
action_dist = None
# Maybe the RLModule has already computed actions.
if SampleBatch.ACTION_DIST_INPUTS in fwd_out:
dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]
action_dist_class = self.model.get_exploration_action_dist_cls()
action_dist = action_dist_class.from_logits(dist_inputs)
# If `forward_exploration()` returned actions, use them here as-is.
if SampleBatch.ACTIONS in fwd_out:
actions = fwd_out[SampleBatch.ACTIONS]
# Otherwise, sample actions from the distribution.
else:
if action_dist is None:
raise KeyError(
"Your RLModule's `forward_exploration()` method must return"
f" a dict with either the {SampleBatch.ACTIONS} key or the "
f"{SampleBatch.ACTION_DIST_INPUTS} key in it (or both)!"
)
actions = action_dist.sample()
# Compute action-logp and action-prob from distribution and add to
# `extra_fetches`, if possible.
if action_dist is not None:
logp = action_dist.logp(actions)
else:
fwd_out = self.model.forward_inference(input_dict)
# For recurrent models, we need to remove the time dimension.
fwd_out = self.maybe_remove_time_dimension(fwd_out)
# ACTION_DIST_INPUTS field returned by `forward_exploration()` ->
# Create a distribution object.
action_dist = None
if SampleBatch.ACTION_DIST_INPUTS in fwd_out:
dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]
action_dist_class = self.model.get_inference_action_dist_cls()
action_dist = action_dist_class.from_logits(dist_inputs)
action_dist = action_dist.to_deterministic()
# If `forward_inference()` returned actions, use them here as-is.
if SampleBatch.ACTIONS in fwd_out:
actions = fwd_out[SampleBatch.ACTIONS]
# Otherwise, sample actions from the distribution.
else:
if action_dist is None:
raise KeyError(
"Your RLModule's `forward_inference()` method must return"
f" a dict with either the {SampleBatch.ACTIONS} key or the "
f"{SampleBatch.ACTION_DIST_INPUTS} key in it (or both)!"
)
actions = action_dist.sample()
# Anything but actions and state_out is an extra fetch.
state_out = fwd_out.pop(Columns.STATE_OUT, {})
extra_fetches = fwd_out
elif is_overridden(self.action_sampler_fn):
action_dist = None
actions, logp, dist_inputs, state_out = self.action_sampler_fn(
self.model,
obs_batch=input_dict,
state_batches=state_batches,
explore=explore,
timestep=timestep,
)
else:
# Call the exploration before_compute_actions hook.
self.exploration.before_compute_actions(explore=explore, timestep=timestep)
if is_overridden(self.action_distribution_fn):
dist_inputs, dist_class, state_out = self.action_distribution_fn(
self.model,
obs_batch=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=explore,
timestep=timestep,
is_training=False,
)
else:
dist_class = self.dist_class
dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
if not (
isinstance(dist_class, functools.partial)
or issubclass(dist_class, TorchDistributionWrapper)
):
raise ValueError(
"`dist_class` ({}) not a TorchDistributionWrapper "
"subclass! Make sure your `action_distribution_fn` or "
"`make_model_and_action_dist` return a correct "
"distribution class.".format(dist_class.__name__)
)
action_dist = dist_class(dist_inputs, self.model)
# Get the exploration action from the forward results.
actions, logp = self.exploration.get_exploration_action(
action_distribution=action_dist, timestep=timestep, explore=explore
)
# Add default and custom fetches.
if extra_fetches is None:
extra_fetches = self.extra_action_out(
input_dict, state_batches, self.model, action_dist
)
# Action-dist inputs.
if dist_inputs is not None:
extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
# Action-logp and action-prob.
if logp is not None:
extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp.float())
extra_fetches[SampleBatch.ACTION_LOGP] = logp
# Update our global timestep by the batch size.
self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])
return convert_to_numpy((actions, state_out, extra_fetches))
def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None):
if not isinstance(postprocessed_batch, SampleBatch):
postprocessed_batch = SampleBatch(postprocessed_batch)
postprocessed_batch.set_get_interceptor(
functools.partial(convert_to_torch_tensor, device=device or self.device)
)
return postprocessed_batch
def _multi_gpu_parallel_grad_calc(
self, sample_batches: List[SampleBatch]
) -> List[Tuple[List[TensorType], GradInfoDict]]:
"""Performs a parallelized loss and gradient calculation over the batch.
Splits up the given train batch into n shards (n=number of this
Policy's devices) and passes each data shard (in parallel) through
the loss function using the individual devices' models
(self.model_gpu_towers). Then returns each tower's outputs.
Args:
sample_batches: A list of SampleBatch shards to
calculate loss and gradients for.
Returns:
A list (one item per device) of 2-tuples, each with 1) gradient
list and 2) grad info dict.
"""
assert len(self.model_gpu_towers) == len(sample_batches)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(shard_idx, model, sample_batch, device):
torch.set_grad_enabled(grad_enabled)
try:
with NullContextManager() if device.type == "cpu" else torch.cuda.device( # noqa: E501
device
):
loss_out = force_list(
self.loss(model, self.dist_class, sample_batch)
)
# Call Model's custom-loss with Policy loss outputs and
# train_batch.
if hasattr(model, "custom_loss"):
loss_out = model.custom_loss(loss_out, sample_batch)
assert len(loss_out) == len(self._optimizers)
# Loop through all optimizers.
grad_info = {"allreduce_latency": 0.0}
parameters = list(model.parameters())
all_grads = [None for _ in range(len(parameters))]
for opt_idx, opt in enumerate(self._optimizers):
# Erase gradients in all vars of the tower that this
# optimizer would affect.
param_indices = self.multi_gpu_param_groups[opt_idx]
for param_idx, param in enumerate(parameters):
if param_idx in param_indices and param.grad is not None:
param.grad.data.zero_()
# Recompute gradients of loss over all variables.
loss_out[opt_idx].backward(retain_graph=True)
grad_info.update(
self.extra_grad_process(opt, loss_out[opt_idx])
)
grads = []
# Note that return values are just references;
# Calling zero_grad would modify the values.
for param_idx, param in enumerate(parameters):
if param_idx in param_indices:
if param.grad is not None:
grads.append(param.grad)
all_grads[param_idx] = param.grad
if self.distributed_world_size:
start = time.time()
if torch.cuda.is_available():
# Sadly, allreduce_coalesced does not work with
# CUDA yet.
for g in grads:
torch.distributed.all_reduce(
g, op=torch.distributed.ReduceOp.SUM
)
else:
torch.distributed.all_reduce_coalesced(
grads, op=torch.distributed.ReduceOp.SUM
)
for param_group in opt.param_groups:
for p in param_group["params"]:
if p.grad is not None:
p.grad /= self.distributed_world_size
grad_info["allreduce_latency"] += time.time() - start
with lock:
results[shard_idx] = (all_grads, grad_info)
except Exception as e:
import traceback
with lock:
results[shard_idx] = (
ValueError(
e.args[0]
+ "\n traceback"
+ traceback.format_exc()
+ "\n"
+ "In tower {} on device {}".format(shard_idx, device)
),
e,
)
# Single device (GPU) or fake-GPU case (serialize for better
# debugging).
if len(self.devices) == 1 or self.config["_fake_gpus"]:
for shard_idx, (model, sample_batch, device) in enumerate(
zip(self.model_gpu_towers, sample_batches, self.devices)
):
_worker(shard_idx, model, sample_batch, device)
# Raise errors right away for better debugging.
last_result = results[len(results) - 1]
if isinstance(last_result[0], ValueError):
raise last_result[0] from last_result[1]
# Multi device (GPU) case: Parallelize via threads.
else:
threads = [
threading.Thread(
target=_worker, args=(shard_idx, model, sample_batch, device)
)
for shard_idx, (model, sample_batch, device) in enumerate(
zip(self.model_gpu_towers, sample_batches, self.devices)
)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
# Gather all threads' outputs and return.
outputs = []
for shard_idx in range(len(sample_batches)):
output = results[shard_idx]
if isinstance(output[0], Exception):
raise output[0] from output[1]
outputs.append(results[shard_idx])
return outputs