Trainer API

The Trainer class is the highest-level API in RLlib. It allows you to train and evaluate policies, save an experiment’s progress and restore from a prior saved experiment when continuing an RL run. Trainer is a sub-class of Trainable and thus fully supports distributed hyperparameter tuning for RL.

../../_images/trainer_class_overview.svg

A typical RLlib Trainer object: The components sitting inside a Trainer are normally N RolloutWorker and zero or more @ray.remote BaseEnv per worker.

Building Custom Trainer Classes

Warning

As of Ray >= 1.9, it is no longer recommended to use the build_trainer() utility function for creating custom Trainer sub-classes. Instead, follow the simple guidelines here for directly sub-classing from Trainer.

In order to create a custom Trainer, sub-class the Trainer class and override one or more of its methods. Those are in particular:

See here for an example on how to override Trainer.

Trainer base class (ray.rllib.agents.trainer.Trainer)

class ray.rllib.agents.trainer.Trainer(config: Optional[dict] = None, env: Union[str, Any, None] = None, logger_creator: Optional[Callable[], ray.tune.logger.Logger]] = None, remote_checkpoint_dir: Optional[str] = None, sync_function_tpl: Optional[str] = None)[source]

An RLlib algorithm responsible for optimizing one or more Policies.

Trainers contain a WorkerSet under self.workers. A WorkerSet is normally composed of a single local worker (self.workers.local_worker()), used to compute and apply learning updates, and optionally one or more remote workers (self.workers.remote_workers()), used to generate environment samples in parallel.

Each worker (remotes or local) contains a PolicyMap, which itself may contain either one policy for single-agent training or one or more policies for multi-agent training. Policies are synchronized automatically from time to time using ray.remote calls. The exact synchronization logic depends on the specific algorithm (Trainer) used, but this usually happens from local worker to all remote workers and after each training update.

You can write your own Trainer sub-classes by using the rllib.agents.trainer_template.py::build_trainer() utility function. This allows you to provide a custom execution_plan. You can find the different built-in algorithms’ execution plans in their respective main py files, e.g. rllib.agents.dqn.dqn.py or rllib.agents.impala.impala.py.

The most important API methods a Trainer exposes are train(), evaluate(), save() and restore(). Trainer objects retain internal model state between calls to train(), so you should create a new Trainer instance for each training session.

__init__(config: Optional[dict] = None, env: Union[str, Any, None] = None, logger_creator: Optional[Callable[], ray.tune.logger.Logger]] = None, remote_checkpoint_dir: Optional[str] = None, sync_function_tpl: Optional[str] = None)[source]

Initializes a Trainer instance.

Parameters
  • config – Algorithm-specific configuration dict.

  • env – Name of the environment to use (e.g. a gym-registered str), a full class path (e.g. “ray.rllib.examples.env.random_env.RandomEnv”), or an Env class directly. Note that this arg can also be specified via the “env” key in config.

  • logger_creator – Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created.

setup(config: dict)[source]

Subclasses should override this for custom initialization.

New in version 0.8.7.

Parameters

config (dict) – Hyperparameters and other configs given. Copy of self.config.

get_default_policy_class(config: dict)[source]

Returns a default Policy class to use, given a config.

This class will be used inside RolloutWorkers’ PolicyMaps in case the policy class is not provided by the user in any single- or multi-agent PolicySpec.

This method is experimental and currently only used, iff the Trainer class was not created using the build_trainer utility and if the Trainer sub-class does not override _init() and create it’s own WorkerSet in _init().

step() → dict[source]

Implements the main Trainer.train() logic.

Takes n attempts to perform a single training step. Thereby catches RayErrors resulting from worker failures. After n attempts, fails gracefully.

Override this method in your Trainer sub-classes if you would like to handle worker failures yourself. Otherwise, override self.step_attempt() to keep the n attempts (catch worker failures).

Returns

The results dict with stats/infos on sampling, training, and - if required - evaluation.

step_attempt() → dict[source]

Attempts a single training step, including evaluation, if required.

Override this method in your Trainer sub-classes if you would like to keep the n step-attempts logic (catch worker failures) in place or override step() directly if you would like to handle worker failures yourself.

Returns

The results dict with stats/infos on sampling, training, and - if required - evaluation.

evaluate(episodes_left_fn=None, duration_fn: Optional[Callable[[int], int]] = None) → dict[source]

Evaluates current policy under evaluation_config settings.

Note that this default implementation does not do anything beyond merging evaluation_config with the normal trainer config.

