Source code for ray.rllib.algorithms.bc.bc
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.bc.bc_catalog import BCCatalog
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType
[docs]
class BCConfig(MARWILConfig):
"""Defines a configuration class from which a new BC Algorithm can be built
.. testcode::
:skipif: True
from ray.rllib.algorithms.bc import BCConfig
# Run this from the ray directory root.
config = BCConfig().training(lr=0.00001, gamma=0.99)
config = config.offline_data(
input_="./rllib/tests/data/cartpole/large.json")
# Build an Algorithm object from the config and run 1 training iteration.
algo = config.build()
algo.train()
.. testcode::
:skipif: True
from ray.rllib.algorithms.bc import BCConfig
from ray import tune
config = BCConfig()
# Print out some default values.
print(config.beta)
# Update the config object.
config.training(
lr=tune.grid_search([0.001, 0.0001]), beta=0.75
)
# Set the config object's data path.
# Run this from the ray directory root.
config.offline_data(
input_="./rllib/tests/data/cartpole/large.json"
)
# Set the config object's env, used for evaluation.
config.environment(env="CartPole-v1")
# Use to_dict() to get the old-style python config dict
# when running with tune.
tune.Tuner(
"BC",
param_space=config.to_dict(),
).fit()
"""
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or BC)
# fmt: off
# __sphinx_doc_begin__
# No need to calculate advantages (or do anything else with the rewards).
self.beta = 0.0
# Advantages (calculated during postprocessing)
# not important for behavioral cloning.
self.postprocess_inputs = False
# Materialize only the mapped data. This is optimal as long
# as no connector in the connector pipeline holds a state.
self.materialize_data = False
self.materialize_mapped_data = True
# __sphinx_doc_end__
# fmt: on
@override(AlgorithmConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
if self.framework_str == "torch":
from ray.rllib.algorithms.bc.torch.bc_torch_rl_module import BCTorchRLModule
return RLModuleSpec(
module_class=BCTorchRLModule,
catalog_class=BCCatalog,
)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use `torch` instead."
)
@override(AlgorithmConfig)
def build_learner_connector(
self,
input_observation_space,
input_action_space,
device=None,
):
pipeline = super().build_learner_connector(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
device=device,
)
# Remove unneeded connectors from the MARWIL connector pipeline.
pipeline.remove("AddOneTsToEpisodesAndTruncate")
pipeline.remove("GeneralAdvantageEstimation")
return pipeline
@override(MARWILConfig)
def validate(self) -> None:
# Call super's validation method.
super().validate()
if self.beta != 0.0:
raise ValueError("For behavioral cloning, `beta` parameter must be 0.0!")
class BC(MARWIL):
"""Behavioral Cloning (derived from MARWIL).
Simply uses MARWIL with beta force-set to 0.0.
"""
@classmethod
@override(MARWIL)
def get_default_config(cls) -> AlgorithmConfig:
return BCConfig()
@override(MARWIL)
def training_step(self) -> ResultDict:
# Call MARWIL's training step.
return super().training_step()