Torch-Specific Policy: TorchPolicy

class ray.rllib.policy.torch_policy.TorchPolicy(observation_space: <Mock name='mock.spaces.Space' id='139804182087632'>, action_space: <Mock name='mock.spaces.Space' id='139804182087632'>, config: dict, *, model: Optional[ray.rllib.models.torch.torch_modelv2.TorchModelV2] = None, loss: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper], ray.rllib.policy.sample_batch.SampleBatch], Union[Any, List[Any]]]] = None, action_distribution_class: Optional[Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper]] = None, action_sampler_fn: Optional[Callable[[Any, List[Any]], Tuple[Any, Any]]] = None, action_distribution_fn: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Any, Any, Any], Tuple[Any, Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper], List[Any]]]] = None, max_seq_len: int = 20, get_batch_divisibility_req: Optional[Callable[[ray.rllib.policy.policy.Policy], int]] = None)[source]

PyTorch specific Policy class to use with RLlib.

__init__(observation_space: <Mock name='mock.spaces.Space' id='139804182087632'>, action_space: <Mock name='mock.spaces.Space' id='139804182087632'>, config: dict, *, model: Optional[ray.rllib.models.torch.torch_modelv2.TorchModelV2] = None, loss: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper], ray.rllib.policy.sample_batch.SampleBatch], Union[Any, List[Any]]]] = None, action_distribution_class: Optional[Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper]] = None, action_sampler_fn: Optional[Callable[[Any, List[Any]], Tuple[Any, Any]]] = None, action_distribution_fn: Optional[Callable[[ray.rllib.policy.policy.Policy, ray.rllib.models.modelv2.ModelV2, Any, Any, Any], Tuple[Any, Type[ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper], List[Any]]]] = None, max_seq_len: int = 20, get_batch_divisibility_req: Optional[Callable[[ray.rllib.policy.policy.Policy], int]] = None)[source]

Initializes a TorchPolicy instance.

Parameters
  • observation_space – Observation space of the policy.

  • action_space – Action space of the policy.

  • config – The Policy’s config dict.

  • model – PyTorch policy module. Given observations as input, this module must return a list of outputs where the first item is action logits, and the rest can be any value.

  • loss – Callable that returns one or more (a list of) scalar loss terms.

  • action_distribution_class – Class for a torch action distribution.

  • action_sampler_fn – A callable returning a sampled action and its log-likelihood given Policy, ModelV2, input_dict, state batches (optional), explore, and timestep. Provide action_sampler_fn if you would like to have full control over the action computation step, including the model forward pass, possible sampling from a distribution, and exploration logic. Note: If action_sampler_fn is given, action_distribution_fn must be None. If both action_sampler_fn and action_distribution_fn are None, RLlib will simply pass inputs through self.model to get distribution inputs, create the distribution object, sample from it, and apply some exploration logic to the results. The callable takes as inputs: Policy, ModelV2, input_dict (SampleBatch), state_batches (optional), explore, and timestep.

  • action_distribution_fn – A callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). Provide action_distribution_fn if you would like to only customize the model forward pass call. The resulting distribution parameters are then used by RLlib to create a distribution object, sample from it, and execute any exploration logic. Note: If action_distribution_fn is given, action_sampler_fn must be None. If both action_sampler_fn and action_distribution_fn are None, RLlib will simply pass inputs through self.model to get distribution inputs, create the distribution object, sample from it, and apply some exploration logic to the results. The callable takes as inputs: Policy, ModelV2, ModelInputDict, explore, timestep, is_training.

  • max_seq_len – Max sequence length for LSTM training.

  • get_batch_divisibility_req – Optional callable that returns the divisibility requirement for sample batches given the Policy.

compute_actions_from_input_dict(input_dict: Dict[str, Any], explore: bool = None, timestep: Optional[int] = None, **kwargs) → Tuple[Any, List[Any], Dict[str, Any]][source]

Computes actions from collected samples (across multiple-agents).

Takes an input dict (usually a SampleBatch) as its main data input. This allows for using this method in case a more complex input pattern (view requirements) is needed, for example when the Model requires the last n observations, the last m actions/rewards, or a combination of any of these.

