from typing import Type, Union
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.bc.bc_catalog import BCCatalog
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.core.learner import Learner
from ray.rllib.core.learner.learner_group_config import ModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.metrics import (
ALL_MODULES,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
SAMPLE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.typing import ResultDict
[docs]class BCConfig(MARWILConfig):
"""Defines a configuration class from which a new BC Algorithm can be built
Example:
>>> from ray.rllib.algorithms.bc import BCConfig
>>> # Run this from the ray directory root.
>>> config = BCConfig().training(lr=0.00001, gamma=0.99)
>>> config = config.offline_data( # doctest: +SKIP
... input_="./rllib/tests/data/cartpole/large.json")
>>> print(config.to_dict()) # doctest:+SKIP
>>> # Build an Algorithm object from the config and run 1 training iteration.
>>> algo = config.build() # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
Example:
>>> from ray.rllib.algorithms.bc import BCConfig
>>> from ray import tune
>>> config = BCConfig()
>>> # Print out some default values.
>>> print(config.beta) # doctest: +SKIP
>>> # Update the config object.
>>> config.training( # doctest:+SKIP
... lr=tune.grid_search([0.001, 0.0001]), beta=0.75
... )
>>> # Set the config object's data path.
>>> # Run this from the ray directory root.
>>> config.offline_data( # doctest:+SKIP
... input_="./rllib/tests/data/cartpole/large.json"
... )
>>> # Set the config object's env, used for evaluation.
>>> config.environment(env="CartPole-v1") # doctest:+SKIP
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.Tuner( # doctest:+SKIP
... "BC",
... param_space=config.to_dict(),
... ).fit()
"""
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or BC)
# fmt: off
# __sphinx_doc_begin__
# No need to calculate advantages (or do anything else with the rewards).
self.beta = 0.0
# Advantages (calculated during postprocessing)
# not important for behavioral cloning.
self.postprocess_inputs = False
# Set RLModule as default.
self.rl_module(_enable_rl_module_api=True)
self.training(_enable_learner_api=True)
# __sphinx_doc_end__
# fmt: on
@override(AlgorithmConfig)
def get_default_rl_module_spec(self) -> ModuleSpec:
if self.framework_str == "torch":
from ray.rllib.algorithms.bc.torch.bc_torch_rl_module import BCTorchRLModule
return SingleAgentRLModuleSpec(
module_class=BCTorchRLModule,
catalog_class=BCCatalog,
)
elif self.framework_str == "tf2":
from ray.rllib.algorithms.bc.tf.bc_tf_rl_module import BCTfRLModule
return SingleAgentRLModuleSpec(
module_class=BCTfRLModule,
catalog_class=BCCatalog,
)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use either 'torch' or 'tf2'."
)
@override(AlgorithmConfig)
def get_default_learner_class(self) -> Union[Type[Learner], str]:
if self.framework_str == "torch":
from ray.rllib.algorithms.bc.torch.bc_torch_learner import BCTorchLearner
return BCTorchLearner
elif self.framework_str == "tf2":
from ray.rllib.algorithms.bc.tf.bc_tf_learner import BCTfLearner
return BCTfLearner
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use either 'torch' or 'tf2'."
)
@override(MARWILConfig)
def validate(self) -> None:
# Can not use Tf with learner api.
if self.framework_str == "tf":
self.rl_module(_enable_rl_module_api=False)
self.training(_enable_learner_api=False)
# Call super's validation method.
super().validate()
if self.beta != 0.0:
raise ValueError("For behavioral cloning, `beta` parameter must be 0.0!")
class BC(MARWIL):
"""Behavioral Cloning (derived from MARWIL).
Simply uses MARWIL with beta force-set to 0.0.
"""
@classmethod
@override(MARWIL)
def get_default_config(cls) -> AlgorithmConfig:
return BCConfig()
@ExperimentalAPI
def training_step(self) -> ResultDict:
if not self.config["_enable_rl_module_api"]:
# Using ModelV2.
return super().training_step()
else:
# Implement logic using RLModule and Learner API.
# TODO (sven): Remove RolloutWorkers/EnvRunners for
# datasets. Use RolloutWorker/EnvRunner only for
# env stepping.
# TODO (simon): Take care of sampler metrics: right
# now all rewards are `nan`, which possibly confuses
# the user that sth. is not right, although it is as
# we do not step the env.
with self._timers[SAMPLE_TIMER]:
# Sampling from offline data.
# TODO (simon): We have to remove the `RolloutWorker`
# here and just use the already distributed `dataset`
# for sampling. Only in online evaluation
# `RolloutWorker/EnvRunner` should be used.
if self.config.count_steps_by == "agent_steps":
train_batch = synchronous_parallel_sample(
worker_set=self.workers,
max_agent_steps=self.config.train_batch_size,
)
else:
train_batch = synchronous_parallel_sample(
worker_set=self.workers,
max_env_steps=self.config.train_batch_size,
)
train_batch = train_batch.as_multi_agent()
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
# Updating the policy.
is_module_trainable = self.workers.local_worker().is_policy_to_train
self.learner_group.set_is_module_trainable(is_module_trainable)
train_results = self.learner_group.update(train_batch)
# Synchronize weights.
# As the results contain for each policy the loss and in addition the
# total loss over all policies is returned, this total loss has to be
# removed.
policies_to_update = set(train_results.keys()) - {ALL_MODULES}
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
if self.workers.num_remote_workers() > 0:
self.workers.sync_weights(
from_worker_or_learner_group=self.learner_group,
policies=policies_to_update,
global_vars=global_vars,
)
# Get weights from Learner to local worker.
else:
self.workers.local_worker().set_weights(
self.learner_group.get_weights()
)
return train_results