Parameters

duration_fn – An optional callable taking the already run num episodes as only arg and returning the number of episodes left to run. It’s used to find out whether evaluation should continue.

compute_single_action(observation: Union[Any, dict, tuple, None] = None, state: Optional[List[Union[Any, dict, tuple]]] = None, *, prev_action: Union[Any, dict, tuple, None] = None, prev_reward: Optional[float] = None, info: Optional[dict] = None, input_dict: Optional[ray.rllib.policy.sample_batch.SampleBatch] = None, policy_id: str = 'default_policy', full_fetch: bool = False, explore: Optional[bool] = None, timestep: Optional[int] = None, episode: Optional[ray.rllib.evaluation.episode.Episode] = None, unsquash_action: Optional[bool] = None, clip_action: Optional[bool] = None, unsquash_actions=- 1, clip_actions=- 1, **kwargs) → Union[Any, dict, tuple, Tuple[Union[Any, dict, tuple], List[Any], Dict[str, Any]]][source]

Computes an action for the specified policy on the local worker.

Note that you can also access the policy object through self.get_policy(policy_id) and call compute_single_action() on it directly.

Parameters
  • observation – Single (unbatched) observation from the environment.

  • state – List of all RNN hidden (single, unbatched) state tensors.

  • prev_action – Single (unbatched) previous action value.

  • prev_reward – Single (unbatched) previous reward value.

  • info – Env info dict, if any.

  • input_dict – An optional SampleBatch that holds all the values for: obs, state, prev_action, and prev_reward, plus maybe custom defined views of the current env trajectory. Note that only one of obs or input_dict must be non-None.

  • policy_id – Policy to query (only applies to multi-agent). Default: “default_policy”.

  • full_fetch – Whether to return extra action fetch results. This is always set to True if state is specified.

  • explore – Whether to apply exploration to the action. Default: None -> use self.config[“explore”].

  • timestep – The current (sampling) time step.

  • episode – This provides access to all of the internal episodes’ state, which may be useful for model-based or multi-agent algorithms.

  • unsquash_action – Should actions be unsquashed according to the env’s/Policy’s action space? If None, use the value of self.config[“normalize_actions”].

  • clip_action – Should actions be clipped according to the env’s/Policy’s action space? If None, use the value of self.config[“clip_actions”].

Keyword Arguments

kwargs – forward compatibility placeholder

Returns

The computed action if full_fetch=False, or a tuple of a) the full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy.

Raises

KeyError – If the policy_id cannot be found in this Trainer’s local worker.

compute_actions(observations: Union[Any, dict, tuple], state: Optional[List[Union[Any, dict, tuple]]] = None, *, prev_action: Union[Any, dict, tuple, None] = None, prev_reward: Union[Any, dict, tuple, None] = None, info: Optional[dict] = None, policy_id: str = 'default_policy', full_fetch: bool = False, explore: Optional[bool] = None, timestep: Optional[int] = None, episodes: Optional[List[ray.rllib.evaluation.episode.Episode]] = None, unsquash_actions: Optional[bool] = None, clip_actions: Optional[bool] = None, normalize_actions=None, **kwargs)[source]

Computes an action for the specified policy on the local Worker.

Note that you can also access the policy object through self.get_policy(policy_id) and call compute_actions() on it directly.

Parameters
  • observation – Observation from the environment.

  • state – RNN hidden state, if any. If state is not None, then all of compute_single_action(…) is returned (computed action, rnn state(s), logits dictionary). Otherwise compute_single_action(…)[0] is returned (computed action).

  • prev_action – Previous action value, if any.

  • prev_reward – Previous reward, if any.

  • info – Env info dict, if any.

  • policy_id – Policy to query (only applies to multi-agent).

  • full_fetch – Whether to return extra action fetch results. This is always set to True if RNN state is specified.

  • explore – Whether to pick an exploitation or exploration action (default: None -> use self.config[“explore”]).

  • timestep – The current (sampling) time step.

  • episodes – This provides access to all of the internal episodes’ state, which may be useful for model-based or multi-agent algorithms.

  • unsquash_actions – Should actions be unsquashed according to the env’s/Policy’s action space? If None, use self.config[“normalize_actions”].

  • clip_actions – Should actions be clipped according to the env’s/Policy’s action space? If None, use self.config[“clip_actions”].

Keyword Arguments

kwargs – forward compatibility placeholder

Returns

