import inspect
import os
from typing import Optional, Dict, Type, Union, Any, TYPE_CHECKING
import ray.cloudpickle as cpickle
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig, RunConfig
from ray.train.trainer import BaseTrainer, GenDataset
from ray.air._internal.checkpointing import (
save_preprocessor_to_dir,
)
from ray.rllib.algorithms.algorithm import Algorithm as RLlibAlgo
from ray.rllib.utils.typing import PartialAlgorithmConfigDict
from ray.tune import Trainable, PlacementGroupFactory
from ray.tune.registry import get_trainable_cls
from ray.tune.resources import Resources
from ray.util.annotations import PublicAPI
from ray.train.rl.rl_checkpoint import RL_TRAINER_CLASS_FILE, RL_CONFIG_FILE
from ray._private.dict import merge_dicts
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
[docs]@PublicAPI(stability="alpha")
class RLTrainer(BaseTrainer):
"""Reinforcement learning trainer.
This trainer provides an interface to RLlib trainables.
If datasets and preprocessors are used, they can be utilized for
offline training, e.g. using behavior cloning. Otherwise, this trainer
will use online training.
Args:
algorithm: Algorithm to train on. Can be a string reference,
(e.g. ``"PPO"``) or a RLlib trainer class.
scaling_config: Configuration for how to scale training.
run_config: Configuration for the execution of the training run.
datasets: Any Datasets to use for training. Use the key "train"
to denote which dataset is the training
dataset. If a ``preprocessor`` is provided and has not already been fit,
it will be fit on the training dataset. All datasets will be transformed
by the ``preprocessor`` if one is provided.
If specified, datasets will be used for offline training. Will be
configured as an RLlib ``input`` config item.
preprocessor: A preprocessor to preprocess the provided datasets.
resume_from_checkpoint: A checkpoint to resume training from.
Example:
Online training:
.. code-block:: python
from ray.air.config import RunConfig, ScalingConfig
from ray.train.rl import RLTrainer
trainer = RLTrainer(
run_config=RunConfig(stop={"training_iteration": 5}),
scaling_config=ScalingConfig(num_workers=2, use_gpu=False),
algorithm="PPO",
config={
"env": "CartPole-v0",
"framework": "tf",
"evaluation_num_workers": 1,
"evaluation_interval": 1,
"evaluation_config": {"input": "sampler"},
},
)
result = trainer.fit()
Example:
Offline training (assumes data is stored in ``/tmp/data-dir``):
.. code-block:: python
import ray
from ray.air.config import RunConfig, ScalingConfig
from ray.train.rl import RLTrainer
from ray.rllib.algorithms.bc.bc import BC
dataset = ray.data.read_json(
"/tmp/data-dir", parallelism=2, ray_remote_args={"num_cpus": 1}
)
trainer = RLTrainer(
run_config=RunConfig(stop={"training_iteration": 5}),
scaling_config=ScalingConfig(
num_workers=2,
use_gpu=False,
),
datasets={"train": dataset},
algorithm=BCTrainer,
config={
"env": "CartPole-v0",
"framework": "tf",
"evaluation_num_workers": 1,
"evaluation_interval": 1,
"evaluation_config": {"input": "sampler"},
},
)
result = trainer.fit()
"""
def __init__(
self,
algorithm: Union[str, Type[RLlibAlgo]],
config: Optional[Dict[str, Any]] = None,
scaling_config: Optional[ScalingConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
):
self._algorithm = algorithm
self._config = config if config is not None else {}
super(RLTrainer, self).__init__(
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)
def _validate_attributes(self):
super(RLTrainer, self)._validate_attributes()
if not isinstance(self._algorithm, str) and not (
inspect.isclass(self._algorithm) and issubclass(self._algorithm, RLlibAlgo)
):
raise ValueError(
f"`algorithm` should be either a string or a RLlib trainer class, "
f"found {type(self._algorithm)} with value `{self._algorithm}`."
)
if not isinstance(self._config, dict):
raise ValueError(
f"`config` should be either a dict, "
f"found {type(self._config)} with value `{self._config}`."
)
def _get_rllib_config(self, process_datasets: bool = False) -> Dict:
config = self._config.copy()
num_workers = self.scaling_config.num_workers
if num_workers is not None:
config["num_workers"] = num_workers
worker_resources = self.scaling_config.resources_per_worker
if worker_resources:
res = worker_resources.copy()
config["num_cpus_per_worker"] = res.pop("CPU", 1)
config["num_gpus_per_worker"] = res.pop("GPU", 0)
config["custom_resources_per_worker"] = res
use_gpu = self.scaling_config.use_gpu
if use_gpu:
config["num_gpus"] = 1
trainer_resources = self.scaling_config.trainer_resources
if trainer_resources:
config["num_cpus_for_driver"] = trainer_resources.get("CPU", 1)
if process_datasets:
self.preprocess_datasets()
# Up for discussion: If datasets is passed, should we always
# set the input config? Is the sampler config required here, too?
if self.datasets:
config["input"] = "dataset"
config["input_config"] = {
"loader_fn": lambda: self.datasets["train"],
}
return config
def training_loop(self) -> None:
pass
def as_trainable(self) -> Type[Trainable]:
param_dict = self._param_dict
base_config = self._config or {}
trainer_cls = self.__class__
preprocessor = self.preprocessor
if isinstance(self._algorithm, str):
rllib_trainer = get_trainable_cls(self._algorithm)
else:
rllib_trainer = self._algorithm
class AIRRLTrainer(rllib_trainer):
def __init__(
self,
config: Optional[PartialAlgorithmConfigDict] = None,
**kwargs,
):
resolved_config = merge_dicts(base_config, config or {})
param_dict["config"] = resolved_config
trainer = trainer_cls(**param_dict)
rllib_config = trainer._get_rllib_config(process_datasets=True)
super(AIRRLTrainer, self).__init__(config=rllib_config, **kwargs)
def save_checkpoint(self, checkpoint_dir: str):
checkpoint_path = super(AIRRLTrainer, self).save_checkpoint(
checkpoint_dir
)
trainer_class_path = os.path.join(checkpoint_dir, RL_TRAINER_CLASS_FILE)
with open(trainer_class_path, "wb") as fp:
cpickle.dump(self.__class__, fp)
config_path = os.path.join(checkpoint_dir, RL_CONFIG_FILE)
with open(config_path, "wb") as fp:
cpickle.dump(self.config, fp)
if preprocessor:
save_preprocessor_to_dir(preprocessor, checkpoint_dir)
return checkpoint_path
@classmethod
def default_resource_request(
cls, config: PartialAlgorithmConfigDict
) -> Union[Resources, PlacementGroupFactory]:
resolved_config = merge_dicts(base_config, config)
param_dict["config"] = resolved_config
trainer = trainer_cls(**param_dict)
rllib_config = trainer._get_rllib_config(process_datasets=False)
return rllib_trainer.default_resource_request(rllib_config)
AIRRLTrainer.__name__ = f"AIR{rllib_trainer.__name__}"
return AIRRLTrainer