ray.train.rl.RLTrainer
ray.train.rl.RLTrainer#
- class ray.train.rl.RLTrainer(*args, **kwargs)[source]#
Bases:
ray.train.base_trainer.BaseTrainer
Reinforcement learning trainer.
This trainer provides an interface to RLlib trainables.
If datasets and preprocessors are used, they can be utilized for offline training, e.g. using behavior cloning. Otherwise, this trainer will use online training.
- Parameters
algorithm – Algorithm to train on. Can be a string reference, (e.g.
"PPO"
) or a RLlib trainer class.scaling_config – Configuration for how to scale training.
run_config – Configuration for the execution of the training run.
datasets – Any Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset. If a
preprocessor
is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by thepreprocessor
if one is provided. If specified, datasets will be used for offline training. Will be configured as an RLlibinput
config item.preprocessor – A preprocessor to preprocess the provided datasets.
resume_from_checkpoint – A checkpoint to resume training from.
Example
Online training:
from ray.air.config import RunConfig, ScalingConfig from ray.train.rl import RLTrainer trainer = RLTrainer( run_config=RunConfig(stop={"training_iteration": 5}), scaling_config=ScalingConfig(num_workers=2, use_gpu=False), algorithm="PPO", config={ "env": "CartPole-v0", "framework": "tf", "evaluation_num_workers": 1, "evaluation_interval": 1, "evaluation_config": {"input": "sampler"}, }, ) result = trainer.fit()
Example
Offline training (assumes data is stored in
/tmp/data-dir
):import ray from ray.air.config import RunConfig, ScalingConfig from ray.train.rl import RLTrainer from ray.rllib.algorithms.bc.bc import BC dataset = ray.data.read_json( "/tmp/data-dir", parallelism=2, ray_remote_args={"num_cpus": 1} ) trainer = RLTrainer( run_config=RunConfig(stop={"training_iteration": 5}), scaling_config=ScalingConfig( num_workers=2, use_gpu=False, ), datasets={"train": dataset}, algorithm=BCTrainer, config={ "env": "CartPole-v0", "framework": "tf", "evaluation_num_workers": 1, "evaluation_interval": 1, "evaluation_config": {"input": "sampler"}, }, ) result = trainer.fit()
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- training_loop() None [source]#
Loop called by fit() to run training and report results to Tune.
Note
This method runs on a remote process.
self.datasets
have already been preprocessed byself.preprocessor
.You can use the Tune Function API functions (
session.report()
andsession.get_checkpoint()
) inside this training loop.Example:
from ray.train.trainer import BaseTrainer class MyTrainer(BaseTrainer): def training_loop(self): for epoch_idx in range(5): ... session.report({"epoch": epoch_idx})
- as_trainable() Type[ray.tune.trainable.trainable.Trainable] [source]#
Convert self to a
tune.Trainable
class.