ray.train.rl.RLTrainer#

class ray.train.rl.RLTrainer(*args, **kwargs)[source]#

Bases: ray.train.base_trainer.BaseTrainer

Reinforcement learning trainer.

This trainer provides an interface to RLlib trainables.

If datasets and preprocessors are used, they can be utilized for offline training, e.g. using behavior cloning. Otherwise, this trainer will use online training.

Parameters
  • algorithm – Algorithm to train on. Can be a string reference, (e.g. "PPO") or a RLlib trainer class.

  • scaling_config – Configuration for how to scale training.

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided. If specified, datasets will be used for offline training. Will be configured as an RLlib input config item.

  • preprocessor – A preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

Example

Online training:

from ray.air.config import RunConfig, ScalingConfig
from ray.train.rl import RLTrainer

trainer = RLTrainer(
    run_config=RunConfig(stop={"training_iteration": 5}),
    scaling_config=ScalingConfig(num_workers=2, use_gpu=False),
    algorithm="PPO",
    config={
        "env": "CartPole-v0",
        "framework": "tf",
        "evaluation_num_workers": 1,
        "evaluation_interval": 1,
        "evaluation_config": {"input": "sampler"},
    },
)
result = trainer.fit()

Example

Offline training (assumes data is stored in /tmp/data-dir):

import ray
from ray.air.config import RunConfig, ScalingConfig
from ray.train.rl import RLTrainer
from ray.rllib.algorithms.bc.bc import BC

dataset = ray.data.read_json(
    "/tmp/data-dir", parallelism=2, ray_remote_args={"num_cpus": 1}
)

trainer = RLTrainer(
    run_config=RunConfig(stop={"training_iteration": 5}),
    scaling_config=ScalingConfig(
        num_workers=2,
        use_gpu=False,
    ),
    datasets={"train": dataset},
    algorithm=BCTrainer,
    config={
        "env": "CartPole-v0",
        "framework": "tf",
        "evaluation_num_workers": 1,
        "evaluation_interval": 1,
        "evaluation_config": {"input": "sampler"},
    },
)
result = trainer.fit()

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

training_loop() None[source]#

Loop called by fit() to run training and report results to Tune.

Note

This method runs on a remote process.

self.datasets have already been preprocessed by self.preprocessor.

You can use the Tune Function API functions (session.report() and session.get_checkpoint()) inside this training loop.

Example:

from ray.train.trainer import BaseTrainer

class MyTrainer(BaseTrainer):
    def training_loop(self):
        for epoch_idx in range(5):
            ...
            session.report({"epoch": epoch_idx})
as_trainable() Type[ray.tune.trainable.trainable.Trainable][source]#

Convert self to a tune.Trainable class.