from typing import Callable, Optional, Type, Union
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.marwil.marwil_catalog import MARWILCatalog
from ray.rllib.connectors.learner import (
AddObservationsFromEpisodesToBatch,
AddOneTsToEpisodesAndTruncate,
AddNextObservationsFromEpisodesToTrainBatch,
GeneralAdvantageEstimation,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
train_one_step,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.metrics import (
ALL_MODULES,
LEARNER_RESULTS,
LEARNER_UPDATE_TIMER,
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
OFFLINE_SAMPLING_TIMER,
SAMPLE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
TIMERS,
)
from ray.rllib.utils.typing import (
EnvType,
ResultDict,
RLModuleSpecType,
)
from ray.tune.logger import Logger
[docs]
class MARWILConfig(AlgorithmConfig):
"""Defines a configuration class from which a MARWIL Algorithm can be built.
.. testcode::
from pathlib import Path
from ray.rllib.algorithms.marwil import MARWILConfig
# Get the base path (to ray/rllib)
base_path = Path(__file__).parents[2]
# Get the path to the data in rllib folder.
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
config = MARWILConfig()
# Enable the new API stack.
config.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# Define the environment for which to learn a policy
# from offline data.
config.environment("CartPole-v1")
# Set the training parameters.
config.training(
beta=1.0,
lr=1e-5,
gamma=0.99,
# We must define a train batch size for each
# learner (here 1 local learner).
train_batch_size_per_learner=2000,
)
# Define the data source for offline data.
config.offline_data(
input_=[data_path.as_posix()],
# Run exactly one update per training iteration.
dataset_num_iters_per_learner=1,
)
# Build an `Algorithm` object from the config and run 1 training
# iteration.
algo = config.build()
algo.train()
.. testcode::
from pathlib import Path
from ray.rllib.algorithms.marwil import MARWILConfig
from ray import train, tune
# Get the base path (to ray/rllib)
base_path = Path(__file__).parents[2]
# Get the path to the data in rllib folder.
data_path = base_path / "tests/data/cartpole/cartpole-v1_large"
config = MARWILConfig()
# Enable the new API stack.
config.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# Print out some default values
print(f"beta: {config.beta}")
# Update the config object.
config.training(
lr=tune.grid_search([1e-3, 1e-4]),
beta=0.75,
# We must define a train batch size for each
# learner (here 1 local learner).
train_batch_size_per_learner=2000,
)
# Set the config's data path.
config.offline_data(
input_=[data_path.as_posix()],
# Set the number of updates to be run per learner
# per training step.
dataset_num_iters_per_learner=1,
)
# Set the config's environment for evalaution.
config.environment(env="CartPole-v1")
# Set up a tuner to run the experiment.
tuner = tune.Tuner(
"MARWIL",
param_space=config,
run_config=train.RunConfig(
stop={"training_iteration": 1},
),
)
# Run the experiment.
tuner.fit()
"""
def __init__(self, algo_class=None):
"""Initializes a MARWILConfig instance."""
super().__init__(algo_class=algo_class or MARWIL)
# fmt: off
# __sphinx_doc_begin__
# MARWIL specific settings:
self.beta = 1.0
self.bc_logstd_coeff = 0.0
self.moving_average_sqd_adv_norm_update_rate = 1e-8
self.moving_average_sqd_adv_norm_start = 100.0
self.vf_coeff = 1.0
self.model["vf_share_layers"] = False
self.grad_clip = None
# Override some of AlgorithmConfig's default values with MARWIL-specific values.
# You should override input_ to point to an offline dataset
# (see algorithm.py and algorithm_config.py).
# The dataset may have an arbitrary number of timesteps
# (and even episodes) per line.
# However, each line must only contain consecutive timesteps in
# order for MARWIL to be able to calculate accumulated
# discounted returns. It is ok, though, to have multiple episodes in
# the same line.
self.input_ = "sampler"
self.postprocess_inputs = True
self.lr = 1e-4
self.lambda_ = 1.0
self.train_batch_size = 2000
# TODO (Artur): MARWIL should not need an exploration config as an offline
# algorithm. However, the current implementation of the CRR algorithm
# requires it. Investigate.
self.exploration_config = {
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
# You can also provide the python class directly or the full location
# of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
# EpsilonGreedy").
"type": "StochasticSampling",
# Add constructor kwargs here (if any).
}
# Materialize only the data in raw format, but not the mapped data b/c
# MARWIL uses a connector to calculate values and therefore the module
# needs to be updated frequently. This updating would not work if we
# map the data once at the beginning.
# TODO (simon, sven): The module is only updated when the OfflinePreLearner
# gets reinitiated, i.e. when the iterator gets reinitiated. This happens
# frequently enough with a small dataset, but with a big one this does not
# update often enough. We might need to put model weigths every couple of
# iterations into the object storage (maybe also connector states).
self.materialize_data = True
self.materialize_mapped_data = False
# __sphinx_doc_end__
# fmt: on
self._set_off_policy_estimation_methods = False
[docs]
@override(AlgorithmConfig)
def training(
self,
*,
beta: Optional[float] = NotProvided,
bc_logstd_coeff: Optional[float] = NotProvided,
moving_average_sqd_adv_norm_update_rate: Optional[float] = NotProvided,
moving_average_sqd_adv_norm_start: Optional[float] = NotProvided,
vf_coeff: Optional[float] = NotProvided,
grad_clip: Optional[float] = NotProvided,
**kwargs,
) -> "MARWILConfig":
"""Sets the training related configuration.
Args:
beta: Scaling of advantages in exponential terms. When beta is 0.0,
MARWIL is reduced to behavior cloning (imitation learning);
see bc.py algorithm in this same directory.
bc_logstd_coeff: A coefficient to encourage higher action distribution
entropy for exploration.
moving_average_sqd_adv_norm_update_rate: The rate for updating the
squared moving average advantage norm (c^2). A higher rate leads
to faster updates of this moving avergage.
moving_average_sqd_adv_norm_start: Starting value for the
squared moving average advantage norm (c^2).
vf_coeff: Balancing value estimation loss and policy optimization loss.
grad_clip: If specified, clip the global norm of gradients by this amount.
Returns:
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if beta is not NotProvided:
self.beta = beta
if bc_logstd_coeff is not NotProvided:
self.bc_logstd_coeff = bc_logstd_coeff
if moving_average_sqd_adv_norm_update_rate is not NotProvided:
self.moving_average_sqd_adv_norm_update_rate = (
moving_average_sqd_adv_norm_update_rate
)
if moving_average_sqd_adv_norm_start is not NotProvided:
self.moving_average_sqd_adv_norm_start = moving_average_sqd_adv_norm_start
if vf_coeff is not NotProvided:
self.vf_coeff = vf_coeff
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
return self
@override(AlgorithmConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
if self.framework_str == "torch":
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
PPOTorchRLModule,
)
return RLModuleSpec(
module_class=PPOTorchRLModule,
catalog_class=MARWILCatalog,
)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use 'torch' instead."
)
@override(AlgorithmConfig)
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
if self.framework_str == "torch":
from ray.rllib.algorithms.marwil.torch.marwil_torch_learner import (
MARWILTorchLearner,
)
return MARWILTorchLearner
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use 'torch' instead."
)
@override(AlgorithmConfig)
def evaluation(
self,
**kwargs,
) -> "MARWILConfig":
"""Sets the evaluation related configuration.
Returns:
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `evaluation()` method.
super().evaluation(**kwargs)
if "off_policy_estimation_methods" in kwargs:
# User specified their OPE methods.
self._set_off_policy_estimation_methods = True
return self
@override(AlgorithmConfig)
def offline_data(self, **kwargs) -> "MARWILConfig":
super().offline_data(**kwargs)
# Check, if the passed in class incorporates the `OfflinePreLearner`
# interface.
if "prelearner_class" in kwargs:
from ray.rllib.offline.offline_data import OfflinePreLearner
if not issubclass(kwargs.get("prelearner_class"), OfflinePreLearner):
raise ValueError(
f"`prelearner_class` {kwargs.get('prelearner_class')} is not a "
"subclass of `OfflinePreLearner`. Any class passed to "
"`prelearner_class` needs to implement the interface given by "
"`OfflinePreLearner`."
)
return self
@override(AlgorithmConfig)
def build(
self,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
) -> "Algorithm":
if not self._set_off_policy_estimation_methods:
deprecation_warning(
old=r"MARWIL used to have off_policy_estimation_methods "
"is and wis by default. This has"
r"changed to off_policy_estimation_methods: \{\}."
"If you want to use an off-policy estimator, specify it in"
".evaluation(off_policy_estimation_methods=...)",
error=False,
)
return super().build(env, logger_creator)
@override(AlgorithmConfig)
def build_learner_connector(
self,
input_observation_space,
input_action_space,
device=None,
):
pipeline = super().build_learner_connector(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
device=device,
)
# Before anything, add one ts to each episode (and record this in the loss
# mask, so that the computations at this extra ts are not used to compute
# the loss).
pipeline.prepend(AddOneTsToEpisodesAndTruncate())
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)
# At the end of the pipeline (when the batch is already completed), add the
# GAE connector, which performs a vf forward pass, then computes the GAE
# computations, and puts the results of this (advantages, value targets)
# directly back in the batch. This is then the batch used for
# `forward_train` and `compute_losses`.
pipeline.append(
GeneralAdvantageEstimation(gamma=self.gamma, lambda_=self.lambda_)
)
return pipeline
@override(AlgorithmConfig)
def validate(self) -> None:
# Call super's validation method.
super().validate()
if self.beta < 0.0 or self.beta > 1.0:
raise ValueError("`beta` must be within 0.0 and 1.0!")
if self.postprocess_inputs is False and self.beta > 0.0:
raise ValueError(
"`postprocess_inputs` must be True for MARWIL (to "
"calculate accum., discounted returns)! Try setting "
"`config.offline_data(postprocess_inputs=True)`."
)
# Assert that for a local learner the number of iterations is 1. Note,
# this is needed because we have no iterators, but instead a single
# batch returned directly from the `OfflineData.sample` method.
if (
self.num_learners == 0
and not self.dataset_num_iters_per_learner
and self.enable_rl_module_and_learner
):
raise ValueError(
"When using a local Learner (`config.num_learners=0`), the number of "
"iterations per learner (`dataset_num_iters_per_learner`) has to be "
"defined! Set this hyperparameter through `config.offline_data("
"dataset_num_iters_per_learner=...)`."
)
@property
def _model_auto_keys(self):
return super()._model_auto_keys | {"beta": self.beta, "vf_share_layers": False}
class MARWIL(Algorithm):
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfig:
return MARWILConfig()
@classmethod
@override(Algorithm)
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
from ray.rllib.algorithms.marwil.marwil_torch_policy import (
MARWILTorchPolicy,
)
return MARWILTorchPolicy
elif config["framework"] == "tf":
from ray.rllib.algorithms.marwil.marwil_tf_policy import (
MARWILTF1Policy,
)
return MARWILTF1Policy
else:
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTF2Policy
return MARWILTF2Policy
@override(Algorithm)
def training_step(self) -> ResultDict:
if self.config.enable_env_runner_and_connector_v2:
return self._training_step_new_api_stack()
else:
return self._training_step_old_api_stack()
def _training_step_new_api_stack(self) -> ResultDict:
"""Implements training logic for the new stack
Note, this includes so far training with the `OfflineData`
class (multi-/single-learner setup) and evaluation on
`EnvRunner`s. Note further, evaluation on the dataset itself
using estimators is not implemented, yet.
"""
# Implement logic using RLModule and Learner API.
# TODO (simon): Take care of sampler metrics: right
# now all rewards are `nan`, which possibly confuses
# the user that sth. is not right, although it is as
# we do not step the env.
with self.metrics.log_time((TIMERS, OFFLINE_SAMPLING_TIMER)):
# Sampling from offline data.
batch_or_iterator = self.offline_data.sample(
num_samples=self.config.train_batch_size_per_learner,
num_shards=self.config.num_learners,
return_iterator=self.config.num_learners > 1,
)
with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
# Updating the policy.
# TODO (simon, sven): Check, if we should execute directly s.th. like
# `LearnerGroup.update_from_iterator()`.
learner_results = self.learner_group._update(
batch=batch_or_iterator,
minibatch_size=self.config.train_batch_size_per_learner,
num_iters=self.config.dataset_num_iters_per_learner,
)
# Log training results.
self.metrics.merge_and_log_n_dicts(learner_results, key=LEARNER_RESULTS)
self.metrics.log_value(
NUM_ENV_STEPS_TRAINED_LIFETIME,
self.metrics.peek(
(LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED)
),
reduce="sum",
)
self.metrics.log_dict(
{
(LEARNER_RESULTS, mid, NUM_MODULE_STEPS_TRAINED_LIFETIME): (
stats[NUM_MODULE_STEPS_TRAINED]
)
for mid, stats in self.metrics.peek(LEARNER_RESULTS).items()
},
reduce="sum",
)
# Synchronize weights.
# As the results contain for each policy the loss and in addition the
# total loss over all policies is returned, this total loss has to be
# removed.
modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
# Update weights - after learning on the local worker -
# on all remote workers.
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
self.env_runner_group.sync_weights(
# Sync weights from learner_group to all EnvRunners.
from_worker_or_learner_group=self.learner_group,
policies=modules_to_update,
inference_only=True,
)
return self.metrics.reduce()
def _training_step_old_api_stack(self) -> ResultDict:
"""Implements training step for the old stack.
Note, there is no hybrid stack anymore. If you need to use `RLModule`s,
use the new api stack.
"""
# Collect SampleBatches from sample workers.
with self._timers[SAMPLE_TIMER]:
train_batch = synchronous_parallel_sample(worker_set=self.env_runner_group)
train_batch = train_batch.as_multi_agent(
module_id=list(self.config.policies)[0]
)
self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps()
# Train.
if self.config.simple_optimizer:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# TODO: Move training steps counter update outside of `train_one_step()` method.
# # Update train step counters.
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
global_vars = {
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
}
# Update weights - after learning on the local worker - on all remote
# workers (only those policies that were actually trained).
if self.env_runner_group.remote_workers():
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.env_runner_group.sync_weights(
policies=list(train_results.keys()), global_vars=global_vars
)
# Update global vars on local worker as well.
self.env_runner.set_global_vars(global_vars)
return train_results