Parameters
  • input_dict – A SampleBatch or input dict containing the Tensors to compute actions. input_dict already abides to the Policy’s as well as the Model’s view requirements and can thus be passed to the Model as-is.

  • explore – Whether to pick an exploitation or exploration action (default: None -> use self.config[“explore”]).

  • timestep – The current (sampling) time step.

  • episodes – This provides access to all of the internal episodes’ state, which may be useful for model-based or multi-agent algorithms.

Keyword Arguments

kwargs – Forward compatibility placeholder.

Returns

Batch of output actions, with shape like

[BATCH_SIZE, ACTION_SHAPE].

state_outs: List of RNN state output

batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].

info: Dictionary of extra feature batches, if any, with shape like

{“f1”: [BATCH_SIZE, …], “f2”: [BATCH_SIZE, …]}.

Return type

actions

compute_actions(obs_batch: Union[List[Union[Any, dict, tuple]], Any, dict, tuple], state_batches: Optional[List[Any]] = None, prev_action_batch: Union[List[Union[Any, dict, tuple]], Any, dict, tuple] = None, prev_reward_batch: Union[List[Union[Any, dict, tuple]], Any, dict, tuple] = None, info_batch: Optional[Dict[str, list]] = None, episodes: Optional[List[Episode]] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs) → Tuple[Union[Any, dict, tuple], List[Any], Dict[str, Any]][source]

Computes actions for the current policy.

Parameters
  • obs_batch – Batch of observations.

  • state_batches – List of RNN state input batches, if any.

  • prev_action_batch – Batch of previous action values.

  • prev_reward_batch – Batch of previous rewards.

  • info_batch – Batch of info objects.

  • episodes – List of Episode objects, one for each obs in obs_batch. This provides access to all of the internal episode state, which may be useful for model-based or multi-agent algorithms.

  • explore – Whether to pick an exploitation or exploration action. Set to None (default) for using the value of self.config[“explore”].

  • timestep – The current (sampling) time step.

Keyword Arguments

kwargs – Forward compatibility placeholder

Returns

Batch of output actions, with shape like

[BATCH_SIZE, ACTION_SHAPE].

state_outs (List[TensorType]): List of RNN state output

batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].

info (List[dict]): Dictionary of extra feature batches, if any,

with shape like {“f1”: [BATCH_SIZE, …], “f2”: [BATCH_SIZE, …]}.

Return type

actions (TensorType)

load_batch_into_buffer(batch: ray.rllib.policy.sample_batch.SampleBatch, buffer_index: int = 0) → int[source]

Bulk-loads the given SampleBatch into the devices’ memories.

The data is split equally across all the Policy’s devices. If the data is not evenly divisible by the batch size, excess data should be discarded.

Parameters
  • batch – The SampleBatch to load.

  • buffer_index – The index of the buffer (a MultiGPUTowerStack) to use on the devices. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

Returns

The number of tuples loaded per device.

get_num_samples_loaded_into_buffer(buffer_index: int = 0) → int[source]

Returns the number of currently loaded samples in the given buffer.

Parameters

buffer_index – The index of the buffer (a MultiGPUTowerStack) to use on the devices. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

Returns

The number of tuples loaded per device.

learn_on_loaded_batch(offset: int = 0, buffer_index: int = 0)[source]

Runs a single step of SGD on an already loaded data in a buffer.

Runs an SGD step over a slice of the pre-loaded batch, offset by the offset argument (useful for performing n minibatch SGD updates repeatedly on the same, already pre-loaded data).

Updates the model weights based on the averaged per-device gradients.

Parameters
  • offset – Offset into the preloaded data. Used for pre-loading a train-batch once to a device, then iterating over (subsampling through) this batch n times doing minibatch SGD.

  • buffer_index – The index of the buffer (a MultiGPUTowerStack) to take the already pre-loaded data from. The number of buffers on each device depends on the value of the num_multi_gpu_tower_stacks config key.

Returns

The outputs of extra_ops evaluated over the batch.

apply_gradients(gradients: Union[List[Tuple[Any, Any]], List[Any]]) → None[source]

Applies the (previously) computed gradients.

Either this in combination with compute_gradients() or learn_on_batch() must be implemented by subclasses.

Parameters

gradients – The already calculated gradients to apply to this Policy.

