Source code for ray.rllib.algorithms.bc.bc

from typing import Type, Union
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.bc.bc_catalog import BCCatalog
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.core.learner import Learner
from ray.rllib.core.learner.learner_group_config import ModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.metrics import (
    ALL_MODULES,
    NUM_AGENT_STEPS_SAMPLED,
    NUM_ENV_STEPS_SAMPLED,
    SAMPLE_TIMER,
    SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.typing import ResultDict


[docs]class BCConfig(MARWILConfig): """Defines a configuration class from which a new BC Algorithm can be built Example: >>> from ray.rllib.algorithms.bc import BCConfig >>> # Run this from the ray directory root. >>> config = BCConfig().training(lr=0.00001, gamma=0.99) >>> config = config.offline_data( # doctest: +SKIP ... input_="./rllib/tests/data/cartpole/large.json") >>> print(config.to_dict()) # doctest:+SKIP >>> # Build an Algorithm object from the config and run 1 training iteration. >>> algo = config.build() # doctest: +SKIP >>> algo.train() # doctest: +SKIP Example: >>> from ray.rllib.algorithms.bc import BCConfig >>> from ray import tune >>> config = BCConfig() >>> # Print out some default values. >>> print(config.beta) # doctest: +SKIP >>> # Update the config object. >>> config.training( # doctest:+SKIP ... lr=tune.grid_search([0.001, 0.0001]), beta=0.75 ... ) >>> # Set the config object's data path. >>> # Run this from the ray directory root. >>> config.offline_data( # doctest:+SKIP ... input_="./rllib/tests/data/cartpole/large.json" ... ) >>> # Set the config object's env, used for evaluation. >>> config.environment(env="CartPole-v1") # doctest:+SKIP >>> # Use to_dict() to get the old-style python config dict >>> # when running with tune. >>> tune.Tuner( # doctest:+SKIP ... "BC", ... param_space=config.to_dict(), ... ).fit() """ def __init__(self, algo_class=None): super().__init__(algo_class=algo_class or BC) # fmt: off # __sphinx_doc_begin__ # No need to calculate advantages (or do anything else with the rewards). self.beta = 0.0 # Advantages (calculated during postprocessing) # not important for behavioral cloning. self.postprocess_inputs = False # Set RLModule as default. self.rl_module(_enable_rl_module_api=True) self.training(_enable_learner_api=True) # __sphinx_doc_end__ # fmt: on @override(AlgorithmConfig) def get_default_rl_module_spec(self) -> ModuleSpec: if self.framework_str == "torch": from ray.rllib.algorithms.bc.torch.bc_torch_rl_module import BCTorchRLModule return SingleAgentRLModuleSpec( module_class=BCTorchRLModule, catalog_class=BCCatalog, ) elif self.framework_str == "tf2": from ray.rllib.algorithms.bc.tf.bc_tf_rl_module import BCTfRLModule return SingleAgentRLModuleSpec( module_class=BCTfRLModule, catalog_class=BCCatalog, ) else: raise ValueError( f"The framework {self.framework_str} is not supported. " "Use either 'torch' or 'tf2'." ) @override(AlgorithmConfig) def get_default_learner_class(self) -> Union[Type[Learner], str]: if self.framework_str == "torch": from ray.rllib.algorithms.bc.torch.bc_torch_learner import BCTorchLearner return BCTorchLearner elif self.framework_str == "tf2": from ray.rllib.algorithms.bc.tf.bc_tf_learner import BCTfLearner return BCTfLearner else: raise ValueError( f"The framework {self.framework_str} is not supported. " "Use either 'torch' or 'tf2'." ) @override(MARWILConfig) def validate(self) -> None: # Can not use Tf with learner api. if self.framework_str == "tf": self.rl_module(_enable_rl_module_api=False) self.training(_enable_learner_api=False) # Call super's validation method. super().validate() if self.beta != 0.0: raise ValueError("For behavioral cloning, `beta` parameter must be 0.0!")
class BC(MARWIL): """Behavioral Cloning (derived from MARWIL). Simply uses MARWIL with beta force-set to 0.0. """ @classmethod @override(MARWIL) def get_default_config(cls) -> AlgorithmConfig: return BCConfig() @ExperimentalAPI def training_step(self) -> ResultDict: if not self.config["_enable_rl_module_api"]: # Using ModelV2. return super().training_step() else: # Implement logic using RLModule and Learner API. # TODO (sven): Remove RolloutWorkers/EnvRunners for # datasets. Use RolloutWorker/EnvRunner only for # env stepping. # TODO (simon): Take care of sampler metrics: right # now all rewards are `nan`, which possibly confuses # the user that sth. is not right, although it is as # we do not step the env. with self._timers[SAMPLE_TIMER]: # Sampling from offline data. # TODO (simon): We have to remove the `RolloutWorker` # here and just use the already distributed `dataset` # for sampling. Only in online evaluation # `RolloutWorker/EnvRunner` should be used. if self.config.count_steps_by == "agent_steps": train_batch = synchronous_parallel_sample( worker_set=self.workers, max_agent_steps=self.config.train_batch_size, ) else: train_batch = synchronous_parallel_sample( worker_set=self.workers, max_env_steps=self.config.train_batch_size, ) train_batch = train_batch.as_multi_agent() self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() # Updating the policy. is_module_trainable = self.workers.local_worker().is_policy_to_train self.learner_group.set_is_module_trainable(is_module_trainable) train_results = self.learner_group.update(train_batch) # Synchronize weights. # As the results contain for each policy the loss and in addition the # total loss over all policies is returned, this total loss has to be # removed. policies_to_update = set(train_results.keys()) - {ALL_MODULES} global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: if self.workers.num_remote_workers() > 0: self.workers.sync_weights( from_worker_or_learner_group=self.learner_group, policies=policies_to_update, global_vars=global_vars, ) # Get weights from Learner to local worker. else: self.workers.local_worker().set_weights( self.learner_group.get_weights() ) return train_results