Trainer API¶
The Trainer
class is the highest-level API in RLlib.
It allows you to train and evaluate policies, save an experiment’s progress and restore from
a prior saved experiment when continuing an RL run.
Trainer
is a sub-class
of Trainable
and thus fully supports distributed hyperparameter tuning for RL.
A typical RLlib Trainer object: The components sitting inside a Trainer are
normally N RolloutWorker
and zero or more @ray.remote
BaseEnv
per worker.¶
Building Custom Trainer Classes¶
Warning
As of Ray >= 1.9, it is no longer recommended to use the build_trainer() utility
function for creating custom Trainer sub-classes.
Instead, follow the simple guidelines here for directly sub-classing from
Trainer
.
In order to create a custom Trainer, sub-class the
Trainer
class
and override one or more of its methods. Those are in particular:
get_default_config()
execution_plan()
Trainer base class (ray.rllib.agents.trainer.Trainer)¶
- class ray.rllib.agents.trainer.Trainer(config: Optional[Union[dict, ray.rllib.agents.trainer_config.TrainerConfig]] = None, env: Optional[Union[str, Any]] = None, logger_creator: Optional[Callable[[], ray.tune.logger.Logger]] = None, remote_checkpoint_dir: Optional[str] = None, sync_function_tpl: Optional[str] = None)[source]¶
An RLlib algorithm responsible for optimizing one or more Policies.
Trainers contain a WorkerSet under self.workers. A WorkerSet is normally composed of a single local worker (self.workers.local_worker()), used to compute and apply learning updates, and optionally one or more remote workers (self.workers.remote_workers()), used to generate environment samples in parallel.
Each worker (remotes or local) contains a PolicyMap, which itself may contain either one policy for single-agent training or one or more policies for multi-agent training. Policies are synchronized automatically from time to time using ray.remote calls. The exact synchronization logic depends on the specific algorithm (Trainer) used, but this usually happens from local worker to all remote workers and after each training update.
You can write your own Trainer classes by sub-classing from Trainer or any of its built-in sub-classes. This allows you to override the execution_plan method to implement your own algorithm logic. You can find the different built-in algorithms’ execution plans in their respective main py files, e.g. rllib.algorithms.dqn.dqn.py or rllib.agents.impala.impala.py.
The most important API methods a Trainer exposes are train(), evaluate(), save() and restore(). Trainer objects retain internal model state between calls to train(), so you should create a new Trainer instance for each training session.
- __init__(config: Optional[Union[dict, ray.rllib.agents.trainer_config.TrainerConfig]] = None, env: Optional[Union[str, Any]] = None, logger_creator: Optional[Callable[[], ray.tune.logger.Logger]] = None, remote_checkpoint_dir: Optional[str] = None, sync_function_tpl: Optional[str] = None)[source]¶
Initializes a Trainer instance.
- Parameters
config – Algorithm-specific configuration dict.
env – Name of the environment to use (e.g. a gym-registered str), a full class path (e.g. “ray.rllib.examples.env.random_env.RandomEnv”), or an Env class directly. Note that this arg can also be specified via the “env” key in config.
logger_creator – Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created.
- setup(config: dict)[source]¶
Subclasses should override this for custom initialization.
New in version 0.8.7.
- Parameters
config – Hyperparameters and other configs given. Copy of self.config.
- get_default_policy_class(config: dict) Type[ray.rllib.policy.policy.Policy] [source]¶
Returns a default Policy class to use, given a config.
This class will be used inside RolloutWorkers’ PolicyMaps in case the policy class is not provided by the user in any single- or multi-agent PolicySpec.
This method is experimental and currently only used, iff the Trainer class was not created using the build_trainer utility and if the Trainer sub-class does not override _init() and create it’s own WorkerSet in _init().
- step() dict [source]¶
Implements the main Trainer.train() logic.
Takes n attempts to perform a single training step. Thereby catches RayErrors resulting from worker failures. After n attempts, fails gracefully.
Override this method in your Trainer sub-classes if you would like to handle worker failures yourself. Otherwise, override self.step_attempt() to keep the n attempts (catch worker failures).
- Returns
The results dict with stats/infos on sampling, training, and - if required - evaluation.
- step_attempt() dict [source]¶
Attempts a single training step, including evaluation, if required.
Override this method in your Trainer sub-classes if you would like to keep the n step-attempts logic (catch worker failures) in place or override step() directly if you would like to handle worker failures yourself.
- Returns
The results dict with stats/infos on sampling, training, and - if required - evaluation.
- evaluate(episodes_left_fn=None, duration_fn: Optional[Callable[[int], int]] = None) dict [source]¶
Evaluates current policy under evaluation_config settings.
Note that this default implementation does not do anything beyond merging evaluation_config with the normal trainer config.
- Parameters
duration_fn – An optional callable taking the already run num episodes as only arg and returning the number of episodes left to run. It’s used to find out whether evaluation should continue.
- training_iteration() dict [source]¶
Default single iteration logic of an algorithm.
Collect on-policy samples (SampleBatches) in parallel using the Trainer’s RolloutWorkers (@ray.remote).
Concatenate collected SampleBatches into one train batch.
Note that we may have more than one policy in the multi-agent case: Call the different policies’ learn_on_batch (simple optimizer) OR load_batch_into_buffer + learn_on_loaded_batch (multi-GPU optimizer) methods to calculate loss and update the model(s).
Return all collected metrics for the iteration.
- Returns
The results dict from executing the training iteration.
- compute_single_action(observation: Optional[Union[Any, dict, tuple]] = None, state: Optional[List[Union[Any, dict, tuple]]] = None, *, prev_action: Optional[Union[Any, dict, tuple]] = None, prev_reward: Optional[float] = None, info: Optional[dict] = None, input_dict: Optional[ray.rllib.policy.sample_batch.SampleBatch] = None, policy_id: str = 'default_policy', full_fetch: bool = False, explore: Optional[bool] = None, timestep: Optional[int] = None, episode: Optional[ray.rllib.evaluation.episode.Episode] = None, unsquash_action: Optional[bool] = None, clip_action: Optional[bool] = None, unsquash_actions=- 1, clip_actions=- 1, **kwargs) Union[Any, dict, tuple, Tuple[Union[Any, dict, tuple], List[Any], Dict[str, Any]]] [source]¶
Computes an action for the specified policy on the local worker.
Note that you can also access the policy object through self.get_policy(policy_id) and call compute_single_action() on it directly.
- Parameters
observation – Single (unbatched) observation from the environment.
state – List of all RNN hidden (single, unbatched) state tensors.
prev_action – Single (unbatched) previous action value.
prev_reward – Single (unbatched) previous reward value.
info – Env info dict, if any.
input_dict – An optional SampleBatch that holds all the values for: obs, state, prev_action, and prev_reward, plus maybe custom defined views of the current env trajectory. Note that only one of obs or input_dict must be non-None.
policy_id – Policy to query (only applies to multi-agent). Default: “default_policy”.
full_fetch – Whether to return extra action fetch results. This is always set to True if state is specified.
explore – Whether to apply exploration to the action. Default: None -> use self.config[“explore”].
timestep – The current (sampling) time step.
episode – This provides access to all of the internal episodes’ state, which may be useful for model-based or multi-agent algorithms.
unsquash_action – Should actions be unsquashed according to the env’s/Policy’s action space? If None, use the value of self.config[“normalize_actions”].
clip_action – Should actions be clipped according to the env’s/Policy’s action space? If None, use the value of self.config[“clip_actions”].
- Keyword Arguments
kwargs – forward compatibility placeholder
- Returns
The computed action if full_fetch=False, or a tuple of a) the full output of policy.compute_actions() if full_fetch=True or we have an RNN-based Policy.
- Raises
KeyError – If the policy_id cannot be found in this Trainer’s local worker.
- compute_actions(observations: Union[Any, dict, tuple], state: Optional[List[Union[Any, dict, tuple]]] = None, *, prev_action: Optional[Union[Any, dict, tuple]] = None, prev_reward: Optional[Union[Any, dict, tuple]] = None, info: Optional[dict] = None, policy_id: str = 'default_policy', full_fetch: bool = False, explore: Optional[bool] = None, timestep: Optional[int] = None, episodes: Optional[List[ray.rllib.evaluation.episode.Episode]] = None, unsquash_actions: Optional[bool] = None, clip_actions: Optional[bool] = None, normalize_actions=None, **kwargs)[source]¶
Computes an action for the specified policy on the local Worker.
Note that you can also access the policy object through self.get_policy(policy_id) and call compute_actions() on it directly.
- Parameters
observation – Observation from the environment.
state – RNN hidden state, if any. If state is not None, then all of compute_single_action(…) is returned (computed action, rnn state(s), logits dictionary). Otherwise compute_single_action(…)[0] is returned (computed action).
prev_action – Previous action value, if any.
prev_reward – Previous reward, if any.
info – Env info dict, if any.
policy_id – Policy to query (only applies to multi-agent).
full_fetch – Whether to return extra action fetch results. This is always set to True if RNN state is specified.
explore – Whether to pick an exploitation or exploration action (default: None -> use self.config[“explore”]).
timestep – The current (sampling) time step.
episodes – This provides access to all of the internal episodes’ state, which may be useful for model-based or multi-agent algorithms.
unsquash_actions – Should actions be unsquashed according to the env’s/Policy’s action space? If None, use self.config[“normalize_actions”].
clip_actions – Should actions be clipped according to the env’s/Policy’s action space? If None, use self.config[“clip_actions”].
- Keyword Arguments
kwargs – forward compatibility placeholder
- Returns
The computed action if full_fetch=False, or a tuple consisting of the full output of policy.compute_actions_from_input_dict() if full_fetch=True or we have an RNN-based Policy.
- get_policy(policy_id: str = 'default_policy') ray.rllib.policy.policy.Policy [source]¶
Return policy for the specified id, or None.
- Parameters
policy_id – ID of the policy to return.
- get_weights(policies: Optional[List[str]] = None) dict [source]¶
Return a dictionary of policy ids to weights.
- Parameters
policies – Optional list of policies to return weights for, or None for all policies.
- set_weights(weights: Dict[str, dict])[source]¶
Set policy weights by policy id.
- Parameters
weights – Map of policy ids to weights to set.
- add_policy(policy_id: str, policy_cls: Type[ray.rllib.policy.policy.Policy], *, observation_space: Optional[<MagicMock name='mock.spaces.Space' id='140042509821648'>] = None, action_space: Optional[<MagicMock name='mock.spaces.Space' id='140042509821648'>] = None, config: Optional[dict] = None, policy_state: Optional[Dict[str, Union[Any, dict, tuple]]] = None, policy_mapping_fn: Optional[Callable[[Any, int], str]] = None, policies_to_train: Optional[Union[Container[str], Callable[[str, Optional[Union[SampleBatch, MultiAgentBatch]]], bool]]] = None, evaluation_workers: bool = True, workers: Optional[List[Union[ray.rllib.evaluation.rollout_worker.RolloutWorker, ray.actor.ActorHandle]]] = None) ray.rllib.policy.policy.Policy [source]¶
Adds a new policy to this Trainer.
- Parameters
policy_id – ID of the policy to add.
policy_cls – The Policy class to use for constructing the new Policy.
observation_space – The observation space of the policy to add. If None, try to infer this space from the environment.
action_space – The action space of the policy to add. If None, try to infer this space from the environment.
config – The config overrides for the policy to add.
policy_state – Optional state dict to apply to the new policy instance, right after its construction.
policy_mapping_fn – An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode.
policies_to_train – An optional list of policy IDs to be trained or a callable taking PolicyID and SampleBatchType and returning a bool (trainable or not?). If None, will keep the existing setup in place. Policies, whose IDs are not in the list (or for which the callable returns False) will not be updated.
evaluation_workers – Whether to add the new policy also to the evaluation WorkerSet.
workers – A list of RolloutWorker/ActorHandles (remote RolloutWorkers) to add this policy to. If defined, will only add the given policy to these workers.
- Returns
The newly added policy (the copy that got added to the local worker).
- remove_policy(policy_id: str = 'default_policy', *, policy_mapping_fn: Optional[Callable[[Any], str]] = None, policies_to_train: Optional[Union[Set[str], Callable[[str, Optional[Union[SampleBatch, MultiAgentBatch]]], bool]]] = None, evaluation_workers: bool = True) None [source]¶
Removes a new policy from this Trainer.
- Parameters
policy_id – ID of the policy to be removed.
policy_mapping_fn – An optional (updated) policy mapping function to use from here on. Note that already ongoing episodes will not change their mapping but will use the old mapping till the end of the episode.
policies_to_train – An optional list of policy IDs to be trained or a callable taking PolicyID and SampleBatchType and returning a bool (trainable or not?). If None, will keep the existing setup in place. Policies, whose IDs are not in the list (or for which the callable returns False) will not be updated.
evaluation_workers – Whether to also remove the policy from the evaluation WorkerSet.
- export_policy_model(export_dir: str, policy_id: str = 'default_policy', onnx: Optional[int] = None) None [source]¶
Exports policy model with given policy_id to a local directory.
- Parameters
export_dir – Writable local directory.
policy_id – Optional policy id to export.
onnx – If given, will export model in ONNX format. The value of this parameter set the ONNX OpSet version to use. If None, the output format will be DL framework specific.
Example
>>> from ray.rllib.agents.ppo import PPOTrainer >>> # Use a Trainer from RLlib or define your own. >>> trainer = PPOTrainer(...) >>> for _ in range(10): >>> trainer.train() >>> trainer.export_policy_model("/tmp/dir") >>> trainer.export_policy_model("/tmp/dir/onnx", onnx=1)
- export_policy_checkpoint(export_dir: str, filename_prefix: str = 'model', policy_id: str = 'default_policy') None [source]¶
Exports policy model checkpoint to a local directory.
- Parameters
export_dir – Writable local directory.
filename_prefix – file name prefix of checkpoint files.
policy_id – Optional policy id to export.
Example
>>> from ray.rllib.agents.ppo import PPOTrainer >>> # Use a Trainer from RLlib or define your own. >>> trainer = PPOTrainer(...) >>> for _ in range(10): >>> trainer.train() >>> trainer.export_policy_checkpoint("/tmp/export_dir")
- import_policy_model_from_h5(import_file: str, policy_id: str = 'default_policy') None [source]¶
Imports a policy’s model with given policy_id from a local h5 file.
- Parameters
import_file – The h5 file to import from.
policy_id – Optional policy id to import into.
Example
>>> from ray.rllib.agents.ppo import PPOTrainer >>> trainer = PPOTrainer(...) >>> trainer.import_policy_model_from_h5("/tmp/weights.h5") >>> for _ in range(10): >>> trainer.train()
- save_checkpoint(checkpoint_dir: str) str [source]¶
Subclasses should override this to implement
save()
.Warning
Do not rely on absolute paths in the implementation of
Trainable.save_checkpoint
andTrainable.load_checkpoint
.Use
validate_save_restore
to catchTrainable.save_checkpoint
/Trainable.load_checkpoint
errors before execution.>>> from ray.tune.utils import validate_save_restore >>> MyTrainableClass = ... >>> validate_save_restore(MyTrainableClass) >>> validate_save_restore( ... MyTrainableClass, use_object_store=True)
New in version 0.8.7.
- Parameters
tmp_checkpoint_dir – The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved.
- Returns
A dict or string. If string, the return value is expected to be prefixed by tmp_checkpoint_dir. If dict, the return value will be automatically serialized by Tune and passed to
Trainable.load_checkpoint()
.
Example
>>> trainable, trainable1, trainable2 = ... >>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) "/tmp/checkpoint_1/my_checkpoint_file" >>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) {"some": "data"} >>> trainable.save_checkpoint("/tmp/bad_example") "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
- load_checkpoint(checkpoint_path: str) None [source]¶
Subclasses should override this to implement restore().
Warning
In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in
Trainable.save_checkpoint
may be changed.If
Trainable.save_checkpoint
returned a prefixed string, the prefix of the checkpoint string returned byTrainable.save_checkpoint
may be changed. This is because trial pausing depends on temporary directories.The directory structure under the checkpoint_dir provided to
Trainable.save_checkpoint
is preserved.See the example below.
Example
>>> from ray.tune.trainable import Trainable >>> class Example(Trainable): ... def save_checkpoint(self, checkpoint_path): ... print(checkpoint_path) ... return os.path.join(checkpoint_path, "my/check/point") ... def load_checkpoint(self, checkpoint): ... print(checkpoint) >>> trainer = Example() >>> # This is used when PAUSED. >>> obj = trainer.save_to_object() <logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point >>> # Note the different prefix. >>> trainer.restore_from_object(obj) <logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point
New in version 0.8.7.
- Parameters
checkpoint – If dict, the return value is as returned by save_checkpoint. If a string, then it is a checkpoint path that may have a different prefix than that returned by save_checkpoint. The directory structure underneath the checkpoint_dir save_checkpoint is preserved.
- log_result(result: dict) None [source]¶
Subclasses can optionally override this to customize logging.
The logging here is done on the worker process rather than the driver. You may want to turn off driver logging via the
loggers
parameter intune.run
when overriding this function.New in version 0.8.7.
- Parameters
result – Training result returned by step().
- cleanup() None [source]¶
Subclasses should override this for any cleanup on stop.
If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here.
You can kill a Ray actor by calling actor.__ray_terminate__.remote() on the actor.
New in version 0.8.7.
- classmethod default_resource_request(config: dict) Union[ray.tune.resources.Resources, ray.tune.utils.placement_groups.PlacementGroupFactory] [source]¶
Provides a static resource requirement for the given configuration.
This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to.
@classmethod def default_resource_request(cls, config): return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}]])
- Parameters
config[Dict[str – The Trainable’s config dict.
Any]] – The Trainable’s config dict.
- Returns
- A Resources object or
PlacementGroupFactory consumed by Tune for queueing.
- Return type
Union[Resources, PlacementGroupFactory]
- classmethod resource_help(config: dict) str [source]¶
Returns a help string for configuring this trainable’s resources.
- Parameters
config – The Trainer’s config dict.
- classmethod merge_trainer_configs(config1: dict, config2: dict, _allow_unknown_configs: Optional[bool] = None) dict [source]¶
Merges a complete Trainer config with a partial override dict.
Respects nested structures within the config dicts. The values in the partial override dict take priority.
- Parameters
config1 – The complete Trainer’s dict to be merged (overridden) with config2.
config2 – The partial override config dict to merge on top of config1.
_allow_unknown_configs – If True, keys in config2 that don’t exist in config1 are allowed and will be added to the final config.
- Returns
The merged full trainer config dict.
- static validate_framework(config: dict) None [source]¶
Validates the config dictionary wrt the framework settings.
- Parameters
config – The config dictionary to be validated.
- validate_config(config: dict) None [source]¶
Validates a given config dict for this Trainer.
Users should override this method to implement custom validation behavior. It is recommended to call super().validate_config() in this override.
- Parameters
config – The given config dict to check.
- Raises
ValueError – If there is something wrong with the config.
- static validate_env(env: Any, env_context: ray.rllib.env.env_context.EnvContext) None [source]¶
Env validator function for this Trainer class.
Override this in child classes to define custom validation behavior.
- Parameters
env – The (sub-)environment to validate. This is normally a single sub-environment (e.g. a gym.Env) within a vectorized setup.
env_context – The EnvContext to configure the environment.
- Raises
Exception in case something is wrong with the given environment. –
- try_recover_from_step_attempt() None [source]¶
Try to identify and remove any unhealthy workers.
This method is called after an unexpected remote error is encountered from a worker during the call to self.step_attempt() (within self.step()). It issues check requests to all current workers and removes any that respond with error. If no healthy workers remain, an error is raised. Otherwise, tries to re-build the execution plan with the remaining (healthy) workers.