The computed action if full_fetch=False, or a tuple consisting of the full output of policy.compute_actions_from_input_dict() if full_fetch=True or we have an RNN-based Policy.

get_policy(policy_id: str = 'default_policy')ray.rllib.policy.policy.Policy[source]

Return policy for the specified id, or None.

Parameters

policy_id – ID of the policy to return.

get_weights(policies: Optional[List[str]] = None) → dict[source]

Return a dictionary of policy ids to weights.

Parameters

policies – Optional list of policies to return weights for, or None for all policies.

set_weights(weights: Dict[str, dict])[source]

Set policy weights by policy id.

Parameters

weights – Map of policy ids to weights to set.

add_policy(policy_id: str, policy_cls: Type[ray.rllib.policy.policy.Policy], *, observation_space: Optional[<Mock name='mock.spaces.Space' id='139804182087632'>] = None, action_space: Optional[<Mock name='mock.spaces.Space' id='139804182087632'>] = None, config: Optional[dict] = None, policy_mapping_fn: Optional[Callable[[Any, int], str]] = None, policies_to_train: Optional[List[str]] = None, evaluation_workers: bool = True)ray.rllib.policy.policy.Policy[source]

Adds a new policy to this Trainer.

Parameters
  • policy_id (PolicyID) – ID of the policy to add.

  • policy_cls (Type[Policy]) – The Policy class to use for constructing the new Policy.

  • observation_space (Optional[gym.spaces.Space]) – The observation space of the policy to add.

  • action_space (Optional[gym.spaces.Space]) – The action space of the policy to add.

  • config (Optional[PartialTrainerConfigDict]) – The config overrides for the policy to add.

  • policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]) – An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode.

  • policies_to_train (Optional[List[PolicyID]]) – An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated.

  • evaluation_workers (bool) – Whether to add the new policy also to the evaluation WorkerSet.

Returns

The newly added policy (the copy that got added to the local worker).

remove_policy(policy_id: str = 'default_policy', *, policy_mapping_fn: Optional[Callable[[Any], str]] = None, policies_to_train: Optional[List[str]] = None, evaluation_workers: bool = True) → None[source]

Removes a new policy from this Trainer.

Parameters
  • policy_id (Optional[PolicyID]) – ID of the policy to be removed.

  • policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]) – An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode.

  • policies_to_train (Optional[List[PolicyID]]) – An optional list of policy IDs to be trained. If None, will keep the existing list in place. Policies, whose IDs are not in the list will not be updated.

  • evaluation_workers (bool) – Whether to also remove the policy from the evaluation WorkerSet.

export_policy_model(export_dir: str, policy_id: str = 'default_policy', onnx: Optional[int] = None) → None[source]

Exports policy model with given policy_id to a local directory.

Parameters
  • export_dir – Writable local directory.

  • policy_id – Optional policy id to export.

  • onnx – If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use. If None, the output format will be DL framework specific.

Example

>>> trainer = MyTrainer()
>>> for _ in range(10):
>>>     trainer.train()
>>> trainer.export_policy_model("/tmp/dir")
>>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1)
export_policy_checkpoint(export_dir: str, filename_prefix: str = 'model', policy_id: str = 'default_policy') → None[source]

Exports policy model checkpoint to a local directory.

Parameters
  • export_dir – Writable local directory.

  • filename_prefix – file name prefix of checkpoint files.

  • policy_id – Optional policy id to export.

Example

>>> trainer = MyTrainer()
>>> for _ in range(10):
>>>     trainer.train()
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
import_policy_model_from_h5(import_file: str, policy_id: str = 'default_policy') → None[source]

Imports a policy’s model with given policy_id from a local h5 file.

Parameters
  • import_file – The h5 file to import from.

  • policy_id – Optional policy id to import into.

Example

>>> trainer = MyTrainer()
>>> trainer.import_policy_model_from_h5("/tmp/weights.h5")
>>> for _ in range(10):
>>>     trainer.train()
collect_metrics(selected_workers: List[ray.actor.ActorHandle] = None) → dict[source]

Collects metrics from the remote workers of this agent.

This is the same data as returned by a call to train().

save_checkpoint(checkpoint_dir: str) → str[source]

Subclasses should override this to implement save().

Warning

Do not rely on absolute paths in the implementation of Trainable.save_checkpoint and Trainable.load_checkpoint.

Use validate_save_restore to catch Trainable.save_checkpoint/ Trainable.load_checkpoint errors before execution.