get_tower_stats(stats_name: str) → List[Union[Any, dict, tuple]][source]

Returns list of per-tower stats, copied to this Policy’s device.

Parameters

stats_name – The name of the stats to average over (this str must exist as a key inside each tower’s tower_stats dict).

Returns

The list of stats tensor (structs) of all towers, copied to this Policy’s device.

Raises
  • AssertionError – If the stats_name cannot be found in any one

  • of the tower's tower_stats dicts.

get_weights() → dict[source]

Returns model weights.

Note: The return value of this method will reside under the “weights” key in the return value of Policy.get_state(). Model weights are only one part of a Policy’s state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Returns

Serializable copy or view of model weights.

set_weights(weights: dict) → None[source]

Sets this Policy’s model’s weights.

Note: Model weights are only one part of a Policy’s state. Other state information contains: optimizer variables, exploration state, and global state vars such as the sampling timestep.

Parameters

weights – Serializable copy or view of model weights.

is_recurrent() → bool[source]

Whether this Policy holds a recurrent Model.

Returns

True if this Policy has-a RNN-based Model.

num_state_tensors() → int[source]

The number of internal states needed by the RNN-Model of the Policy.

Returns

The number of RNN internal states kept by this Policy’s Model.

Return type

int

get_initial_state() → List[Any][source]

Returns initial RNN state for the current policy.

Returns

Initial RNN state for the current policy.

Return type

List[TensorType]

get_state() → Union[Dict[str, Any], List[Any]][source]

Returns the entire current state of this Policy.

Note: Not to be confused with an RNN model’s internal state. State includes the Model(s)’ weights, optimizer weights, the exploration component’s state, as well as global variables, such as sampling timesteps.

Returns

Serialized local state.

set_state(state: dict) → None[source]

Restores the entire current state of this Policy from state.

Parameters

state – The new state to set this policy to. Can be obtained by calling self.get_state().

extra_grad_process(optimizer: <Mock name='mock.optim.Optimizer' id='139804114660176'>, loss: Any) → Dict[str, Any][source]

Called after each optimizer.zero_grad() + loss.backward() call.

Called for each self._optimizers/loss-value pair. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping.

Parameters
  • optimizer – A torch optimizer object.

  • loss – The loss tensor associated with the optimizer.

Returns

An dict with information on the gradient processing step.

extra_compute_grad_fetches() → Dict[str, Any][source]

Extra values to fetch and return from compute_gradients().

Returns

Extra fetch dict to be added to the fetch dict of the compute_gradients call.

extra_action_out(input_dict: Dict[str, Any], state_batches: List[Any], model: ray.rllib.models.torch.torch_modelv2.TorchModelV2, action_dist: ray.rllib.models.torch.torch_action_dist.TorchDistributionWrapper) → Dict[str, Any][source]

Returns dict of extra info to include in experience batch.

Parameters
  • input_dict – Dict of model input tensors.

  • state_batches – List of state tensors.

  • model – Reference to the model object.

  • action_dist – Torch action dist object to get log-probs (e.g. for already sampled actions).

Returns

Extra outputs to return in a compute_actions_from_input_dict() call (3rd return value).

extra_grad_info(train_batch: ray.rllib.policy.sample_batch.SampleBatch) → Dict[str, Any][source]

Return dict of extra grad info.

Parameters

train_batch – The training batch for which to produce extra grad info for.

Returns

The info dict carrying grad info per str key.

optimizer() → Union[List[<Mock name=’mock.optim.Optimizer’ id=’139804114660176’>], <Mock name=’mock.optim.Optimizer’ id=’139804114660176’>][source]

Custom the local PyTorch optimizer(s) to use.

Returns

The local PyTorch optimizer(s) to use for this Policy.

export_model(export_dir: str, onnx: Optional[int] = None) → None[source]

Exports the Policy’s Model to local directory for serving.

Creates a TorchScript model and saves it.

Parameters
  • export_dir – Local writable directory or filename.

  • onnx – If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use.

export_checkpoint(export_dir: str) → None[source]

Export Policy checkpoint to local directory.

Parameters

export_dir – Local writable directory.

import_model_from_h5(import_file: str) → None[source]

Imports weights into torch model.