>>> from ray.tune.utils import validate_save_restore
>>> validate_save_restore(MyTrainableClass)
>>> validate_save_restore(MyTrainableClass, use_object_store=True)

New in version 0.8.7.

Parameters

tmp_checkpoint_dir (str) – The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved.

Returns

A dict or string. If string, the return value is expected to be prefixed by tmp_checkpoint_dir. If dict, the return value will be automatically serialized by Tune and passed to Trainable.load_checkpoint().

Examples

>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1"))
"/tmp/checkpoint_1/my_checkpoint_file"
>>> print(trainable2.save_checkpoint("/tmp/checkpoint_2"))
{"some": "data"}
>>> trainable.save_checkpoint("/tmp/bad_example")
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
load_checkpoint(checkpoint_path: str) → None[source]

Subclasses should override this to implement restore().

Warning

In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in Trainable.save_checkpoint may be changed.

If Trainable.save_checkpoint returned a prefixed string, the prefix of the checkpoint string returned by Trainable.save_checkpoint may be changed. This is because trial pausing depends on temporary directories.

The directory structure under the checkpoint_dir provided to Trainable.save_checkpoint is preserved.

See the example below.

class Example(Trainable):
    def save_checkpoint(self, checkpoint_path):
        print(checkpoint_path)
        return os.path.join(checkpoint_path, "my/check/point")

    def load_checkpoint(self, checkpoint):
        print(checkpoint)

>>> trainer = Example()
>>> obj = trainer.save_to_object()  # This is used when PAUSED.
<logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point
>>> trainer.restore_from_object(obj)  # Note the different prefix.
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point

New in version 0.8.7.

Parameters

checkpoint (str|dict) – If dict, the return value is as returned by save_checkpoint. If a string, then it is a checkpoint path that may have a different prefix than that returned by save_checkpoint. The directory structure underneath the checkpoint_dir save_checkpoint is preserved.

log_result(result: dict) → None[source]

Subclasses can optionally override this to customize logging.

The logging here is done on the worker process rather than the driver. You may want to turn off driver logging via the loggers parameter in tune.run when overriding this function.

New in version 0.8.7.

Parameters

result (dict) – Training result returned by step().

cleanup() → None[source]

Subclasses should override this for any cleanup on stop.

If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here.

You can kill a Ray actor by calling actor.__ray_terminate__.remote() on the actor.

New in version 0.8.7.

classmethod default_resource_request(config: dict) → Union[ray.tune.resources.Resources, ray.tune.utils.placement_groups.PlacementGroupFactory][source]

Provides a static resource requirement for the given configuration.

This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to.

@classmethod
def default_resource_request(cls, config):
    return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}]])
Parameters

Any]] (config[Dict[str,) – The Trainable’s config dict.

Returns

A Resources object or

PlacementGroupFactory consumed by Tune for queueing.

Return type

Union[Resources, PlacementGroupFactory]

classmethod resource_help(config: dict) → str[source]

Returns a help string for configuring this trainable’s resources.

Parameters

config (dict) – The Trainer’s config dict.

static validate_framework(config: dict) → None[source]

Validates the config dictionary wrt the framework settings.

Parameters

config – The config dictionary to be validated.

validate_config(config: dict) → None[source]

Validates a given config dict for this Trainer.

Users should override this method to implement custom validation behavior. It is recommended to call super().validate_config() in this override.

Parameters

config – The given config dict to check.

Raises

ValueError – If there is something wrong with the config.

static validate_env(env: Any, env_context: ray.rllib.env.env_context.EnvContext) → None[source]

Env validator function for this Trainer class.

Override this in child classes to define custom validation behavior.

Parameters
  • env – The (sub-)environment to validate. This is normally a single sub-environment (e.g. a gym.Env) within a vectorized setup.

  • env_context – The EnvContext to configure the environment.

Raises

Exception in case something is wrong with the given environment.

try_recover_from_step_attempt() → None[source]

Try to identify and remove any unhealthy workers.

This method is called after an unexpected remote error is encountered from a worker during the call to self.step_attempt() (within self.step()). It issues check requests to all current workers and removes any that respond with error. If no healthy workers remain, an error is raised. Otherwise, tries to re-build the execution plan with the remaining (healthy) workers.

import_model(import_file: str)[source]

Imports a model from import_file.

Note: Currently, only h5 files are supported.

Parameters

import_file (str) – The file to import the model from.

Returns

A dict that maps ExportFormats to successfully exported models.