import copy
import dataclasses
from enum import Enum
import logging
import math
import sys
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
import gymnasium as gym
import tree
from packaging import version
import ray
from ray.rllib.callbacks.callbacks import RLlibCallback
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env import INPUT_ENV_SPACES
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.wrappers.atari_wrappers import is_atari
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils.annotations import (
OldAPIStack,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
Deprecated,
deprecation_warning,
)
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import NotProvided, from_config
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.serialization import (
NOT_SERIALIZABLE,
deserialize_type,
serialize_type,
)
from ray.rllib.utils.test_utils import check
from ray.rllib.utils.torch_utils import TORCH_COMPILE_REQUIRED_VERSION
from ray.rllib.utils.typing import (
AgentID,
AlgorithmConfigDict,
EnvConfigDict,
EnvType,
LearningRateOrSchedule,
ModuleID,
MultiAgentPolicyConfigDict,
PartialAlgorithmConfigDict,
PolicyID,
RLModuleSpecType,
SampleBatchType,
)
from ray.tune.logger import Logger
from ray.tune.registry import get_trainable_cls
from ray.tune.result import TRIAL_INFO
from ray.tune.tune import _Config
Space = gym.Space
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.learner import Learner
from ray.rllib.core.learner.learner_group import LearnerGroup
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.typing import EpisodeType
logger = logging.getLogger(__name__)
def _check_rl_module_spec(module_spec: RLModuleSpecType) -> None:
if not isinstance(module_spec, (RLModuleSpec, MultiRLModuleSpec)):
raise ValueError(
"rl_module_spec must be an instance of "
"RLModuleSpec or MultiRLModuleSpec."
f"Got {type(module_spec)} instead."
)
[docs]
class AlgorithmConfig(_Config):
"""A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration.
.. testcode::
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.callbacks import MemoryTrackingCallbacks
# Construct a generic config object, specifying values within different
# sub-categories, e.g. "training".
config = (
PPOConfig()
.training(gamma=0.9, lr=0.01)
.environment(env="CartPole-v1")
.env_runners(num_env_runners=0)
.callbacks(MemoryTrackingCallbacks)
)
# A config object can be used to construct the respective Algorithm.
rllib_algo = config.build()
.. testcode::
from ray.rllib.algorithms.ppo import PPOConfig
from ray import tune
# In combination with a tune.grid_search:
config = PPOConfig()
config.training(lr=tune.grid_search([0.01, 0.001]))
# Use `to_dict()` method to get the legacy plain python config dict
# for usage with `tune.Tuner().fit()`.
tune.Tuner("PPO", param_space=config.to_dict())
"""
@staticmethod
def DEFAULT_AGENT_TO_MODULE_MAPPING_FN(agent_id, episode):
# The default agent ID to module ID mapping function to use in the multi-agent
# case if None is provided.
# Map any agent ID to "default_policy".
return DEFAULT_MODULE_ID
# TODO (sven): Deprecate in new API stack.
@staticmethod
def DEFAULT_POLICY_MAPPING_FN(aid, episode, worker, **kwargs):
# The default policy mapping function to use if None provided.
# Map any agent ID to "default_policy".
return DEFAULT_POLICY_ID
[docs]
@classmethod
def from_dict(cls, config_dict: dict) -> "AlgorithmConfig":
"""Creates an AlgorithmConfig from a legacy python config dict.
.. testcode::
from ray.rllib.algorithms.ppo.ppo import PPOConfig
# pass a RLlib config dict
ppo_config = PPOConfig.from_dict({})
ppo = ppo_config.build(env="Pendulum-v1")
Args:
config_dict: The legacy formatted python config dict for some algorithm.
Returns:
A new AlgorithmConfig object that matches the given python config dict.
"""
# Create a default config object of this class.
config_obj = cls()
# Remove `_is_frozen` flag from config dict in case the AlgorithmConfig that
# the dict was derived from was already frozen (we don't want to copy the
# frozenness).
config_dict.pop("_is_frozen", None)
config_obj.update_from_dict(config_dict)
return config_obj
[docs]
@classmethod
def overrides(cls, **kwargs):
"""Generates and validates a set of config key/value pairs (passed via kwargs).
Validation whether given config keys are valid is done immediately upon
construction (by comparing against the properties of a default AlgorithmConfig
object of this class).
Allows combination with a full AlgorithmConfig object to yield a new
AlgorithmConfig object.
Used anywhere, we would like to enable the user to only define a few config
settings that would change with respect to some main config, e.g. in multi-agent
setups and evaluation configs.
.. testcode::
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.policy.policy import PolicySpec
config = (
PPOConfig()
.multi_agent(
policies={
"pol0": PolicySpec(config=PPOConfig.overrides(lambda_=0.95))
},
)
)
.. testcode::
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.ppo import PPOConfig
config = (
PPOConfig()
.evaluation(
evaluation_num_env_runners=1,
evaluation_interval=1,
evaluation_config=AlgorithmConfig.overrides(explore=False),
)
)
Returns:
A dict mapping valid config property-names to values.
Raises:
KeyError: In case a non-existing property name (kwargs key) is being
passed in. Valid property names are taken from a default
AlgorithmConfig object of `cls`.
"""
default_config = cls()
config_overrides = {}
for key, value in kwargs.items():
if not hasattr(default_config, key):
raise KeyError(
f"Invalid property name {key} for config class {cls.__name__}!"
)
# Allow things like "lambda" as well.
key = cls._translate_special_keys(key, warn_deprecated=True)
config_overrides[key] = value
return config_overrides
[docs]
def __init__(self, algo_class: Optional[type] = None):
"""Initializes an AlgorithmConfig instance.
Args:
algo_class: An optional Algorithm class that this config class belongs to.
Used (if provided) to build a respective Algorithm instance from this
config.
"""
# Define all settings and their default values.
# Define the default RLlib Algorithm class that this AlgorithmConfig is applied
# to.
self.algo_class = algo_class
# `self.python_environment()`
self.extra_python_environs_for_driver = {}
self.extra_python_environs_for_worker = {}
# `self.resources()`
self.placement_strategy = "PACK"
self.num_gpus = 0 # @OldAPIStack
self._fake_gpus = False # @OldAPIStack
self.num_cpus_for_main_process = 1
# `self.framework()`
self.framework_str = "torch"
self.eager_tracing = True
self.eager_max_retraces = 20
self.tf_session_args = {
# note: overridden by `local_tf_session_args`
"intra_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 2,
"gpu_options": {
"allow_growth": True,
},
"log_device_placement": False,
"device_count": {"CPU": 1},
# Required by multi-GPU (num_gpus > 1).
"allow_soft_placement": True,
}
self.local_tf_session_args = {
# Allow a higher level of parallelism by default, but not unlimited
# since that can cause crashes with many concurrent drivers.
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
}
# Torch compile settings
self.torch_compile_learner = False
self.torch_compile_learner_what_to_compile = (
TorchCompileWhatToCompile.FORWARD_TRAIN
)
# AOT Eager is a dummy backend and doesn't result in speedups.
self.torch_compile_learner_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "inductor"
)
self.torch_compile_learner_dynamo_mode = None
self.torch_compile_worker = False
# AOT Eager is a dummy backend and doesn't result in speedups.
self.torch_compile_worker_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "onnxrt"
)
self.torch_compile_worker_dynamo_mode = None
# Default kwargs for `torch.nn.parallel.DistributedDataParallel`.
self.torch_ddp_kwargs = {}
# Default setting for skipping `nan` gradient updates.
self.torch_skip_nan_gradients = False
# `self.environment()`
self.env = None
self.env_config = {}
self.observation_space = None
self.action_space = None
self.clip_rewards = None
self.normalize_actions = True
self.clip_actions = False
self._is_atari = None
self.disable_env_checking = False
# Deprecated settings:
self.render_env = False
self.action_mask_key = "action_mask"
# `self.env_runners()`
self.env_runner_cls = None
self.num_env_runners = 0
self.num_envs_per_env_runner = 1
# TODO (sven): Once new ormsgpack system in place, reaplce the string
# with proper `gym.envs.registration.VectorizeMode.SYNC`.
self.gym_env_vectorize_mode = "SYNC"
self.num_cpus_per_env_runner = 1
self.num_gpus_per_env_runner = 0
self.custom_resources_per_env_runner = {}
self.validate_env_runners_after_construction = True
self.episodes_to_numpy = True
self.max_requests_in_flight_per_env_runner = 1
self.sample_timeout_s = 60.0
self.create_env_on_local_worker = False
self._env_to_module_connector = None
self.add_default_connectors_to_env_to_module_pipeline = True
self._module_to_env_connector = None
self.add_default_connectors_to_module_to_env_pipeline = True
self.episode_lookback_horizon = 1
# TODO (sven): Rename into `sample_timesteps` (or `sample_duration`
# and `sample_duration_unit` (replacing batch_mode), like we do it
# in the evaluation config).
self.rollout_fragment_length = 200
# TODO (sven): Rename into `sample_mode`.
self.batch_mode = "truncate_episodes"
self.compress_observations = False
# @OldAPIStack
self.remote_worker_envs = False
self.remote_env_batch_wait_ms = 0
self.enable_tf1_exec_eagerly = False
self.sample_collector = SimpleListCollector
self.preprocessor_pref = "deepmind"
self.observation_filter = "NoFilter"
self.update_worker_filter_stats = True
self.use_worker_filter_stats = True
self.sampler_perf_stats_ema_coef = None
# `self.learners()`
self.num_learners = 0
self.num_gpus_per_learner = 0
self.num_cpus_per_learner = 1
self.num_aggregator_actors_per_learner = 0
self.max_requests_in_flight_per_aggregator_actor = 100
self.local_gpu_idx = 0
# TODO (sven): This probably works even without any restriction
# (allowing for any arbitrary number of requests in-flight). Test with
# 3 first, then with unlimited, and if both show the same behavior on
# an async algo, remove this restriction entirely.
self.max_requests_in_flight_per_learner = 3
# `self.training()`
self.gamma = 0.99
self.lr = 0.001
self.grad_clip = None
self.grad_clip_by = "global_norm"
# Simple logic for now: If None, use `train_batch_size`.
self._train_batch_size_per_learner = None
self.train_batch_size = 32 # @OldAPIStack
# These setting have been adopted from the original PPO batch settings:
# num_sgd_iter, minibatch_size, and shuffle_sequences.
self.num_epochs = 1
self.minibatch_size = None
self.shuffle_batch_per_epoch = False
# TODO (sven): Unsolved problem with RLModules sometimes requiring settings from
# the main AlgorithmConfig. We should not require the user to provide those
# settings in both, the AlgorithmConfig (as property) AND the model config
# dict. We should generally move to a world, in which there exists an
# AlgorithmConfig that a) has-a user provided model config object and b)
# is given a chance to compile a final model config (dict or object) that is
# then passed into the RLModule/Catalog. This design would then match our
# "compilation" pattern, where we compile automatically those settings that
# should NOT be touched by the user.
# In case, an Algorithm already uses the above described pattern (and has
# `self.model` as a @property, ignore AttributeError (for trying to set this
# property).
try:
self.model = copy.deepcopy(MODEL_DEFAULTS)
except AttributeError:
pass
self._learner_connector = None
self.add_default_connectors_to_learner_pipeline = True
self.learner_config_dict = {}
self.optimizer = {} # @OldAPIStack
self._learner_class = None
# `self.callbacks()`
# TODO (sven): Set this default to None, once the old API stack has been
# deprecated.
self.callbacks_class = RLlibCallback
self.callbacks_on_algorithm_init = None
self.callbacks_on_env_runners_recreated = None
self.callbacks_on_checkpoint_loaded = None
self.callbacks_on_environment_created = None
self.callbacks_on_episode_created = None
self.callbacks_on_episode_start = None
self.callbacks_on_episode_step = None
self.callbacks_on_episode_end = None
self.callbacks_on_evaluate_start = None
self.callbacks_on_evaluate_end = None
self.callbacks_on_sample_end = None
self.callbacks_on_train_result = None
# `self.explore()`
self.explore = True
# This is not compatible with RLModules, which have a method
# `forward_exploration` to specify custom exploration behavior.
if not hasattr(self, "exploration_config"):
# Helper to keep track of the original exploration config when dis-/enabling
# rl modules.
self._prior_exploration_config = None
self.exploration_config = {}
# `self.api_stack()`
self.enable_rl_module_and_learner = True
self.enable_env_runner_and_connector_v2 = True
self.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
# `self.multi_agent()`
# TODO (sven): Prepare multi-agent setup for logging each agent's and each
# RLModule's steps taken thus far (and passing this information into the
# EnvRunner metrics and the RLModule's forward pass). Thereby, deprecate the
# `count_steps_by` config setting AND - at the same time - allow users to
# specify the batch size unit instead (agent- vs env steps).
self.count_steps_by = "env_steps"
# self.agent_to_module_mapping_fn = self.DEFAULT_AGENT_TO_MODULE_MAPPING_FN
# Soon to be Deprecated.
self.policies = {DEFAULT_POLICY_ID: PolicySpec()}
self.policy_map_capacity = 100
self.policy_mapping_fn = self.DEFAULT_POLICY_MAPPING_FN
self.policies_to_train = None
self.policy_states_are_swappable = False
self.observation_fn = None
# `self.offline_data()`
self.input_ = "sampler"
self.offline_data_class = None
self.offline_data_class = None
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.input_read_schema = {}
self.input_read_episodes = False
self.input_read_sample_batches = False
self.input_read_batch_size = None
self.input_filesystem = None
self.input_filesystem_kwargs = {}
self.input_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
self.input_spaces_jsonable = True
self.materialize_data = False
self.materialize_mapped_data = True
self.map_batches_kwargs = {}
self.iter_batches_kwargs = {}
self.prelearner_class = None
self.prelearner_buffer_class = None
self.prelearner_buffer_kwargs = {}
self.prelearner_module_synch_period = 10
self.dataset_num_iters_per_learner = None
self.input_config = {}
self.actions_in_input_normalized = False
self.postprocess_inputs = False
self.shuffle_buffer_size = 0
self.output = None
self.output_config = {}
self.output_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
self.output_max_file_size = 64 * 1024 * 1024
self.output_max_rows_per_file = None
self.output_write_remaining_data = False
self.output_write_method = "write_parquet"
self.output_write_method_kwargs = {}
self.output_filesystem = None
self.output_filesystem_kwargs = {}
self.output_write_episodes = True
self.offline_sampling = False
# `self.evaluation()`
self.evaluation_interval = None
self.evaluation_duration = 10
self.evaluation_duration_unit = "episodes"
self.evaluation_sample_timeout_s = 120.0
self.evaluation_parallel_to_training = False
self.evaluation_force_reset_envs_before_iteration = True
self.evaluation_config = None
self.off_policy_estimation_methods = {}
self.ope_split_batch_by_episode = True
self.evaluation_num_env_runners = 0
self.custom_evaluation_function = None
# TODO: Set this flag still in the config or - much better - in the
# RolloutWorker as a property.
self.in_evaluation = False
# TODO (sven): Deprecate this setting (it's not user-accessible right now any
# way). Replace by logic within `training_step` to merge and broadcast the
# EnvRunner (connector) states.
self.sync_filters_on_rollout_workers_timeout_s = 10.0
# `self.reporting()`
self.keep_per_episode_custom_metrics = False
self.metrics_episode_collection_timeout_s = 60.0
self.metrics_num_episodes_for_smoothing = 100
self.min_time_s_per_iteration = None
self.min_train_timesteps_per_iteration = 0
self.min_sample_timesteps_per_iteration = 0
self.log_gradients = True
# `self.checkpointing()`
self.export_native_model_files = False
self.checkpoint_trainable_policies_only = False
# `self.debugging()`
self.logger_creator = None
self.logger_config = None
self.log_level = "WARN"
self.log_sys_usage = True
self.fake_sampler = False
self.seed = None
# `self.fault_tolerance()`
self.restart_failed_env_runners = True
self.ignore_env_runner_failures = False
# By default, restart failed worker a thousand times.
# This should be enough to handle normal transient failures.
# This also prevents infinite number of restarts in case the worker or env has
# a bug.
self.max_num_env_runner_restarts = 1000
# Small delay between worker restarts. In case EnvRunners or eval EnvRunners
# have remote dependencies, this delay can be adjusted to make sure we don't
# flood them with re-connection requests, and allow them enough time to recover.
# This delay also gives Ray time to stream back error logging and exceptions.
self.delay_between_env_runner_restarts_s = 60.0
self.restart_failed_sub_environments = False
self.num_consecutive_env_runner_failures_tolerance = 100
self.env_runner_health_probe_timeout_s = 30.0
self.env_runner_restore_timeout_s = 1800.0
# `self.rl_module()`
self._model_config = {}
self._rl_module_spec = None
# Module ID specific config overrides.
self.algorithm_config_overrides_per_module = {}
# Cached, actual AlgorithmConfig objects derived from
# `self.algorithm_config_overrides_per_module`.
self._per_module_overrides: Dict[ModuleID, "AlgorithmConfig"] = {}
# `self.experimental()`
self._validate_config = True
self._use_msgpack_checkpoints = False
self._torch_grad_scaler_class = None
self._torch_lr_scheduler_classes = None
self._tf_policy_handles_more_than_one_loss = False
self._disable_preprocessor_api = False
self._disable_action_flattening = False
self._disable_initialize_loss_from_dummy_batch = False
self._dont_auto_sync_env_runner_states = False
# Has this config object been frozen (cannot alter its attributes anymore).
self._is_frozen = False
# TODO: Remove, once all deprecation_warning calls upon using these keys
# have been removed.
# === Deprecated keys ===
self.env_task_fn = DEPRECATED_VALUE
self.enable_connectors = DEPRECATED_VALUE
self.simple_optimizer = DEPRECATED_VALUE
self.monitor = DEPRECATED_VALUE
self.evaluation_num_episodes = DEPRECATED_VALUE
self.metrics_smoothing_episodes = DEPRECATED_VALUE
self.timesteps_per_iteration = DEPRECATED_VALUE
self.min_iter_time_s = DEPRECATED_VALUE
self.collect_metrics_timeout = DEPRECATED_VALUE
self.min_time_s_per_reporting = DEPRECATED_VALUE
self.min_train_timesteps_per_reporting = DEPRECATED_VALUE
self.min_sample_timesteps_per_reporting = DEPRECATED_VALUE
self.input_evaluation = DEPRECATED_VALUE
self.policy_map_cache = DEPRECATED_VALUE
self.worker_cls = DEPRECATED_VALUE
self.synchronize_filters = DEPRECATED_VALUE
self.enable_async_evaluation = DEPRECATED_VALUE
self.custom_async_evaluation_function = DEPRECATED_VALUE
self._enable_rl_module_api = DEPRECATED_VALUE
self.auto_wrap_old_gym_envs = DEPRECATED_VALUE
self.always_attach_evaluation_results = DEPRECATED_VALUE
# The following values have moved because of the new ReplayBuffer API
self.buffer_size = DEPRECATED_VALUE
self.prioritized_replay = DEPRECATED_VALUE
self.learning_starts = DEPRECATED_VALUE
self.replay_batch_size = DEPRECATED_VALUE
# -1 = DEPRECATED_VALUE is a valid value for replay_sequence_length
self.replay_sequence_length = None
self.replay_mode = DEPRECATED_VALUE
self.prioritized_replay_alpha = DEPRECATED_VALUE
self.prioritized_replay_beta = DEPRECATED_VALUE
self.prioritized_replay_eps = DEPRECATED_VALUE
self.min_time_s_per_reporting = DEPRECATED_VALUE
self.min_train_timesteps_per_reporting = DEPRECATED_VALUE
self.min_sample_timesteps_per_reporting = DEPRECATED_VALUE
self._disable_execution_plan_api = DEPRECATED_VALUE
[docs]
def to_dict(self) -> AlgorithmConfigDict:
"""Converts all settings into a legacy config dict for backward compatibility.
Returns:
A complete AlgorithmConfigDict, usable in backward-compatible Tune/RLlib
use cases.
"""
config = copy.deepcopy(vars(self))
config.pop("algo_class")
config.pop("_is_frozen")
# Worst naming convention ever: NEVER EVER use reserved key-words...
if "lambda_" in config:
assert hasattr(self, "lambda_")
config["lambda"] = self.lambda_
config.pop("lambda_")
if "input_" in config:
assert hasattr(self, "input_")
config["input"] = self.input_
config.pop("input_")
# Convert `policies` (PolicySpecs?) into dict.
# Convert policies dict such that each policy ID maps to a old-style.
# 4-tuple: class, obs-, and action space, config.
if "policies" in config and isinstance(config["policies"], dict):
policies_dict = {}
for policy_id, policy_spec in config.pop("policies").items():
if isinstance(policy_spec, PolicySpec):
policies_dict[policy_id] = policy_spec.get_state()
else:
policies_dict[policy_id] = policy_spec
config["policies"] = policies_dict
# Switch out deprecated vs new config keys.
config["callbacks"] = config.pop("callbacks_class", None)
config["create_env_on_driver"] = config.pop("create_env_on_local_worker", 1)
config["custom_eval_function"] = config.pop("custom_evaluation_function", None)
config["framework"] = config.pop("framework_str", None)
# Simplify: Remove all deprecated keys that have as value `DEPRECATED_VALUE`.
# These would be useless in the returned dict anyways.
for dep_k in [
"monitor",
"evaluation_num_episodes",
"metrics_smoothing_episodes",
"timesteps_per_iteration",
"min_iter_time_s",
"collect_metrics_timeout",
"buffer_size",
"prioritized_replay",
"learning_starts",
"replay_batch_size",
"replay_mode",
"prioritized_replay_alpha",
"prioritized_replay_beta",
"prioritized_replay_eps",
"min_time_s_per_reporting",
"min_train_timesteps_per_reporting",
"min_sample_timesteps_per_reporting",
"input_evaluation",
"_enable_new_api_stack",
]:
if config.get(dep_k) == DEPRECATED_VALUE:
config.pop(dep_k, None)
return config
[docs]
def update_from_dict(
self,
config_dict: PartialAlgorithmConfigDict,
) -> "AlgorithmConfig":
"""Modifies this AlgorithmConfig via the provided python config dict.
Warns if `config_dict` contains deprecated keys.
Silently sets even properties of `self` that do NOT exist. This way, this method
may be used to configure custom Policies which do not have their own specific
AlgorithmConfig classes, e.g.
`ray.rllib.examples.policy.random_policy::RandomPolicy`.
Args:
config_dict: The old-style python config dict (PartialAlgorithmConfigDict)
to use for overriding some properties defined in there.
Returns:
This updated AlgorithmConfig object.
"""
eval_call = {}
# We deal with this special key before all others because it may influence
# stuff like "exploration_config".
# Namely, we want to re-instantiate the exploration config this config had
# inside `self.experimental()` before potentially overwriting it in the
# following.
enable_new_api_stack = config_dict.get(
"enable_rl_module_and_learner",
config_dict.get("enable_env_runner_and_connector_v2"),
)
if enable_new_api_stack is not None:
self.api_stack(
enable_rl_module_and_learner=enable_new_api_stack,
enable_env_runner_and_connector_v2=enable_new_api_stack,
)
# Modify our properties one by one.
for key, value in config_dict.items():
key = self._translate_special_keys(key, warn_deprecated=False)
# Ray Tune saves additional data under this magic keyword.
# This should not get treated as AlgorithmConfig field.
if key == TRIAL_INFO:
continue
if key in ["_enable_new_api_stack"]:
# We've dealt with this above.
continue
# Set our multi-agent settings.
elif key == "multiagent":
kwargs = {
k: value[k]
for k in [
"policies",
"policy_map_capacity",
"policy_mapping_fn",
"policies_to_train",
"policy_states_are_swappable",
"observation_fn",
"count_steps_by",
]
if k in value
}
self.multi_agent(**kwargs)
# Some keys specify config sub-dicts and therefore should go through the
# correct methods to properly `.update()` those from given config dict
# (to not lose any sub-keys).
elif key == "callbacks_class" and value != NOT_SERIALIZABLE:
# For backward compatibility reasons, only resolve possible
# classpath if value is a str type.
if isinstance(value, str):
value = deserialize_type(value, error=True)
self.callbacks(callbacks_class=value)
elif key == "env_config":
self.environment(env_config=value)
elif key.startswith("evaluation_"):
eval_call[key] = value
elif key == "exploration_config":
if enable_new_api_stack:
self.exploration_config = value
continue
if isinstance(value, dict) and "type" in value:
value["type"] = deserialize_type(value["type"])
self.env_runners(exploration_config=value)
elif key == "model":
# Resolve possible classpath.
if isinstance(value, dict) and value.get("custom_model"):
value["custom_model"] = deserialize_type(value["custom_model"])
self.training(**{key: value})
elif key == "optimizer":
self.training(**{key: value})
elif key == "replay_buffer_config":
if isinstance(value, dict) and "type" in value:
value["type"] = deserialize_type(value["type"])
self.training(**{key: value})
elif key == "sample_collector":
# Resolve possible classpath.
value = deserialize_type(value)
self.env_runners(sample_collector=value)
# Set the property named `key` to `value`.
else:
setattr(self, key, value)
self.evaluation(**eval_call)
return self
[docs]
def get_state(self) -> Dict[str, Any]:
"""Returns a dict state that can be pickled.
Returns:
A dictionary containing all attributes of the instance.
"""
state = self.__dict__.copy()
state["class"] = type(self)
state.pop("algo_class")
state.pop("_is_frozen")
state = {k: v for k, v in state.items() if v != DEPRECATED_VALUE}
# Convert `policies` (PolicySpecs?) into dict.
# Convert policies dict such that each policy ID maps to a old-style.
# 4-tuple: class, obs-, and action space, config.
# TODO (simon, sven): Remove when deprecating old stack.
if "policies" in state and isinstance(state["policies"], dict):
policies_dict = {}
for policy_id, policy_spec in state.pop("policies").items():
if isinstance(policy_spec, PolicySpec):
policies_dict[policy_id] = policy_spec.get_state()
else:
policies_dict[policy_id] = policy_spec
state["policies"] = policies_dict
# state = self._serialize_dict(state)
return state
[docs]
@classmethod
def from_state(cls, state: Dict[str, Any]) -> "AlgorithmConfig":
"""Returns an instance constructed from the state.
Args:
cls: An `AlgorithmConfig` class.
state: A dictionary containing the state of an `AlgorithmConfig`.
See `AlgorithmConfig.get_state` for creating a state.
Returns:
An `AlgorithmConfig` instance with attributes from the `state`.
"""
ctor = state["class"]
config = ctor()
config.__dict__.update(state)
return config
# TODO(sven): We might want to have a `deserialize` method as well. Right now,
# simply using the from_dict() API works in this same (deserializing) manner,
# whether the dict used is actually code-free (already serialized) or not
# (i.e. a classic RLlib config dict with e.g. "callbacks" key still pointing to
# a class).
[docs]
def serialize(self) -> Dict[str, Any]:
"""Returns a mapping from str to JSON'able values representing this config.
The resulting values don't have any code in them.
Classes (such as `callbacks_class`) are converted to their full
classpath, e.g. `ray.rllib.callbacks.callbacks.RLlibCallback`.
Actual code such as lambda functions ware written as their source
code (str) plus any closure information for properly restoring the
code inside the AlgorithmConfig object made from the returned dict data.
Dataclass objects get converted to dicts.
Returns:
A dict mapping from str to JSON'able values.
"""
config = self.to_dict()
return self._serialize_dict(config)
[docs]
def copy(self, copy_frozen: Optional[bool] = None) -> "AlgorithmConfig":
"""Creates a deep copy of this config and (un)freezes if necessary.
Args:
copy_frozen: Whether the created deep copy is frozen or not. If None,
keep the same frozen status that `self` currently has.
Returns:
A deep copy of `self` that is (un)frozen.
"""
cp = copy.deepcopy(self)
if copy_frozen is True:
cp.freeze()
elif copy_frozen is False:
cp._is_frozen = False
if isinstance(cp.evaluation_config, AlgorithmConfig):
cp.evaluation_config._is_frozen = False
return cp
[docs]
def freeze(self) -> None:
"""Freezes this config object, such that no attributes can be set anymore.
Algorithms should use this method to make sure that their config objects
remain read-only after this.
"""
if self._is_frozen:
return
self._is_frozen = True
# Also freeze underlying eval config, if applicable.
if isinstance(self.evaluation_config, AlgorithmConfig):
self.evaluation_config.freeze()
# TODO: Flip out all set/dict/list values into frozen versions
# of themselves? This way, users won't even be able to alter those values
# directly anymore.
[docs]
@OverrideToImplementCustomLogic_CallToSuperRecommended
def validate(self) -> None:
"""Validates all values in this config."""
# Validation is blocked.
if not self._validate_config:
return
self._validate_env_runner_settings()
self._validate_callbacks_settings()
self._validate_framework_settings()
self._validate_resources_settings()
self._validate_multi_agent_settings()
self._validate_input_settings()
self._validate_evaluation_settings()
self._validate_offline_settings()
self._validate_new_api_stack_settings()
self._validate_to_be_deprecated_settings()
[docs]
def build_algo(
self,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
use_copy: bool = True,
) -> "Algorithm":
"""Builds an Algorithm from this AlgorithmConfig (or a copy thereof).
Args:
env: Name of the environment to use (e.g. a gym-registered str),
a full class path (e.g.
"ray.rllib.examples.envs.classes.random_env.RandomEnv"), or an Env
class directly. Note that this arg can also be specified via
the "env" key in `config`.
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
use_copy: Whether to deepcopy `self` and pass the copy to the Algorithm
(instead of `self`) as config. This is useful in case you would like to
recycle the same AlgorithmConfig over and over, e.g. in a test case, in
which we loop over different DL-frameworks.
Returns:
A ray.rllib.algorithms.algorithm.Algorithm object.
"""
if env is not None:
self.env = env
if self.evaluation_config is not None:
self.evaluation_config["env"] = env
if logger_creator is not None:
self.logger_creator = logger_creator
algo_class = self.algo_class
if isinstance(self.algo_class, str):
algo_class = get_trainable_cls(self.algo_class)
return algo_class(
config=self if not use_copy else copy.deepcopy(self),
logger_creator=self.logger_creator,
)
def build_env_to_module_connector(self, env, device=None):
from ray.rllib.connectors.env_to_module import (
AddObservationsFromEpisodesToBatch,
AddStatesFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AgentToModuleMapping,
BatchIndividualItems,
EnvToModulePipeline,
NumpyToTensor,
)
custom_connectors = []
# Create an env-to-module connector pipeline (including RLlib's default
# env->module connector piece) and return it.
if self._env_to_module_connector is not None:
val_ = self._env_to_module_connector(env)
from ray.rllib.connectors.connector_v2 import ConnectorV2
# ConnectorV2 (piece or pipeline).
if isinstance(val_, ConnectorV2):
custom_connectors = [val_]
# Sequence of individual ConnectorV2 pieces.
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
# Unsupported return value.
else:
raise ValueError(
"`AlgorithmConfig.env_runners(env_to_module_connector=..)` must "
"return a ConnectorV2 object or a list thereof (to be added to a "
f"pipeline)! Your function returned {val_}."
)
obs_space = getattr(env, "single_observation_space", env.observation_space)
if obs_space is None and self.is_multi_agent:
obs_space = gym.spaces.Dict(
{
aid: env.get_observation_space(aid)
for aid in env.unwrapped.possible_agents
}
)
act_space = getattr(env, "single_action_space", env.action_space)
if act_space is None and self.is_multi_agent:
act_space = gym.spaces.Dict(
{
aid: env.get_action_space(aid)
for aid in env.unwrapped.possible_agents
}
)
pipeline = EnvToModulePipeline(
input_observation_space=obs_space,
input_action_space=act_space,
connectors=custom_connectors,
)
if self.add_default_connectors_to_env_to_module_pipeline:
# Append OBS handling.
pipeline.append(AddObservationsFromEpisodesToBatch())
# Append time-rank handler.
pipeline.append(AddTimeDimToBatchAndZeroPad())
# Append STATE_IN/STATE_OUT handler.
pipeline.append(AddStatesFromEpisodesToBatch())
# If multi-agent -> Map from AgentID-based data to ModuleID based data.
if self.is_multi_agent:
pipeline.append(
AgentToModuleMapping(
rl_module_specs=(
self.rl_module_spec.rl_module_specs
if isinstance(self.rl_module_spec, MultiRLModuleSpec)
else set(self.policies)
),
agent_to_module_mapping_fn=self.policy_mapping_fn,
)
)
# Batch all data.
pipeline.append(BatchIndividualItems(multi_agent=self.is_multi_agent))
# Convert to Tensors.
pipeline.append(NumpyToTensor(device=device))
return pipeline
def build_module_to_env_connector(self, env):
from ray.rllib.connectors.module_to_env import (
GetActions,
ListifyDataForVectorEnv,
ModuleToAgentUnmapping,
ModuleToEnvPipeline,
NormalizeAndClipActions,
RemoveSingleTsTimeRankFromBatch,
TensorToNumpy,
UnBatchToIndividualItems,
)
custom_connectors = []
# Create a module-to-env connector pipeline (including RLlib's default
# module->env connector piece) and return it.
if self._module_to_env_connector is not None:
val_ = self._module_to_env_connector(env)
from ray.rllib.connectors.connector_v2 import ConnectorV2
# ConnectorV2 (piece or pipeline).
if isinstance(val_, ConnectorV2):
custom_connectors = [val_]
# Sequence of individual ConnectorV2 pieces.
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
# Unsupported return value.
else:
raise ValueError(
"`AlgorithmConfig.env_runners(module_to_env_connector=..)` must "
"return a ConnectorV2 object or a list thereof (to be added to a "
f"pipeline)! Your function returned {val_}."
)
obs_space = getattr(env, "single_observation_space", env.observation_space)
if obs_space is None and self.is_multi_agent:
obs_space = gym.spaces.Dict(
{
aid: env.get_observation_space(aid)
for aid in env.unwrapped.possible_agents
}
)
act_space = getattr(env, "single_action_space", env.action_space)
if act_space is None and self.is_multi_agent:
act_space = gym.spaces.Dict(
{
aid: env.get_action_space(aid)
for aid in env.unwrapped.possible_agents
}
)
pipeline = ModuleToEnvPipeline(
input_observation_space=obs_space,
input_action_space=act_space,
connectors=custom_connectors,
)
if self.add_default_connectors_to_module_to_env_pipeline:
# Prepend: Anything that has to do with plain data processing (not
# particularly with the actions).
# Remove extra time-rank, if applicable.
pipeline.prepend(RemoveSingleTsTimeRankFromBatch())
# If multi-agent -> Map from ModuleID-based data to AgentID based data.
if self.is_multi_agent:
pipeline.prepend(ModuleToAgentUnmapping())
# Unbatch all data.
pipeline.prepend(UnBatchToIndividualItems())
# Convert to numpy.
pipeline.prepend(TensorToNumpy())
# Sample actions from ACTION_DIST_INPUTS (if ACTIONS not present).
pipeline.prepend(GetActions())
# Append: Anything that has to do with action sampling.
# Unsquash/clip actions based on config and action space.
pipeline.append(
NormalizeAndClipActions(
normalize_actions=self.normalize_actions,
clip_actions=self.clip_actions,
)
)
# Listify data from ConnectorV2-data format to normal lists that we can
# index into by env vector index. These lists contain individual items
# for single-agent and multi-agent dicts for multi-agent.
pipeline.append(ListifyDataForVectorEnv())
return pipeline
def build_learner_connector(
self,
input_observation_space,
input_action_space,
device=None,
):
from ray.rllib.connectors.learner import (
AddColumnsFromEpisodesToTrainBatch,
AddObservationsFromEpisodesToBatch,
AddStatesFromEpisodesToBatch,
AddTimeDimToBatchAndZeroPad,
AgentToModuleMapping,
BatchIndividualItems,
LearnerConnectorPipeline,
NumpyToTensor,
)
custom_connectors = []
# Create a learner connector pipeline (including RLlib's default
# learner connector piece) and return it.
if self._learner_connector is not None:
val_ = self._learner_connector(
input_observation_space,
input_action_space,
# device, # TODO (sven): Also pass device into custom builder.
)
from ray.rllib.connectors.connector_v2 import ConnectorV2
# ConnectorV2 (piece or pipeline).
if isinstance(val_, ConnectorV2):
custom_connectors = [val_]
# Sequence of individual ConnectorV2 pieces.
elif isinstance(val_, (list, tuple)):
custom_connectors = list(val_)
# Unsupported return value.
else:
raise ValueError(
"`AlgorithmConfig.training(learner_connector=..)` must return "
"a ConnectorV2 object or a list thereof (to be added to a "
f"pipeline)! Your function returned {val_}."
)
pipeline = LearnerConnectorPipeline(
connectors=custom_connectors,
input_observation_space=input_observation_space,
input_action_space=input_action_space,
)
if self.add_default_connectors_to_learner_pipeline:
# Append OBS handling.
pipeline.append(
AddObservationsFromEpisodesToBatch(as_learner_connector=True)
)
# Append all other columns handling.
pipeline.append(AddColumnsFromEpisodesToTrainBatch())
# Append time-rank handler.
pipeline.append(AddTimeDimToBatchAndZeroPad(as_learner_connector=True))
# Append STATE_IN/STATE_OUT handler.
pipeline.append(AddStatesFromEpisodesToBatch(as_learner_connector=True))
# If multi-agent -> Map from AgentID-based data to ModuleID based data.
if self.is_multi_agent:
pipeline.append(
AgentToModuleMapping(
rl_module_specs=(
self.rl_module_spec.rl_module_specs
if isinstance(self.rl_module_spec, MultiRLModuleSpec)
else set(self.policies)
),
agent_to_module_mapping_fn=self.policy_mapping_fn,
)
)
# Batch all data.
pipeline.append(BatchIndividualItems(multi_agent=self.is_multi_agent))
# Convert to Tensors.
pipeline.append(NumpyToTensor(as_learner_connector=True, device=device))
return pipeline
[docs]
def build_learner_group(
self,
*,
env: Optional[EnvType] = None,
spaces: Optional[Dict[ModuleID, Tuple[gym.Space, gym.Space]]] = None,
rl_module_spec: Optional[RLModuleSpecType] = None,
) -> "LearnerGroup":
"""Builds and returns a new LearnerGroup object based on settings in `self`.
Args:
env: An optional EnvType object (e.g. a gym.Env) useful for extracting space
information for the to-be-constructed RLModule inside the LearnerGroup's
Learner workers. Note that if RLlib cannot infer any space information
either from this `env` arg, from the optional `spaces` arg or from
`self`, the LearnerGroup cannot be created.
spaces: An optional dict mapping ModuleIDs to
(observation-space, action-space)-tuples for the to-be-constructed
RLModule inside the LearnerGroup's Learner workers. Note that if RLlib
cannot infer any space information either from this `spces` arg,
from the optional `env` arg or from `self`, the LearnerGroup cannot
be created.
rl_module_spec: An optional (single-agent or multi-agent) RLModuleSpec to
use for the constructed LearnerGroup. If None, RLlib tries to infer
the RLModuleSpec using the other information given and stored in this
`AlgorithmConfig` object.
Returns:
The newly created `LearnerGroup` object.
"""
from ray.rllib.core.learner.learner_group import LearnerGroup
# If `spaces` or `env` provided -> Create a MultiRLModuleSpec first to be
# passed into the LearnerGroup constructor.
if rl_module_spec is None:
rl_module_spec = self.get_multi_rl_module_spec(env=env, spaces=spaces)
# Construct the actual LearnerGroup.
learner_group = LearnerGroup(config=self.copy(), module_spec=rl_module_spec)
return learner_group
[docs]
def build_learner(
self,
*,
env: Optional[EnvType] = None,
spaces: Optional[Dict[PolicyID, Tuple[gym.Space, gym.Space]]] = None,
) -> "Learner":
"""Builds and returns a new Learner object based on settings in `self`.
This Learner object already has its `build()` method called, meaning
its RLModule is already constructed.
Args:
env: An optional EnvType object (e.g. a gym.Env) useful for extracting space
information for the to-be-constructed RLModule inside the Learner.
Note that if RLlib cannot infer any space information
either from this `env` arg, from the optional `spaces` arg or from
`self`, the Learner cannot be created.
spaces: An optional dict mapping ModuleIDs to
(observation-space, action-space)-tuples for the to-be-constructed
RLModule inside the Learner. Note that if RLlib cannot infer any
space information either from this `spces` arg, from the optional
`env` arg or from `self`, the Learner cannot be created.
Returns:
The newly created (and already built) Learner object.
"""
# If `spaces` or `env` provided -> Create a MultiRLModuleSpec first to be
# passed into the LearnerGroup constructor.
rl_module_spec = None
if env is not None or spaces is not None:
rl_module_spec = self.get_multi_rl_module_spec(env=env, spaces=spaces)
# Construct the actual Learner object.
learner = self.learner_class(config=self, module_spec=rl_module_spec)
# `build()` the Learner (internal structures such as RLModule, etc..).
learner.build()
return learner
[docs]
def get_config_for_module(self, module_id: ModuleID) -> "AlgorithmConfig":
"""Returns an AlgorithmConfig object, specific to the given module ID.
In a multi-agent setup, individual modules might override one or more
AlgorithmConfig properties (e.g. `train_batch_size`, `lr`) using the
`overrides()` method.
In order to retrieve a full AlgorithmConfig instance (with all these overrides
already translated and built-in), users can call this method with the respective
module ID.
Args:
module_id: The module ID for which to get the final AlgorithmConfig object.
Returns:
A new AlgorithmConfig object for the specific module ID.
"""
# ModuleID NOT found in cached ModuleID, but in overrides dict.
# Create new algo config object and cache it.
if (
module_id not in self._per_module_overrides
and module_id in self.algorithm_config_overrides_per_module
):
self._per_module_overrides[module_id] = self.copy().update_from_dict(
self.algorithm_config_overrides_per_module[module_id]
)
# Return the module specific algo config object.
if module_id in self._per_module_overrides:
return self._per_module_overrides[module_id]
# No overrides for ModuleID -> return self.
else:
return self
[docs]
def python_environment(
self,
*,
extra_python_environs_for_driver: Optional[dict] = NotProvided,
extra_python_environs_for_worker: Optional[dict] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's python environment settings.
Args:
extra_python_environs_for_driver: Any extra python env vars to set in the
algorithm's process, e.g., {"OMP_NUM_THREADS": "16"}.
extra_python_environs_for_worker: The extra python environments need to set
for worker processes.
Returns:
This updated AlgorithmConfig object.
"""
if extra_python_environs_for_driver is not NotProvided:
self.extra_python_environs_for_driver = extra_python_environs_for_driver
if extra_python_environs_for_worker is not NotProvided:
self.extra_python_environs_for_worker = extra_python_environs_for_worker
return self
[docs]
def resources(
self,
*,
num_cpus_for_main_process: Optional[int] = NotProvided,
num_gpus: Optional[Union[float, int]] = NotProvided, # @OldAPIStack
_fake_gpus: Optional[bool] = NotProvided, # @OldAPIStack
placement_strategy: Optional[str] = NotProvided,
# Deprecated args.
num_cpus_per_worker=DEPRECATED_VALUE, # moved to `env_runners`
num_gpus_per_worker=DEPRECATED_VALUE, # moved to `env_runners`
custom_resources_per_worker=DEPRECATED_VALUE, # moved to `env_runners`
num_learner_workers=DEPRECATED_VALUE, # moved to `learners`
num_cpus_per_learner_worker=DEPRECATED_VALUE, # moved to `learners`
num_gpus_per_learner_worker=DEPRECATED_VALUE, # moved to `learners`
local_gpu_idx=DEPRECATED_VALUE, # moved to `learners`
num_cpus_for_local_worker=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Specifies resources allocated for an Algorithm and its ray actors/workers.
Args:
num_cpus_for_main_process: Number of CPUs to allocate for the main algorithm
process that runs `Algorithm.training_step()`.
Note: This is only relevant when running RLlib through Tune. Otherwise,
`Algorithm.training_step()` runs in the main program (driver).
num_gpus: Number of GPUs to allocate to the algorithm process.
Note that not all algorithms can take advantage of GPUs.
Support for multi-GPU is currently only available for
tf-[PPO/IMPALA/DQN/PG]. This can be fractional (e.g., 0.3 GPUs).
_fake_gpus: Set to True for debugging (multi-)?GPU funcitonality on a
CPU machine. GPU towers are simulated by graphs located on
CPUs in this case. Use `num_gpus` to test for different numbers of
fake GPUs.
placement_strategy: The strategy for the placement group factory returned by
`Algorithm.default_resource_request()`. A PlacementGroup defines, which
devices (resources) should always be co-located on the same node.
For example, an Algorithm with 2 EnvRunners and 1 Learner (with
1 GPU) requests a placement group with the bundles:
[{"cpu": 1}, {"gpu": 1, "cpu": 1}, {"cpu": 1}, {"cpu": 1}], where the
first bundle is for the local (main Algorithm) process, the second one
for the 1 Learner worker and the last 2 bundles are for the two
EnvRunners. These bundles can now be "placed" on the same or different
nodes depending on the value of `placement_strategy`:
"PACK": Packs bundles into as few nodes as possible.
"SPREAD": Places bundles across distinct nodes as even as possible.
"STRICT_PACK": Packs bundles into one node. The group is not allowed
to span multiple nodes.
"STRICT_SPREAD": Packs bundles across distinct nodes.
Returns:
This updated AlgorithmConfig object.
"""
if num_cpus_per_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.resources(num_cpus_per_worker)",
new="AlgorithmConfig.env_runners(num_cpus_per_env_runner)",
error=False,
)
self.num_cpus_per_env_runner = num_cpus_per_worker
if num_gpus_per_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.resources(num_gpus_per_worker)",
new="AlgorithmConfig.env_runners(num_gpus_per_env_runner)",
error=False,
)
self.num_gpus_per_env_runner = num_gpus_per_worker
if custom_resources_per_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.resources(custom_resources_per_worker)",
new="AlgorithmConfig.env_runners(custom_resources_per_env_runner)",
error=False,
)
self.custom_resources_per_env_runner = custom_resources_per_worker
if num_learner_workers != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.resources(num_learner_workers)",
new="AlgorithmConfig.learners(num_learner)",
error=False,
)
self.num_learners = num_learner_workers
if num_cpus_per_learner_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.resources(num_cpus_per_learner_worker)",
new="AlgorithmConfig.learners(num_cpus_per_learner)",
error=False,
)
self.num_cpus_per_learner = num_cpus_per_learner_worker
if num_gpus_per_learner_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.resources(num_gpus_per_learner_worker)",
new="AlgorithmConfig.learners(num_gpus_per_learner)",
error=False,
)
self.num_gpus_per_learner = num_gpus_per_learner_worker
if local_gpu_idx != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.resources(local_gpu_idx)",
new="AlgorithmConfig.learners(local_gpu_idx)",
error=False,
)
self.local_gpu_idx = local_gpu_idx
if num_cpus_for_local_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.resources(num_cpus_for_local_worker)",
new="AlgorithmConfig.resources(num_cpus_for_main_process)",
error=False,
)
self.num_cpus_for_main_process = num_cpus_for_local_worker
if num_cpus_for_main_process is not NotProvided:
self.num_cpus_for_main_process = num_cpus_for_main_process
if num_gpus is not NotProvided:
self.num_gpus = num_gpus
if _fake_gpus is not NotProvided:
self._fake_gpus = _fake_gpus
if placement_strategy is not NotProvided:
self.placement_strategy = placement_strategy
return self
[docs]
def framework(
self,
framework: Optional[str] = NotProvided,
*,
eager_tracing: Optional[bool] = NotProvided,
eager_max_retraces: Optional[int] = NotProvided,
tf_session_args: Optional[Dict[str, Any]] = NotProvided,
local_tf_session_args: Optional[Dict[str, Any]] = NotProvided,
torch_compile_learner: Optional[bool] = NotProvided,
torch_compile_learner_what_to_compile: Optional[str] = NotProvided,
torch_compile_learner_dynamo_mode: Optional[str] = NotProvided,
torch_compile_learner_dynamo_backend: Optional[str] = NotProvided,
torch_compile_worker: Optional[bool] = NotProvided,
torch_compile_worker_dynamo_backend: Optional[str] = NotProvided,
torch_compile_worker_dynamo_mode: Optional[str] = NotProvided,
torch_ddp_kwargs: Optional[Dict[str, Any]] = NotProvided,
torch_skip_nan_gradients: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's DL framework settings.
Args:
framework: torch: PyTorch; tf2: TensorFlow 2.x (eager execution or traced
if eager_tracing=True); tf: TensorFlow (static-graph);
eager_tracing: Enable tracing in eager mode. This greatly improves
performance (speedup ~2x), but makes it slightly harder to debug
since Python code won't be evaluated after the initial eager pass.
Only possible if framework=tf2.
eager_max_retraces: Maximum number of tf.function re-traces before a
runtime error is raised. This is to prevent unnoticed retraces of
methods inside the `..._eager_traced` Policy, which could slow down
execution by a factor of 4, without the user noticing what the root
cause for this slowdown could be.
Only necessary for framework=tf2.
Set to None to ignore the re-trace count and never throw an error.
tf_session_args: Configures TF for single-process operation by default.
local_tf_session_args: Override the following tf session args on the local
worker
torch_compile_learner: If True, forward_train methods on TorchRLModules
on the learner are compiled. If not specified, the default is to compile
forward train on the learner.
torch_compile_learner_what_to_compile: A TorchCompileWhatToCompile
mode specifying what to compile on the learner side if
torch_compile_learner is True. See TorchCompileWhatToCompile for
details and advice on its usage.
torch_compile_learner_dynamo_backend: The torch dynamo backend to use on
the learner.
torch_compile_learner_dynamo_mode: The torch dynamo mode to use on the
learner.
torch_compile_worker: If True, forward exploration and inference methods on
TorchRLModules on the workers are compiled. If not specified,
the default is to not compile forward methods on the workers because
retracing can be expensive.
torch_compile_worker_dynamo_backend: The torch dynamo backend to use on
the workers.
torch_compile_worker_dynamo_mode: The torch dynamo mode to use on the
workers.
torch_ddp_kwargs: The kwargs to pass into
`torch.nn.parallel.DistributedDataParallel` when using `num_learners
> 1`. This is specifically helpful when searching for unused parameters
that are not used in the backward pass. This can give hints for errors
in custom models where some parameters do not get touched in the
backward pass although they should.
torch_skip_nan_gradients: If updates with `nan` gradients should be entirely
skipped. This skips updates in the optimizer entirely if they contain
any `nan` gradient. This can help to avoid biasing moving-average based
optimizers - like Adam. This can help in training phases where policy
updates can be highly unstable such as during the early stages of
training or with highly exploratory policies. In such phases many
gradients might turn `nan` and setting them to zero could corrupt the
optimizer's internal state. The default is `False` and turns `nan`
gradients to zero. If many `nan` gradients are encountered consider (a)
monitoring gradients by setting `log_gradients` in `AlgorithmConfig` to
`True`, (b) use proper weight initialization (e.g. Xavier, Kaiming) via
the `model_config_dict` in `AlgorithmConfig.rl_module` and/or (c)
gradient clipping via `grad_clip` in `AlgorithmConfig.training`.
Returns:
This updated AlgorithmConfig object.
"""
if framework is not NotProvided:
if framework == "tfe":
deprecation_warning(
old="AlgorithmConfig.framework('tfe')",
new="AlgorithmConfig.framework('tf2')",
error=True,
)
self.framework_str = framework
if eager_tracing is not NotProvided:
self.eager_tracing = eager_tracing
if eager_max_retraces is not NotProvided:
self.eager_max_retraces = eager_max_retraces
if tf_session_args is not NotProvided:
self.tf_session_args = tf_session_args
if local_tf_session_args is not NotProvided:
self.local_tf_session_args = local_tf_session_args
if torch_compile_learner is not NotProvided:
self.torch_compile_learner = torch_compile_learner
if torch_compile_learner_dynamo_backend is not NotProvided:
self.torch_compile_learner_dynamo_backend = (
torch_compile_learner_dynamo_backend
)
if torch_compile_learner_dynamo_mode is not NotProvided:
self.torch_compile_learner_dynamo_mode = torch_compile_learner_dynamo_mode
if torch_compile_learner_what_to_compile is not NotProvided:
self.torch_compile_learner_what_to_compile = (
torch_compile_learner_what_to_compile
)
if torch_compile_worker is not NotProvided:
self.torch_compile_worker = torch_compile_worker
if torch_compile_worker_dynamo_backend is not NotProvided:
self.torch_compile_worker_dynamo_backend = (
torch_compile_worker_dynamo_backend
)
if torch_compile_worker_dynamo_mode is not NotProvided:
self.torch_compile_worker_dynamo_mode = torch_compile_worker_dynamo_mode
if torch_ddp_kwargs is not NotProvided:
self.torch_ddp_kwargs = torch_ddp_kwargs
if torch_skip_nan_gradients is not NotProvided:
self.torch_skip_nan_gradients = torch_skip_nan_gradients
return self
[docs]
def api_stack(
self,
enable_rl_module_and_learner: Optional[bool] = NotProvided,
enable_env_runner_and_connector_v2: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's API stack settings.
Args:
enable_rl_module_and_learner: Enables the usage of `RLModule` (instead of
`ModelV2`) and Learner (instead of the training-related parts of
`Policy`). Must be used with `enable_env_runner_and_connector_v2=True`.
Together, these two settings activate the "new API stack" of RLlib.
enable_env_runner_and_connector_v2: Enables the usage of EnvRunners
(SingleAgentEnvRunner and MultiAgentEnvRunner) and ConnectorV2.
When setting this to True, `enable_rl_module_and_learner` must be True
as well. Together, these two settings activate the "new API stack" of
RLlib.
Returns:
This updated AlgorithmConfig object.
"""
if enable_rl_module_and_learner is not NotProvided:
self.enable_rl_module_and_learner = enable_rl_module_and_learner
if enable_rl_module_and_learner is True and self.exploration_config:
self._prior_exploration_config = self.exploration_config
self.exploration_config = {}
elif enable_rl_module_and_learner is False and not self.exploration_config:
if self._prior_exploration_config is not None:
self.exploration_config = self._prior_exploration_config
self._prior_exploration_config = None
else:
logger.warning(
"config.enable_rl_module_and_learner was set to False, but no "
"prior exploration config was found to be restored."
)
if enable_env_runner_and_connector_v2 is not NotProvided:
self.enable_env_runner_and_connector_v2 = enable_env_runner_and_connector_v2
return self
[docs]
def environment(
self,
env: Optional[Union[str, EnvType]] = NotProvided,
*,
env_config: Optional[EnvConfigDict] = NotProvided,
observation_space: Optional[gym.spaces.Space] = NotProvided,
action_space: Optional[gym.spaces.Space] = NotProvided,
render_env: Optional[bool] = NotProvided,
clip_rewards: Optional[Union[bool, float]] = NotProvided,
normalize_actions: Optional[bool] = NotProvided,
clip_actions: Optional[bool] = NotProvided,
disable_env_checking: Optional[bool] = NotProvided,
is_atari: Optional[bool] = NotProvided,
action_mask_key: Optional[str] = NotProvided,
# Deprecated args.
env_task_fn=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the config's RL-environment settings.
Args:
env: The environment specifier. This can either be a tune-registered env,
via `tune.register_env([name], lambda env_ctx: [env object])`,
or a string specifier of an RLlib supported type. In the latter case,
RLlib tries to interpret the specifier as either an Farama-Foundation
gymnasium env, a PyBullet env, or a fully qualified classpath to an Env
class, e.g. "ray.rllib.examples.envs.classes.random_env.RandomEnv".
env_config: Arguments dict passed to the env creator as an EnvContext
object (which is a dict plus the properties: `num_env_runners`,
`worker_index`, `vector_index`, and `remote`).
observation_space: The observation space for the Policies of this Algorithm.
action_space: The action space for the Policies of this Algorithm.
render_env: If True, try to render the environment on the local worker or on
worker 1 (if num_env_runners > 0). For vectorized envs, this usually
means that only the first sub-environment is rendered.
In order for this to work, your env has to implement the
`render()` method which either:
a) handles window generation and rendering itself (returning True) or
b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
clip_rewards: Whether to clip rewards during Policy's postprocessing.
None (default): Clip for Atari only (r=sign(r)).
True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0.
False: Never clip.
[float value]: Clip at -value and + value.
Tuple[value1, value2]: Clip at value1 and value2.
normalize_actions: If True, RLlib learns entirely inside a normalized
action space (0.0 centered with small stddev; only affecting Box
components). RLlib unsquashes actions (and clip, just in case) to the
bounds of the env's action space before sending actions back to the env.
clip_actions: If True, the RLlib default ModuleToEnv connector clips
actions according to the env's bounds (before sending them into the
`env.step()` call).
disable_env_checking: Disable RLlib's env checks after a gymnasium.Env
instance has been constructed in an EnvRunner. Note that the checks
include an `env.reset()` and `env.step()` (with a random action), which
might tinker with your env's logic and behavior and thus negatively
influence sample collection- and/or learning behavior.
is_atari: This config can be used to explicitly specify whether the env is
an Atari env or not. If not specified, RLlib tries to auto-detect
this.
action_mask_key: If observation is a dictionary, expect the value by
the key `action_mask_key` to contain a valid actions mask (`numpy.int8`
array of zeros and ones). Defaults to "action_mask".
Returns:
This updated AlgorithmConfig object.
"""
if env_task_fn != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.environment(env_task_fn=..)",
error=True,
)
if env is not NotProvided:
self.env = env
if env_config is not NotProvided:
deep_update(self.env_config, env_config, True)
if observation_space is not NotProvided:
self.observation_space = observation_space
if action_space is not NotProvided:
self.action_space = action_space
if render_env is not NotProvided:
self.render_env = render_env
if clip_rewards is not NotProvided:
self.clip_rewards = clip_rewards
if normalize_actions is not NotProvided:
self.normalize_actions = normalize_actions
if clip_actions is not NotProvided:
self.clip_actions = clip_actions
if disable_env_checking is not NotProvided:
self.disable_env_checking = disable_env_checking
if is_atari is not NotProvided:
self._is_atari = is_atari
if action_mask_key is not NotProvided:
self.action_mask_key = action_mask_key
return self
[docs]
def env_runners(
self,
*,
env_runner_cls: Optional[type] = NotProvided,
num_env_runners: Optional[int] = NotProvided,
num_envs_per_env_runner: Optional[int] = NotProvided,
gym_env_vectorize_mode: Optional[str] = NotProvided,
num_cpus_per_env_runner: Optional[int] = NotProvided,
num_gpus_per_env_runner: Optional[Union[float, int]] = NotProvided,
custom_resources_per_env_runner: Optional[dict] = NotProvided,
validate_env_runners_after_construction: Optional[bool] = NotProvided,
sample_timeout_s: Optional[float] = NotProvided,
max_requests_in_flight_per_env_runner: Optional[int] = NotProvided,
env_to_module_connector: Optional[
Callable[[EnvType], Union["ConnectorV2", List["ConnectorV2"]]]
] = NotProvided,
module_to_env_connector: Optional[
Callable[[EnvType, "RLModule"], Union["ConnectorV2", List["ConnectorV2"]]]
] = NotProvided,
add_default_connectors_to_env_to_module_pipeline: Optional[bool] = NotProvided,
add_default_connectors_to_module_to_env_pipeline: Optional[bool] = NotProvided,
episode_lookback_horizon: Optional[int] = NotProvided,
use_worker_filter_stats: Optional[bool] = NotProvided,
update_worker_filter_stats: Optional[bool] = NotProvided,
compress_observations: Optional[bool] = NotProvided,
rollout_fragment_length: Optional[Union[int, str]] = NotProvided,
batch_mode: Optional[str] = NotProvided,
explore: Optional[bool] = NotProvided,
episodes_to_numpy: Optional[bool] = NotProvided,
# @OldAPIStack settings.
exploration_config: Optional[dict] = NotProvided, # @OldAPIStack
create_env_on_local_worker: Optional[bool] = NotProvided, # @OldAPIStack
sample_collector: Optional[Type[SampleCollector]] = NotProvided, # @OldAPIStack
remote_worker_envs: Optional[bool] = NotProvided, # @OldAPIStack
remote_env_batch_wait_ms: Optional[float] = NotProvided, # @OldAPIStack
preprocessor_pref: Optional[str] = NotProvided, # @OldAPIStack
observation_filter: Optional[str] = NotProvided, # @OldAPIStack
enable_tf1_exec_eagerly: Optional[bool] = NotProvided, # @OldAPIStack
sampler_perf_stats_ema_coef: Optional[float] = NotProvided, # @OldAPIStack
# Deprecated args.
num_rollout_workers=DEPRECATED_VALUE,
num_envs_per_worker=DEPRECATED_VALUE,
validate_workers_after_construction=DEPRECATED_VALUE,
ignore_worker_failures=DEPRECATED_VALUE,
recreate_failed_workers=DEPRECATED_VALUE,
restart_failed_sub_environments=DEPRECATED_VALUE,
num_consecutive_worker_failures_tolerance=DEPRECATED_VALUE,
worker_health_probe_timeout_s=DEPRECATED_VALUE,
worker_restore_timeout_s=DEPRECATED_VALUE,
synchronize_filter=DEPRECATED_VALUE,
enable_connectors=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the rollout worker configuration.
Args:
env_runner_cls: The EnvRunner class to use for environment rollouts (data
collection).
num_env_runners: Number of EnvRunner actors to create for parallel sampling.
Setting this to 0 forces sampling to be done in the local
EnvRunner (main process or the Algorithm's actor when using Tune).
num_envs_per_env_runner: Number of environments to step through
(vector-wise) per EnvRunner. This enables batching when computing
actions through RLModule inference, which can improve performance
for inference-bottlenecked workloads.
gym_env_vectorize_mode: The gymnasium vectorization mode for vector envs.
Must be a `gymnasium.envs.registration.VectorizeMode` (enum) value.
Default is SYNC. Set this to ASYNC to parallelize the individual sub
environments within the vector. This can speed up your EnvRunners
significantly when using heavier environments.
num_cpus_per_env_runner: Number of CPUs to allocate per EnvRunner.
num_gpus_per_env_runner: Number of GPUs to allocate per EnvRunner. This can
be fractional. This is usually needed only if your env itself requires a
GPU (i.e., it is a GPU-intensive video game), or model inference is
unusually expensive.
custom_resources_per_env_runner: Any custom Ray resources to allocate per
EnvRunner.
sample_timeout_s: The timeout in seconds for calling `sample()` on remote
EnvRunner workers. Results (episode list) from workers that take longer
than this time are discarded. Only used by algorithms that sample
synchronously in turn with their update step (e.g., PPO or DQN). Not
relevant for any algos that sample asynchronously, such as APPO or
IMPALA.
max_requests_in_flight_per_env_runner: Max number of in-flight requests
to each EnvRunner (actor)). See the
`ray.rllib.utils.actor_manager.FaultTolerantActorManager` class for more
details.
Tuning these values is important when running experiments with
large sample batches, where there is the risk that the object store may
fill up, causing spilling of objects to disk. This can cause any
asynchronous requests to become very slow, making your experiment run
slowly as well. You can inspect the object store during your experiment
through a call to `ray memory` on your head node, and by using the Ray
dashboard. If you're seeing that the object store is filling up,
turn down the number of remote requests in flight or enable compression
or increase the object store memory through, for example:
`ray.init(object_store_memory=10 * 1024 * 1024 * 1024) # =10 GB`
sample_collector: For the old API stack only. The SampleCollector class to
be used to collect and retrieve environment-, model-, and sampler data.
Override the SampleCollector base class to implement your own
collection/buffering/retrieval logic.
create_env_on_local_worker: When `num_env_runners` > 0, the driver
(local_worker; worker-idx=0) does not need an environment. This is
because it doesn't have to sample (done by remote_workers;
worker_indices > 0) nor evaluate (done by evaluation workers;
see below).
env_to_module_connector: A callable taking an Env as input arg and returning
an env-to-module ConnectorV2 (might be a pipeline) object.
module_to_env_connector: A callable taking an Env and an RLModule as input
args and returning a module-to-env ConnectorV2 (might be a pipeline)
object.
add_default_connectors_to_env_to_module_pipeline: If True (default), RLlib's
EnvRunners automatically add the default env-to-module ConnectorV2
pieces to the EnvToModulePipeline. These automatically perform adding
observations and states (in case of stateful Module(s)), agent-to-module
mapping, batching, and conversion to tensor data. Only if you know
exactly what you are doing, you should set this setting to False.
Note that this setting is only relevant if the new API stack is used
(including the new EnvRunner classes).
add_default_connectors_to_module_to_env_pipeline: If True (default), RLlib's
EnvRunners automatically add the default module-to-env ConnectorV2
pieces to the ModuleToEnvPipeline. These automatically perform removing
the additional time-rank (if applicable, in case of stateful
Module(s)), module-to-agent unmapping, un-batching (to lists), and
conversion from tensor data to numpy. Only if you know exactly what you
are doing, you should set this setting to False.
Note that this setting is only relevant if the new API stack is used
(including the new EnvRunner classes).
episode_lookback_horizon: The amount of data (in timesteps) to keep from the
preceeding episode chunk when a new chunk (for the same episode) is
generated to continue sampling at a later time. The larger this value,
the more an env-to-module connector can look back in time
and compile RLModule input data from this information. For example, if
your custom env-to-module connector (and your custom RLModule) requires
the previous 10 rewards as inputs, you must set this to at least 10.
use_worker_filter_stats: Whether to use the workers in the EnvRunnerGroup to
update the central filters (held by the local worker). If False, stats
from the workers aren't used and are discarded.
update_worker_filter_stats: Whether to push filter updates from the central
filters (held by the local worker) to the remote workers' filters.
Setting this to True might be useful within the evaluation config in
order to disable the usage of evaluation trajectories for synching
the central filter (used for training).
rollout_fragment_length: Divide episodes into fragments of this many steps
each during sampling. Trajectories of this size are collected from
EnvRunners and combined into a larger batch of `train_batch_size`
for learning.
For example, given rollout_fragment_length=100 and
train_batch_size=1000:
1. RLlib collects 10 fragments of 100 steps each from rollout workers.
2. These fragments are concatenated and we perform an epoch of SGD.
When using multiple envs per worker, the fragment size is multiplied by
`num_envs_per_env_runner`. This is since we are collecting steps from
multiple envs in parallel. For example, if num_envs_per_env_runner=5,
then EnvRunners return experiences in chunks of 5*100 = 500 steps.
The dataflow here can vary per algorithm. For example, PPO further
divides the train batch into minibatches for multi-epoch SGD.
Set `rollout_fragment_length` to "auto" to have RLlib compute an exact
value to match the given batch size.
batch_mode: How to build individual batches with the EnvRunner(s). Batches
coming from distributed EnvRunners are usually concat'd to form the
train batch. Note that "steps" below can mean different things (either
env- or agent-steps) and depends on the `count_steps_by` setting,
adjustable via `AlgorithmConfig.multi_agent(count_steps_by=..)`:
1) "truncate_episodes": Each call to `EnvRunner.sample()` returns a
batch of at most `rollout_fragment_length * num_envs_per_env_runner` in
size. The batch is exactly `rollout_fragment_length * num_envs`
in size if postprocessing does not change batch sizes. Episodes
may be truncated in order to meet this size requirement.
This mode guarantees evenly sized batches, but increases
variance as the future return must now be estimated at truncation
boundaries.
2) "complete_episodes": Each call to `EnvRunner.sample()` returns a
batch of at least `rollout_fragment_length * num_envs_per_env_runner` in
size. Episodes aren't truncated, but multiple episodes
may be packed within one batch to meet the (minimum) batch size.
Note that when `num_envs_per_env_runner > 1`, episode steps are
buffered until the episode completes, and hence batches may contain
significant amounts of off-policy data.
explore: Default exploration behavior, iff `explore=None` is passed into
compute_action(s). Set to False for no exploration behavior (e.g.,
for evaluation).
episodes_to_numpy: Whether to numpy'ize episodes before
returning them from an EnvRunner. False by default. If True, EnvRunners
call `to_numpy()` on those episode (chunks) to be returned by
`EnvRunners.sample()`.
exploration_config: A dict specifying the Exploration object's config.
remote_worker_envs: If using num_envs_per_env_runner > 1, whether to create
those new envs in remote processes instead of in the same worker.
This adds overheads, but can make sense if your envs can take much
time to step / reset (e.g., for StarCraft). Use this cautiously;
overheads are significant.
remote_env_batch_wait_ms: Timeout that remote workers are waiting when
polling environments. 0 (continue when at least one env is ready) is
a reasonable default, but optimal value could be obtained by measuring
your environment step / reset and model inference perf.
validate_env_runners_after_construction: Whether to validate that each
created remote EnvRunner is healthy after its construction process.
preprocessor_pref: Whether to use "rllib" or "deepmind" preprocessors by
default. Set to None for using no preprocessor. In this case, the
model has to handle possibly complex observations from the
environment.
observation_filter: Element-wise observation filter, either "NoFilter"
or "MeanStdFilter".
compress_observations: Whether to LZ4 compress individual observations
in the SampleBatches collected during rollouts.
enable_tf1_exec_eagerly: Explicitly tells the rollout worker to enable
TF eager execution. This is useful for example when framework is
"torch", but a TF2 policy needs to be restored for evaluation or
league-based purposes.
sampler_perf_stats_ema_coef: If specified, perf stats are in EMAs. This
is the coeff of how much new data points contribute to the averages.
Default is None, which uses simple global average instead.
The EMA update rule is: updated = (1 - ema_coef) * old + ema_coef * new
Returns:
This updated AlgorithmConfig object.
"""
if enable_connectors != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(enable_connectors=...)",
error=False,
)
if num_rollout_workers != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(num_rollout_workers)",
new="AlgorithmConfig.env_runners(num_env_runners)",
error=True,
)
if num_envs_per_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(num_envs_per_worker)",
new="AlgorithmConfig.env_runners(num_envs_per_env_runner)",
error=True,
)
if validate_workers_after_construction != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(validate_workers_after_construction)",
new="AlgorithmConfig.env_runners(validate_env_runners_after_"
"construction)",
error=True,
)
if env_runner_cls is not NotProvided:
self.env_runner_cls = env_runner_cls
if num_env_runners is not NotProvided:
self.num_env_runners = num_env_runners
if num_envs_per_env_runner is not NotProvided:
if num_envs_per_env_runner <= 0:
raise ValueError(
f"`num_envs_per_env_runner` ({num_envs_per_env_runner}) must be "
"larger 0!"
)
self.num_envs_per_env_runner = num_envs_per_env_runner
if gym_env_vectorize_mode is not NotProvided:
self.gym_env_vectorize_mode = gym_env_vectorize_mode
if num_cpus_per_env_runner is not NotProvided:
self.num_cpus_per_env_runner = num_cpus_per_env_runner
if num_gpus_per_env_runner is not NotProvided:
self.num_gpus_per_env_runner = num_gpus_per_env_runner
if custom_resources_per_env_runner is not NotProvided:
self.custom_resources_per_env_runner = custom_resources_per_env_runner
if sample_timeout_s is not NotProvided:
self.sample_timeout_s = sample_timeout_s
if max_requests_in_flight_per_env_runner is not NotProvided:
self.max_requests_in_flight_per_env_runner = (
max_requests_in_flight_per_env_runner
)
if sample_collector is not NotProvided:
self.sample_collector = sample_collector
if create_env_on_local_worker is not NotProvided:
self.create_env_on_local_worker = create_env_on_local_worker
if env_to_module_connector is not NotProvided:
self._env_to_module_connector = env_to_module_connector
if module_to_env_connector is not NotProvided:
self._module_to_env_connector = module_to_env_connector
if add_default_connectors_to_env_to_module_pipeline is not NotProvided:
self.add_default_connectors_to_env_to_module_pipeline = (
add_default_connectors_to_env_to_module_pipeline
)
if add_default_connectors_to_module_to_env_pipeline is not NotProvided:
self.add_default_connectors_to_module_to_env_pipeline = (
add_default_connectors_to_module_to_env_pipeline
)
if episode_lookback_horizon is not NotProvided:
self.episode_lookback_horizon = episode_lookback_horizon
if use_worker_filter_stats is not NotProvided:
self.use_worker_filter_stats = use_worker_filter_stats
if update_worker_filter_stats is not NotProvided:
self.update_worker_filter_stats = update_worker_filter_stats
if rollout_fragment_length is not NotProvided:
if not (
(
isinstance(rollout_fragment_length, int)
and rollout_fragment_length > 0
)
or rollout_fragment_length == "auto"
):
raise ValueError("`rollout_fragment_length` must be int >0 or 'auto'!")
self.rollout_fragment_length = rollout_fragment_length
if batch_mode is not NotProvided:
if batch_mode not in ["truncate_episodes", "complete_episodes"]:
raise ValueError(
f"`batch_mode` ({batch_mode}) must be one of [truncate_episodes|"
"complete_episodes]!"
)
self.batch_mode = batch_mode
if explore is not NotProvided:
self.explore = explore
if episodes_to_numpy is not NotProvided:
self.episodes_to_numpy = episodes_to_numpy
# @OldAPIStack
if exploration_config is not NotProvided:
# Override entire `exploration_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
new_exploration_config = deep_update(
{"exploration_config": self.exploration_config},
{"exploration_config": exploration_config},
False,
["exploration_config"],
["exploration_config"],
)
self.exploration_config = new_exploration_config["exploration_config"]
if remote_worker_envs is not NotProvided:
self.remote_worker_envs = remote_worker_envs
if remote_env_batch_wait_ms is not NotProvided:
self.remote_env_batch_wait_ms = remote_env_batch_wait_ms
if validate_env_runners_after_construction is not NotProvided:
self.validate_env_runners_after_construction = (
validate_env_runners_after_construction
)
if preprocessor_pref is not NotProvided:
self.preprocessor_pref = preprocessor_pref
if observation_filter is not NotProvided:
self.observation_filter = observation_filter
if synchronize_filter is not NotProvided:
self.synchronize_filters = synchronize_filter
if compress_observations is not NotProvided:
self.compress_observations = compress_observations
if enable_tf1_exec_eagerly is not NotProvided:
self.enable_tf1_exec_eagerly = enable_tf1_exec_eagerly
if sampler_perf_stats_ema_coef is not NotProvided:
self.sampler_perf_stats_ema_coef = sampler_perf_stats_ema_coef
# Deprecated settings.
if synchronize_filter != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(synchronize_filter=..)",
new="AlgorithmConfig.env_runners(update_worker_filter_stats=..)",
error=True,
)
if ignore_worker_failures != DEPRECATED_VALUE:
deprecation_warning(
old="ignore_worker_failures is deprecated, and will soon be a no-op",
error=True,
)
if recreate_failed_workers != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(recreate_failed_workers=..)",
new="AlgorithmConfig.fault_tolerance(recreate_failed_workers=..)",
error=True,
)
if restart_failed_sub_environments != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(restart_failed_sub_environments=..)",
new=(
"AlgorithmConfig.fault_tolerance("
"restart_failed_sub_environments=..)"
),
error=True,
)
if num_consecutive_worker_failures_tolerance != DEPRECATED_VALUE:
deprecation_warning(
old=(
"AlgorithmConfig.env_runners("
"num_consecutive_worker_failures_tolerance=..)"
),
new=(
"AlgorithmConfig.fault_tolerance("
"num_consecutive_worker_failures_tolerance=..)"
),
error=True,
)
if worker_health_probe_timeout_s != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(worker_health_probe_timeout_s=..)",
new="AlgorithmConfig.fault_tolerance(worker_health_probe_timeout_s=..)",
error=True,
)
if worker_restore_timeout_s != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.env_runners(worker_restore_timeout_s=..)",
new="AlgorithmConfig.fault_tolerance(worker_restore_timeout_s=..)",
error=True,
)
return self
[docs]
def learners(
self,
*,
num_learners: Optional[int] = NotProvided,
num_cpus_per_learner: Optional[Union[float, int]] = NotProvided,
num_gpus_per_learner: Optional[Union[float, int]] = NotProvided,
num_aggregator_actors_per_learner: Optional[int] = NotProvided,
max_requests_in_flight_per_aggregator_actor: Optional[float] = NotProvided,
local_gpu_idx: Optional[int] = NotProvided,
max_requests_in_flight_per_learner: Optional[int] = NotProvided,
):
"""Sets LearnerGroup and Learner worker related configurations.
Args:
num_learners: Number of Learner workers used for updating the RLModule.
A value of 0 means training takes place on a local Learner on main
process CPUs or 1 GPU (determined by `num_gpus_per_learner`).
For multi-gpu training, you have to set `num_learners` to > 1 and set
`num_gpus_per_learner` accordingly (e.g., 4 GPUs total and model fits on
1 GPU: `num_learners=4; num_gpus_per_learner=1` OR 4 GPUs total and
model requires 2 GPUs: `num_learners=2; num_gpus_per_learner=2`).
num_cpus_per_learner: Number of CPUs allocated per Learner worker.
Only necessary for custom processing pipeline inside each Learner
requiring multiple CPU cores. Ignored if `num_learners=0`.
num_gpus_per_learner: Number of GPUs allocated per Learner worker. If
`num_learners=0`, any value greater than 0 runs the
training on a single GPU on the main process, while a value of 0 runs
the training on main process CPUs. If `num_gpus_per_learner` is > 0,
then you shouldn't change `num_cpus_per_learner` (from its default
value of 1).
num_aggregator_actors_per_learner: The number of aggregator actors per
Learner (if num_learners=0, one local learner is created). Must be at
least 1. Aggregator actors perform the task of a) converting episodes
into a train batch and b) move that train batch to the same GPU that
the corresponding learner is located on. Good values are 1 or 2, but
this strongly depends on your setup and `EnvRunner` throughput.
max_requests_in_flight_per_aggregator_actor: How many in-flight requests
are allowed per aggregator actor before new requests are dropped?
local_gpu_idx: If `num_gpus_per_learner` > 0, and
`num_learners` < 2, then RLlib uses this GPU index for training. This is
an index into the available
CUDA devices. For example if `os.environ["CUDA_VISIBLE_DEVICES"] = "1"`
and `local_gpu_idx=0`, RLlib uses the GPU with ID=1 on the node.
max_requests_in_flight_per_learner: Max number of in-flight requests
to each Learner (actor). You normally do not have to tune this setting
(default is 3), however, for asynchronous algorithms, this determines
the "queue" size for incoming batches (or lists of episodes) into each
Learner worker, thus also determining, how much off-policy'ness would be
acceptable. The off-policy'ness is the difference between the numbers of
updates a policy has undergone on the Learner vs the EnvRunners.
See the `ray.rllib.utils.actor_manager.FaultTolerantActorManager` class
for more details.
Returns:
This updated AlgorithmConfig object.
"""
if num_learners is not NotProvided:
self.num_learners = num_learners
if num_cpus_per_learner is not NotProvided:
self.num_cpus_per_learner = num_cpus_per_learner
if num_gpus_per_learner is not NotProvided:
self.num_gpus_per_learner = num_gpus_per_learner
if num_aggregator_actors_per_learner is not NotProvided:
self.num_aggregator_actors_per_learner = num_aggregator_actors_per_learner
if max_requests_in_flight_per_aggregator_actor is not NotProvided:
self.max_requests_in_flight_per_aggregator_actor = (
max_requests_in_flight_per_aggregator_actor
)
if local_gpu_idx is not NotProvided:
self.local_gpu_idx = local_gpu_idx
if max_requests_in_flight_per_learner is not NotProvided:
self.max_requests_in_flight_per_learner = max_requests_in_flight_per_learner
return self
[docs]
def training(
self,
*,
gamma: Optional[float] = NotProvided,
lr: Optional[LearningRateOrSchedule] = NotProvided,
grad_clip: Optional[float] = NotProvided,
grad_clip_by: Optional[str] = NotProvided,
train_batch_size: Optional[int] = NotProvided,
train_batch_size_per_learner: Optional[int] = NotProvided,
num_epochs: Optional[int] = NotProvided,
minibatch_size: Optional[int] = NotProvided,
shuffle_batch_per_epoch: Optional[bool] = NotProvided,
model: Optional[dict] = NotProvided,
optimizer: Optional[dict] = NotProvided,
learner_class: Optional[Type["Learner"]] = NotProvided,
learner_connector: Optional[
Callable[["RLModule"], Union["ConnectorV2", List["ConnectorV2"]]]
] = NotProvided,
add_default_connectors_to_learner_pipeline: Optional[bool] = NotProvided,
learner_config_dict: Optional[Dict[str, Any]] = NotProvided,
# Deprecated args.
num_aggregator_actors_per_learner=DEPRECATED_VALUE,
max_requests_in_flight_per_aggregator_actor=DEPRECATED_VALUE,
num_sgd_iter=DEPRECATED_VALUE,
max_requests_in_flight_per_sampler_worker=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the training related configuration.
Args:
gamma: Float specifying the discount factor of the Markov Decision process.
lr: The learning rate (float) or learning rate schedule in the format of
[[timestep, lr-value], [timestep, lr-value], ...]
In case of a schedule, intermediary timesteps are assigned to
linearly interpolated learning rate values. A schedule config's first
entry must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
grad_clip: If None, no gradient clipping is applied. Otherwise,
depending on the setting of `grad_clip_by`, the (float) value of
`grad_clip` has the following effect:
If `grad_clip_by=value`: Clips all computed gradients individually
inside the interval [-`grad_clip`, +`grad_clip`].
If `grad_clip_by=norm`, computes the L2-norm of each weight/bias
gradient tensor individually and then clip all gradients such that these
L2-norms do not exceed `grad_clip`. The L2-norm of a tensor is computed
via: `sqrt(SUM(w0^2, w1^2, ..., wn^2))` where w[i] are the elements of
the tensor (no matter what the shape of this tensor is).
If `grad_clip_by=global_norm`, computes the square of the L2-norm of
each weight/bias gradient tensor individually, sum up all these squared
L2-norms across all given gradient tensors (e.g. the entire module to
be updated), square root that overall sum, and then clip all gradients
such that this global L2-norm does not exceed the given value.
The global L2-norm over a list of tensors (e.g. W and V) is computed
via:
`sqrt[SUM(w0^2, w1^2, ..., wn^2) + SUM(v0^2, v1^2, ..., vm^2)]`, where
w[i] and v[j] are the elements of the tensors W and V (no matter what
the shapes of these tensors are).
grad_clip_by: See `grad_clip` for the effect of this setting on gradient
clipping. Allowed values are `value`, `norm`, and `global_norm`.
train_batch_size_per_learner: Train batch size per individual Learner
worker. This setting only applies to the new API stack. The number
of Learner workers can be set via `config.resources(
num_learners=...)`. The total effective batch size is then
`num_learners` x `train_batch_size_per_learner` and you can
access it with the property `AlgorithmConfig.total_train_batch_size`.
train_batch_size: Training batch size, if applicable. When on the new API
stack, this setting should no longer be used. Instead, use
`train_batch_size_per_learner` (in combination with
`num_learners`).
num_epochs: The number of complete passes over the entire train batch (per
Learner). Each pass might be further split into n minibatches (if
`minibatch_size` provided).
minibatch_size: The size of minibatches to use to further split the train
batch into.
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
If the train batch has a time rank (axis=1), shuffling only takes
place along the batch axis to not disturb any intact (episode)
trajectories.
model: Arguments passed into the policy model. See models/catalog.py for a
full list of the available model options.
TODO: Provide ModelConfig objects instead of dicts.
optimizer: Arguments to pass to the policy optimizer. This setting is not
used when `enable_rl_module_and_learner=True`.
learner_class: The `Learner` class to use for (distributed) updating of the
RLModule. Only used when `enable_rl_module_and_learner=True`.
learner_connector: A callable taking an env observation space and an env
action space as inputs and returning a learner ConnectorV2 (might be
a pipeline) object.
add_default_connectors_to_learner_pipeline: If True (default), RLlib's
Learners automatically add the default Learner ConnectorV2
pieces to the LearnerPipeline. These automatically perform:
a) adding observations from episodes to the train batch, if this has not
already been done by a user-provided connector piece
b) if RLModule is stateful, add a time rank to the train batch, zero-pad
the data, and add the correct state inputs, if this has not already been
done by a user-provided connector piece.
c) add all other information (actions, rewards, terminateds, etc..) to
the train batch, if this has not already been done by a user-provided
connector piece.
Only if you know exactly what you are doing, you
should set this setting to False.
Note that this setting is only relevant if the new API stack is used
(including the new EnvRunner classes).
learner_config_dict: A dict to insert any settings accessible from within
the Learner instance. This should only be used in connection with custom
Learner subclasses and in case the user doesn't want to write an extra
`AlgorithmConfig` subclass just to add a few settings to the base Algo's
own config class.
Returns:
This updated AlgorithmConfig object.
"""
if num_aggregator_actors_per_learner != DEPRECATED_VALUE:
deprecation_warning(
old="config.training(num_aggregator_actors_per_learner=..)",
new="config.learners(num_aggregator_actors_per_learner=..)",
error=False,
)
self.num_aggregator_actors_per_learner = num_aggregator_actors_per_learner
if max_requests_in_flight_per_aggregator_actor != DEPRECATED_VALUE:
deprecation_warning(
old="config.training(max_requests_in_flight_per_aggregator_actor=..)",
new="config.learners(max_requests_in_flight_per_aggregator_actor=..)",
error=False,
)
self.max_requests_in_flight_per_aggregator_actor = (
max_requests_in_flight_per_aggregator_actor
)
if num_sgd_iter != DEPRECATED_VALUE:
deprecation_warning(
old="config.training(num_sgd_iter=..)",
new="config.training(num_epochs=..)",
error=False,
)
num_epochs = num_sgd_iter
if max_requests_in_flight_per_sampler_worker != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.training("
"max_requests_in_flight_per_sampler_worker=...)",
new="AlgorithmConfig.env_runners("
"max_requests_in_flight_per_env_runner=...)",
error=False,
)
self.env_runners(
max_requests_in_flight_per_env_runner=(
max_requests_in_flight_per_sampler_worker
),
)
if gamma is not NotProvided:
self.gamma = gamma
if lr is not NotProvided:
self.lr = lr
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if grad_clip_by is not NotProvided:
if grad_clip_by not in ["value", "norm", "global_norm"]:
raise ValueError(
f"`grad_clip_by` ({grad_clip_by}) must be one of: 'value', 'norm', "
"or 'global_norm'!"
)
self.grad_clip_by = grad_clip_by
if train_batch_size_per_learner is not NotProvided:
self._train_batch_size_per_learner = train_batch_size_per_learner
if train_batch_size is not NotProvided:
self.train_batch_size = train_batch_size
if num_epochs is not NotProvided:
self.num_epochs = num_epochs
if minibatch_size is not NotProvided:
self.minibatch_size = minibatch_size
if shuffle_batch_per_epoch is not NotProvided:
self.shuffle_batch_per_epoch = shuffle_batch_per_epoch
if model is not NotProvided:
self.model.update(model)
if (
model.get("_use_default_native_models", DEPRECATED_VALUE)
!= DEPRECATED_VALUE
):
deprecation_warning(
old="AlgorithmConfig.training(_use_default_native_models=True)",
help="_use_default_native_models is not supported "
"anymore. To get rid of this error, set `config.api_stack("
"enable_rl_module_and_learner=True)`. Native models will "
"be better supported by the upcoming RLModule API.",
# Error out if user tries to enable this.
error=model["_use_default_native_models"],
)
if optimizer is not NotProvided:
self.optimizer = merge_dicts(self.optimizer, optimizer)
if learner_class is not NotProvided:
self._learner_class = learner_class
if learner_connector is not NotProvided:
self._learner_connector = learner_connector
if add_default_connectors_to_learner_pipeline is not NotProvided:
self.add_default_connectors_to_learner_pipeline = (
add_default_connectors_to_learner_pipeline
)
if learner_config_dict is not NotProvided:
self.learner_config_dict.update(learner_config_dict)
return self
[docs]
def callbacks(
self,
callbacks_class: Optional[
Union[Type[RLlibCallback], List[Type[RLlibCallback]]]
] = NotProvided,
*,
on_algorithm_init: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_train_result: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_evaluate_start: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_evaluate_end: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_env_runners_recreated: Optional[
Union[Callable, List[Callable]]
] = NotProvided,
on_checkpoint_loaded: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_environment_created: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_episode_created: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_episode_start: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_episode_step: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_episode_end: Optional[Union[Callable, List[Callable]]] = NotProvided,
on_sample_end: Optional[Union[Callable, List[Callable]]] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the callbacks configuration.
Args:
callbacks_class: RLlibCallback class, whose methods are called during
various phases of training and RL environment sample collection.
TODO (sven): Change the link to new rst callbacks page.
See the `RLlibCallback` class and
`examples/metrics/custom_metrics_and_callbacks.py` for more information.
on_algorithm_init: A callable or a list of callables. If a list, RLlib calls
the items in the same sequence. `on_algorithm_init` methods overridden
in `callbacks_class` take precedence and are called first.
See
:py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_algorithm_init` # noqa
for more information.
on_evaluate_start: A callable or a list of callables. If a list, RLlib calls
the items in the same sequence. `on_evaluate_start` methods overridden
in `callbacks_class` take precedence and are called first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_evaluate_start` # noqa
for more information.
on_evaluate_end: A callable or a list of callables. If a list, RLlib calls
the items in the same sequence. `on_evaluate_end` methods overridden
in `callbacks_class` take precedence and are called first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_evaluate_end` # noqa
for more information.
on_env_runners_recreated: A callable or a list of callables. If a list,
RLlib calls the items in the same sequence. `on_env_runners_recreated`
methods overridden in `callbacks_class` take precedence and are called
first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_env_runners_recreated` # noqa
for more information.
on_checkpoint_loaded: A callable or a list of callables. If a list,
RLlib calls the items in the same sequence. `on_checkpoint_loaded`
methods overridden in `callbacks_class` take precedence and are called
first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_checkpoint_loaded` # noqa
for more information.
on_environment_created: A callable or a list of callables. If a list,
RLlib calls the items in the same sequence. `on_environment_created`
methods overridden in `callbacks_class` take precedence and are called
first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_environment_created` # noqa
for more information.
on_episode_created: A callable or a list of callables. If a list,
RLlib calls the items in the same sequence. `on_episode_created` methods
overridden in `callbacks_class` take precedence and are called first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_episode_created` # noqa
for more information.
on_episode_start: A callable or a list of callables. If a list,
RLlib calls the items in the same sequence. `on_episode_start` methods
overridden in `callbacks_class` take precedence and are called first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_episode_start` # noqa
for more information.
on_episode_step: A callable or a list of callables. If a list,
RLlib calls the items in the same sequence. `on_episode_step` methods
overridden in `callbacks_class` take precedence and are called first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_episode_step` # noqa
for more information.
on_episode_end: A callable or a list of callables. If a list,
RLlib calls the items in the same sequence. `on_episode_end` methods
overridden in `callbacks_class` take precedence and are called first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_episode_end` # noqa
for more information.
on_sample_end: A callable or a list of callables. If a list,
RLlib calls the items in the same sequence. `on_sample_end` methods
overridden in `callbacks_class` take precedence and are called first.
See :py:meth:`~ray.rllib.callbacks.callbacks.RLlibCallback.on_sample_end` # noqa
for more information.
Returns:
This updated AlgorithmConfig object.
"""
if callbacks_class is None:
callbacks_class = RLlibCallback
if callbacks_class is not NotProvided:
# Check, whether given `callbacks` is a callable.
# TODO (sven): Once the old API stack is deprecated, this can also be None
# (which should then become the default value for this attribute).
if not callable(callbacks_class):
raise ValueError(
"`config.callbacks_class` must be a callable method that "
"returns a subclass of DefaultCallbacks, got "
f"{callbacks_class}!"
)
self.callbacks_class = callbacks_class
if on_algorithm_init is not NotProvided:
self.callbacks_on_algorithm_init = on_algorithm_init
if on_train_result is not NotProvided:
self.callbacks_on_train_result = on_train_result
if on_evaluate_start is not NotProvided:
self.callbacks_on_evaluate_start = on_evaluate_start
if on_evaluate_end is not NotProvided:
self.callbacks_on_evaluate_end = on_evaluate_end
if on_env_runners_recreated is not NotProvided:
self.callbacks_on_env_runners_recreated = on_env_runners_recreated
if on_checkpoint_loaded is not NotProvided:
self.callbacks_on_checkpoint_loaded = on_checkpoint_loaded
if on_environment_created is not NotProvided:
self.callbacks_on_environment_created = on_environment_created
if on_episode_created is not NotProvided:
self.callbacks_on_episode_created = on_episode_created
if on_episode_start is not NotProvided:
self.callbacks_on_episode_start = on_episode_start
if on_episode_step is not NotProvided:
self.callbacks_on_episode_step = on_episode_step
if on_episode_end is not NotProvided:
self.callbacks_on_episode_end = on_episode_end
if on_sample_end is not NotProvided:
self.callbacks_on_sample_end = on_sample_end
return self
[docs]
def evaluation(
self,
*,
evaluation_interval: Optional[int] = NotProvided,
evaluation_duration: Optional[Union[int, str]] = NotProvided,
evaluation_duration_unit: Optional[str] = NotProvided,
evaluation_sample_timeout_s: Optional[float] = NotProvided,
evaluation_parallel_to_training: Optional[bool] = NotProvided,
evaluation_force_reset_envs_before_iteration: Optional[bool] = NotProvided,
evaluation_config: Optional[
Union["AlgorithmConfig", PartialAlgorithmConfigDict]
] = NotProvided,
off_policy_estimation_methods: Optional[Dict] = NotProvided,
ope_split_batch_by_episode: Optional[bool] = NotProvided,
evaluation_num_env_runners: Optional[int] = NotProvided,
custom_evaluation_function: Optional[Callable] = NotProvided,
# Deprecated args.
always_attach_evaluation_results=DEPRECATED_VALUE,
evaluation_num_workers=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the config's evaluation settings.
Args:
evaluation_interval: Evaluate with every `evaluation_interval` training
iterations. The evaluation stats are reported under the "evaluation"
metric key. Set to None (or 0) for no evaluation.
evaluation_duration: Duration for which to run evaluation each
`evaluation_interval`. The unit for the duration can be set via
`evaluation_duration_unit` to either "episodes" (default) or
"timesteps". If using multiple evaluation workers (EnvRunners) in the
`evaluation_num_env_runners > 1` setting, the amount of
episodes/timesteps to run are split amongst these.
A special value of "auto" can be used in case
`evaluation_parallel_to_training=True`. This is the recommended way when
trying to save as much time on evaluation as possible. The Algorithm
then runs as many timesteps via the evaluation workers as possible,
while not taking longer than the parallely running training step and
thus, never wasting any idle time on either training- or evaluation
workers. When using this setting (`evaluation_duration="auto"`), it is
strongly advised to set `evaluation_interval=1` and
`evaluation_force_reset_envs_before_iteration=True` at the same time.
evaluation_duration_unit: The unit, with which to count the evaluation
duration. Either "episodes" (default) or "timesteps". Note that this
setting is ignored if `evaluation_duration="auto"`.
evaluation_sample_timeout_s: The timeout (in seconds) for evaluation workers
to sample a complete episode in the case your config settings are:
`evaluation_duration != auto` and `evaluation_duration_unit=episode`.
After this time, the user receives a warning and instructions on how
to fix the issue.
evaluation_parallel_to_training: Whether to run evaluation in parallel to
the `Algorithm.training_step()` call, using threading. Default=False.
E.g. for evaluation_interval=1 -> In every call to `Algorithm.train()`,
the `Algorithm.training_step()` and `Algorithm.evaluate()` calls
run in parallel. Note that this setting - albeit extremely efficient b/c
it wastes no extra time for evaluation - causes the evaluation results
to lag one iteration behind the rest of the training results. This is
important when picking a good checkpoint. For example, if iteration 42
reports a good evaluation `episode_return_mean`, be aware that these
results were achieved on the weights trained in iteration 41, so you
should probably pick the iteration 41 checkpoint instead.
evaluation_force_reset_envs_before_iteration: Whether all environments
should be force-reset (even if they are not done yet) right before
the evaluation step of the iteration begins. Setting this to True
(default) makes sure that the evaluation results aren't polluted with
episode statistics that were actually (at least partially) achieved with
an earlier set of weights. Note that this setting is only
supported on the new API stack w/ EnvRunners and ConnectorV2
(`config.enable_rl_module_and_learner=True` AND
`config.enable_env_runner_and_connector_v2=True`).
evaluation_config: Typical usage is to pass extra args to evaluation env
creator and to disable exploration by computing deterministic actions.
IMPORTANT NOTE: Policy gradient algorithms are able to find the optimal
policy, even if this is a stochastic one. Setting "explore=False" here
results in the evaluation workers not using this optimal policy!
off_policy_estimation_methods: Specify how to evaluate the current policy,
along with any optional config parameters. This only has an effect when
reading offline experiences ("input" is not "sampler").
Available keys:
{ope_method_name: {"type": ope_type, ...}} where `ope_method_name`
is a user-defined string to save the OPE results under, and
`ope_type` can be any subclass of OffPolicyEstimator, e.g.
ray.rllib.offline.estimators.is::ImportanceSampling
or your own custom subclass, or the full class path to the subclass.
You can also add additional config arguments to be passed to the
OffPolicyEstimator in the dict, e.g.
{"qreg_dr": {"type": DoublyRobust, "q_model_type": "qreg", "k": 5}}
ope_split_batch_by_episode: Whether to use SampleBatch.split_by_episode() to
split the input batch to episodes before estimating the ope metrics. In
case of bandits you should make this False to see improvements in ope
evaluation speed. In case of bandits, it is ok to not split by episode,
since each record is one timestep already. The default is True.
evaluation_num_env_runners: Number of parallel EnvRunners to use for
evaluation. Note that this is set to zero by default, which means
evaluation is run in the algorithm process (only if
`evaluation_interval` is not 0 or None). If you increase this, also
increases the Ray resource usage of the algorithm since evaluation
workers are created separately from those EnvRunners used to sample data
for training.
custom_evaluation_function: Customize the evaluation method. This must be a
function of signature (algo: Algorithm, eval_workers: EnvRunnerGroup) ->
(metrics: dict, env_steps: int, agent_steps: int) (metrics: dict if
`enable_env_runner_and_connector_v2=True`), where `env_steps` and
`agent_steps` define the number of sampled steps during the evaluation
iteration. See the Algorithm.evaluate() method to see the default
implementation. The Algorithm guarantees all eval workers have the
latest policy state before this function is called.
Returns:
This updated AlgorithmConfig object.
"""
if always_attach_evaluation_results != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.evaluation(always_attach_evaluation_results=..)",
help="This setting is no longer needed, b/c Tune does not error "
"anymore (only warns) when a metrics key can't be found in the "
"results.",
error=True,
)
if evaluation_num_workers != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.evaluation(evaluation_num_workers=..)",
new="AlgorithmConfig.evaluation(evaluation_num_env_runners=..)",
error=False,
)
self.evaluation_num_env_runners = evaluation_num_workers
if evaluation_interval is not NotProvided:
self.evaluation_interval = evaluation_interval
if evaluation_duration is not NotProvided:
self.evaluation_duration = evaluation_duration
if evaluation_duration_unit is not NotProvided:
self.evaluation_duration_unit = evaluation_duration_unit
if evaluation_sample_timeout_s is not NotProvided:
self.evaluation_sample_timeout_s = evaluation_sample_timeout_s
if evaluation_parallel_to_training is not NotProvided:
self.evaluation_parallel_to_training = evaluation_parallel_to_training
if evaluation_force_reset_envs_before_iteration is not NotProvided:
self.evaluation_force_reset_envs_before_iteration = (
evaluation_force_reset_envs_before_iteration
)
if evaluation_config is not NotProvided:
# If user really wants to set this to None, we should allow this here,
# instead of creating an empty dict.
if evaluation_config is None:
self.evaluation_config = None
# Update (don't replace) the existing overrides with the provided ones.
else:
from ray.rllib.algorithms.algorithm import Algorithm
self.evaluation_config = deep_update(
self.evaluation_config or {},
evaluation_config,
True,
Algorithm._allow_unknown_subkeys,
Algorithm._override_all_subkeys_if_type_changes,
Algorithm._override_all_key_list,
)
if off_policy_estimation_methods is not NotProvided:
self.off_policy_estimation_methods = off_policy_estimation_methods
if evaluation_num_env_runners is not NotProvided:
self.evaluation_num_env_runners = evaluation_num_env_runners
if custom_evaluation_function is not NotProvided:
self.custom_evaluation_function = custom_evaluation_function
if ope_split_batch_by_episode is not NotProvided:
self.ope_split_batch_by_episode = ope_split_batch_by_episode
return self
[docs]
def offline_data(
self,
*,
input_: Optional[Union[str, Callable[[IOContext], InputReader]]] = NotProvided,
offline_data_class: Optional[Type] = NotProvided,
input_read_method: Optional[Union[str, Callable]] = NotProvided,
input_read_method_kwargs: Optional[Dict] = NotProvided,
input_read_schema: Optional[Dict[str, str]] = NotProvided,
input_read_episodes: Optional[bool] = NotProvided,
input_read_sample_batches: Optional[bool] = NotProvided,
input_read_batch_size: Optional[int] = NotProvided,
input_filesystem: Optional[str] = NotProvided,
input_filesystem_kwargs: Optional[Dict] = NotProvided,
input_compress_columns: Optional[List[str]] = NotProvided,
materialize_data: Optional[bool] = NotProvided,
materialize_mapped_data: Optional[bool] = NotProvided,
map_batches_kwargs: Optional[Dict] = NotProvided,
iter_batches_kwargs: Optional[Dict] = NotProvided,
prelearner_class: Optional[Type] = NotProvided,
prelearner_buffer_class: Optional[Type] = NotProvided,
prelearner_buffer_kwargs: Optional[Dict] = NotProvided,
prelearner_module_synch_period: Optional[int] = NotProvided,
dataset_num_iters_per_learner: Optional[int] = NotProvided,
input_config: Optional[Dict] = NotProvided,
actions_in_input_normalized: Optional[bool] = NotProvided,
postprocess_inputs: Optional[bool] = NotProvided,
shuffle_buffer_size: Optional[int] = NotProvided,
output: Optional[str] = NotProvided,
output_config: Optional[Dict] = NotProvided,
output_compress_columns: Optional[List[str]] = NotProvided,
output_max_file_size: Optional[float] = NotProvided,
output_max_rows_per_file: Optional[int] = NotProvided,
output_write_remaining_data: Optional[bool] = NotProvided,
output_write_method: Optional[str] = NotProvided,
output_write_method_kwargs: Optional[Dict] = NotProvided,
output_filesystem: Optional[str] = NotProvided,
output_filesystem_kwargs: Optional[Dict] = NotProvided,
output_write_episodes: Optional[bool] = NotProvided,
offline_sampling: Optional[str] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's offline data settings.
Args:
input_: Specify how to generate experiences:
- "sampler": Generate experiences via online (env) simulation (default).
- A local directory or file glob expression (e.g., "/tmp/*.json").
- A list of individual file paths/URIs (e.g., ["/tmp/1.json",
"s3://bucket/2.json"]).
- A dict with string keys and sampling probabilities as values (e.g.,
{"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
- A callable that takes an `IOContext` object as only arg and returns a
`ray.rllib.offline.InputReader`.
- A string key that indexes a callable with
`tune.registry.register_input`
offline_data_class: An optional `OfflineData` class that is used to define
the offline data pipeline, including the dataset and the sampling
methodology. Override the `OfflineData` class and pass your derived
class here, if you need some primer transformations specific to your
data or your loss. Usually overriding the `OfflinePreLearner` and using
the resulting customization via `prelearner_class` suffices for most
cases. The default is `None` which uses the base `OfflineData` defined
in `ray.rllib.offline.offline_data.OfflineData`.
input_read_method: Read method for the `ray.data.Dataset` to read in the
offline data from `input_`. The default is `read_parquet` for Parquet
files. See https://docs.ray.io/en/latest/data/api/input_output.html for
more info about available read methods in `ray.data`.
input_read_method_kwargs: Keyword args for `input_read_method`. These
are passed by RLlib into the read method without checking. Use these
keyword args together with `map_batches_kwargs` and
`iter_batches_kwargs` to tune the performance of the data pipeline.
It is strongly recommended to rely on Ray Data's automatic read
performance tuning.
input_read_schema: Table schema for converting offline data to episodes.
This schema maps the offline data columns to
ray.rllib.core.columns.Columns:
`{Columns.OBS: 'o_t', Columns.ACTIONS: 'a_t', ...}`. Columns in
the data set that are not mapped via this schema are sorted into
episodes' `extra_model_outputs`. If no schema is passed in the default
schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set
contains already the names in this schema, no `input_read_schema` is
needed. The same applies if the data is in RLlib's `EpisodeType` or its
old `SampleBatch` format.
input_read_episodes: Whether offline data is already stored in RLlib's
`EpisodeType` format, i.e. `ray.rllib.env.SingleAgentEpisode` (multi
-agent is planned but not supported, yet). Reading episodes directly
avoids additional transform steps and is usually faster and
therefore the recommended format when your application remains fully
inside of RLlib's schema. The other format is a columnar format and is
agnostic to the RL framework used. Use the latter format, if you are
unsure when to use the data or in which RL framework. The default is
to read column data, for example, `False`. `input_read_episodes`, and
`input_read_sample_batches` can't be `True` at the same time. See
also `output_write_episodes` to define the output data format when
recording.
input_read_sample_batches: Whether offline data is stored in RLlib's old
stack `SampleBatch` type. This is usually the case for older data
recorded with RLlib in JSON line format. Reading in `SampleBatch`
data needs extra transforms and might not concatenate episode chunks
contained in different `SampleBatch`es in the data. If possible avoid
to read `SampleBatch`es and convert them in a controlled form into
RLlib's `EpisodeType` (i.e. `SingleAgentEpisode`). The default is
`False`. `input_read_episodes`, and `input_read_sample_batches` can't
be `True` at the same time.
input_read_batch_size: Batch size to pull from the data set. This could
differ from the `train_batch_size_per_learner`, if a dataset holds
`EpisodeType` (i.e., `SingleAgentEpisode`) or `SampleBatch`, or any
other data type that contains multiple timesteps in a single row of
the dataset. In such cases a single batch of size
`train_batch_size_per_learner` will potentially pull a multiple of
`train_batch_size_per_learner` timesteps from the offline dataset. The
default is `None` in which the `train_batch_size_per_learner` is pulled.
input_filesystem: A cloud filesystem to handle access to cloud storage when
reading experiences. Can be either "gcs" for Google Cloud Storage,
"s3" for AWS S3 buckets, "abs" for Azure Blob Storage, or any
filesystem supported by PyArrow. In general the file path is sufficient
for accessing data from public or local storage systems. See
https://arrow.apache.org/docs/python/filesystems.html for details.
input_filesystem_kwargs: A dictionary holding the kwargs for the filesystem
given by `input_filesystem`. See `gcsfs.GCSFilesystem` for GCS,
`pyarrow.fs.S3FileSystem`, for S3, and `ablfs.AzureBlobFilesystem` for
ABS filesystem arguments.
input_compress_columns: What input columns are compressed with LZ4 in the
input data. If data is stored in RLlib's `SingleAgentEpisode` (
`MultiAgentEpisode` not supported, yet). Note the providing
`rllib.core.columns.Columns.OBS` also tries to decompress
`rllib.core.columns.Columns.NEXT_OBS`.
materialize_data: Whether the raw data should be materialized in memory.
This boosts performance, but requires enough memory to avoid an OOM, so
make sure that your cluster has the resources available. For very large
data you might want to switch to streaming mode by setting this to
`False` (default). If your algorithm does not need the RLModule in the
Learner connector pipeline or all (learner) connectors are stateless
you should consider setting `materialize_mapped_data` to `True`
instead (and set `materialize_data` to `False`). If your data does not
fit into memory and your Learner connector pipeline requires an RLModule
or is stateful, set both `materialize_data` and
`materialize_mapped_data` to `False`.
materialize_mapped_data: Whether the data should be materialized after
running it through the Learner connector pipeline (i.e. after running
the `OfflinePreLearner`). This improves performance, but should only be
used in case the (learner) connector pipeline does not require an
RLModule and the (learner) connector pipeline is stateless. For example,
MARWIL's Learner connector pipeline requires the RLModule for value
function predictions and training batches would become stale after some
iterations causing learning degradation or divergence. Also ensure that
your cluster has enough memory available to avoid an OOM. If set to
`True` (True), make sure that `materialize_data` is set to `False` to
avoid materialization of two datasets. If your data does not fit into
memory and your Learner connector pipeline requires an RLModule or is
stateful, set both `materialize_data` and `materialize_mapped_data` to
`False`.
map_batches_kwargs: Keyword args for the `map_batches` method. These are
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments
`{'concurrency': max(2, num_learners), 'zero_copy_batch': True}` is
used. Use these keyword args together with `input_read_method_kwargs`
and `iter_batches_kwargs` to tune the performance of the data pipeline.
iter_batches_kwargs: Keyword args for the `iter_batches` method. These are
passed into the `ray.data.Dataset.iter_batches` method when sampling
without checking. If no arguments are passed in, the default argument
`{'prefetch_batches': 2}` is used. Use these keyword args
together with `input_read_method_kwargs` and `map_batches_kwargs` to
tune the performance of the data pipeline.
prelearner_class: An optional `OfflinePreLearner` class that is used to
transform data batches in `ray.data.map_batches` used in the
`OfflineData` class to transform data from columns to batches that can
be used in the `Learner.update...()` methods. Override the
`OfflinePreLearner` class and pass your derived class in here, if you
need to make some further transformations specific for your data or
loss. The default is None which uses the base `OfflinePreLearner`
defined in `ray.rllib.offline.offline_prelearner`.
prelearner_buffer_class: An optional `EpisodeReplayBuffer` class that RLlib
uses to buffer experiences when data is in `EpisodeType` or
RLlib's previous `SampleBatch` type format. In this case, a single
data row may contain multiple timesteps and the buffer serves two
purposes: (a) to store intermediate data in memory, and (b) to ensure
that RLlib samples exactly `train_batch_size_per_learner` experiences
per batch. The default is RLlib's `EpisodeReplayBuffer`.
prelearner_buffer_kwargs: Optional keyword arguments for intializing the
`EpisodeReplayBuffer`. In most cases this value is simply the `capacity`
for the default buffer that RLlib uses (`EpisodeReplayBuffer`), but it
may differ if the `prelearner_buffer_class` uses a custom buffer.
prelearner_module_synch_period: The period (number of batches converted)
after which the `RLModule` held by the `PreLearner` should sync weights.
The `PreLearner` is used to preprocess batches for the learners. The
higher this value, the more off-policy the `PreLearner`'s module is.
Values too small force the `PreLearner` to sync more frequently
and thus might slow down the data pipeline. The default value chosen
by the `OfflinePreLearner` is 10.
dataset_num_iters_per_learner: Number of updates to run in each learner
during a single training iteration. If None, each learner runs a
complete epoch over its data block (the dataset is partitioned into
at least as many blocks as there are learners). The default is `None`.
This value must be set to `1`, if RLlib uses a single (local) learner.
input_config: Arguments that describe the settings for reading the input.
If input is "sample", this is the environment configuration, e.g.
`env_name` and `env_config`, etc. See `EnvContext` for more info.
If the input is "dataset", this contains e.g. `format`, `path`.
actions_in_input_normalized: True, if the actions in a given offline "input"
are already normalized (between -1.0 and 1.0). This is usually the case
when the offline file has been generated by another RLlib algorithm
(e.g. PPO or SAC), while "normalize_actions" was set to True.
postprocess_inputs: Whether to run postprocess_trajectory() on the
trajectory fragments from offline inputs. Note that postprocessing is
done using the *current* policy, not the *behavior* policy, which
is typically undesirable for on-policy algorithms.
shuffle_buffer_size: If positive, input batches are shuffled via a
sliding window buffer of this number of batches. Use this if the input
data is not in random enough order. Input is delayed until the shuffle
buffer is filled.
output: Specify where experiences should be saved:
- None: don't save any experiences
- "logdir" to save to the agent log dir
- a path/URI to save to a custom output directory (e.g., "s3://bckt/")
- a function that returns a rllib.offline.OutputWriter
output_config: Arguments accessible from the IOContext for configuring
custom output.
output_compress_columns: What sample batch columns to LZ4 compress in the
output data. Note that providing `rllib.core.columns.Columns.OBS` also
compresses `rllib.core.columns.Columns.NEXT_OBS`.
output_max_file_size: Max output file size (in bytes) before rolling over
to a new file.
output_max_rows_per_file: Max output row numbers before rolling over to a
new file.
output_write_remaining_data: Determines whether any remaining data in the
recording buffers should be stored to disk. It is only applicable if
`output_max_rows_per_file` is defined. When sampling data, it is
buffered until the threshold specified by `output_max_rows_per_file`
is reached. Only complete multiples of `output_max_rows_per_file` are
written to disk, while any leftover data remains in the buffers. If a
recording session is stopped, residual data may still reside in these
buffers. Setting `output_write_remaining_data` to `True` ensures this
data is flushed to disk. By default, this attribute is set to `False`.
output_write_method: Write method for the `ray.data.Dataset` to write the
offline data to `output`. The default is `read_parquet` for Parquet
files. See https://docs.ray.io/en/latest/data/api/input_output.html for
more info about available read methods in `ray.data`.
output_write_method_kwargs: `kwargs` for the `output_write_method`. These
are passed into the write method without checking.
output_filesystem: A cloud filesystem to handle access to cloud storage when
writing experiences. Should be either "gcs" for Google Cloud Storage,
"s3" for AWS S3 buckets, or "abs" for Azure Blob Storage.
output_filesystem_kwargs: A dictionary holding the kwargs for the filesystem
given by `output_filesystem`. See `gcsfs.GCSFilesystem` for GCS,
`pyarrow.fs.S3FileSystem`, for S3, and `ablfs.AzureBlobFilesystem` for
ABS filesystem arguments.
output_write_episodes: If RLlib should record data in its RLlib's
`EpisodeType` format (that is, `SingleAgentEpisode` objects). Use this
format, if you need RLlib to order data in time and directly group by
episodes for example to train stateful modules or if you plan to use
recordings exclusively in RLlib. Otherwise RLlib records data in tabular
(columnar) format. Default is `True`.
offline_sampling: Whether sampling for the Algorithm happens via
reading from offline data. If True, EnvRunners don't limit the number
of collected batches within the same `sample()` call based on
the number of sub-environments within the worker (no sub-environments
present).
Returns:
This updated AlgorithmConfig object.
"""
if input_ is not NotProvided:
self.input_ = input_
if offline_data_class is not NotProvided:
self.offline_data_class = offline_data_class
if input_read_method is not NotProvided:
self.input_read_method = input_read_method
if input_read_method_kwargs is not NotProvided:
self.input_read_method_kwargs = input_read_method_kwargs
if input_read_schema is not NotProvided:
self.input_read_schema = input_read_schema
if input_read_episodes is not NotProvided:
self.input_read_episodes = input_read_episodes
if input_read_sample_batches is not NotProvided:
self.input_read_sample_batches = input_read_sample_batches
if input_read_batch_size is not NotProvided:
self.input_read_batch_size = input_read_batch_size
if input_filesystem is not NotProvided:
self.input_filesystem = input_filesystem
if input_filesystem_kwargs is not NotProvided:
self.input_filesystem_kwargs = input_filesystem_kwargs
if input_compress_columns is not NotProvided:
self.input_compress_columns = input_compress_columns
if materialize_data is not NotProvided:
self.materialize_data = materialize_data
if materialize_mapped_data is not NotProvided:
self.materialize_mapped_data = materialize_mapped_data
if map_batches_kwargs is not NotProvided:
self.map_batches_kwargs = map_batches_kwargs
if iter_batches_kwargs is not NotProvided:
self.iter_batches_kwargs = iter_batches_kwargs
if prelearner_class is not NotProvided:
self.prelearner_class = prelearner_class
if prelearner_buffer_class is not NotProvided:
self.prelearner_buffer_class = prelearner_buffer_class
if prelearner_buffer_kwargs is not NotProvided:
self.prelearner_buffer_kwargs = prelearner_buffer_kwargs
if prelearner_module_synch_period is not NotProvided:
self.prelearner_module_synch_period = prelearner_module_synch_period
if dataset_num_iters_per_learner is not NotProvided:
self.dataset_num_iters_per_learner = dataset_num_iters_per_learner
if input_config is not NotProvided:
if not isinstance(input_config, dict):
raise ValueError(
f"input_config must be a dict, got {type(input_config)}."
)
# TODO (Kourosh) Once we use a complete separation between rollout worker
# and input dataset reader we can remove this.
# For now Error out if user attempts to set these parameters.
msg = "{} should not be set in the input_config. RLlib uses {} instead."
if input_config.get("num_cpus_per_read_task") is not None:
raise ValueError(
msg.format(
"num_cpus_per_read_task",
"config.env_runners(num_cpus_per_env_runner=..)",
)
)
if input_config.get("parallelism") is not None:
if self.in_evaluation:
raise ValueError(
msg.format(
"parallelism",
"config.evaluation(evaluation_num_env_runners=..)",
)
)
else:
raise ValueError(
msg.format(
"parallelism", "config.env_runners(num_env_runners=..)"
)
)
self.input_config = input_config
if actions_in_input_normalized is not NotProvided:
self.actions_in_input_normalized = actions_in_input_normalized
if postprocess_inputs is not NotProvided:
self.postprocess_inputs = postprocess_inputs
if shuffle_buffer_size is not NotProvided:
self.shuffle_buffer_size = shuffle_buffer_size
# TODO (simon): Enable storing to general log-directory.
if output is not NotProvided:
self.output = output
if output_config is not NotProvided:
self.output_config = output_config
if output_compress_columns is not NotProvided:
self.output_compress_columns = output_compress_columns
if output_max_file_size is not NotProvided:
self.output_max_file_size = output_max_file_size
if output_max_rows_per_file is not NotProvided:
self.output_max_rows_per_file = output_max_rows_per_file
if output_write_remaining_data is not NotProvided:
self.output_write_remaining_data = output_write_remaining_data
if output_write_method is not NotProvided:
self.output_write_method = output_write_method
if output_write_method_kwargs is not NotProvided:
self.output_write_method_kwargs = output_write_method_kwargs
if output_filesystem is not NotProvided:
self.output_filesystem = output_filesystem
if output_filesystem_kwargs is not NotProvided:
self.output_filesystem_kwargs = output_filesystem_kwargs
if output_write_episodes is not NotProvided:
self.output_write_episodes = output_write_episodes
if offline_sampling is not NotProvided:
self.offline_sampling = offline_sampling
return self
[docs]
def multi_agent(
self,
*,
policies: Optional[
Union[MultiAgentPolicyConfigDict, Collection[PolicyID]]
] = NotProvided,
policy_map_capacity: Optional[int] = NotProvided,
policy_mapping_fn: Optional[
Callable[[AgentID, "EpisodeType"], PolicyID]
] = NotProvided,
policies_to_train: Optional[
Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]]
] = NotProvided,
policy_states_are_swappable: Optional[bool] = NotProvided,
observation_fn: Optional[Callable] = NotProvided,
count_steps_by: Optional[str] = NotProvided,
# Deprecated args:
algorithm_config_overrides_per_module=DEPRECATED_VALUE,
replay_mode=DEPRECATED_VALUE,
# Now done via Ray object store, which has its own cloud-supported
# spillover mechanism.
policy_map_cache=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the config's multi-agent settings.
Validates the new multi-agent settings and translates everything into
a unified multi-agent setup format. For example a `policies` list or set
of IDs is properly converted into a dict mapping these IDs to PolicySpecs.
Args:
policies: Map of type MultiAgentPolicyConfigDict from policy ids to either
4-tuples of (policy_cls, obs_space, act_space, config) or PolicySpecs.
These tuples or PolicySpecs define the class of the policy, the
observation- and action spaces of the policies, and any extra config.
policy_map_capacity: Keep this many policies in the "policy_map" (before
writing least-recently used ones to disk/S3).
policy_mapping_fn: Function mapping agent ids to policy ids. The signature
is: `(agent_id, episode, worker, **kwargs) -> PolicyID`.
policies_to_train: Determines those policies that should be updated.
Options are:
- None, for training all policies.
- An iterable of PolicyIDs that should be trained.
- A callable, taking a PolicyID and a SampleBatch or MultiAgentBatch
and returning a bool (indicating whether the given policy is trainable
or not, given the particular batch). This allows you to have a policy
trained only on certain data (e.g. when playing against a certain
opponent).
policy_states_are_swappable: Whether all Policy objects in this map can be
"swapped out" via a simple `state = A.get_state(); B.set_state(state)`,
where `A` and `B` are policy instances in this map. You should set
this to True for significantly speeding up the PolicyMap's cache lookup
times, iff your policies all share the same neural network
architecture and optimizer types. If True, the PolicyMap doesn't
have to garbage collect old, least recently used policies, but instead
keeps them in memory and simply override their state with the state of
the most recently accessed one.
For example, in a league-based training setup, you might have 100s of
the same policies in your map (playing against each other in various
combinations), but all of them share the same state structure
(are "swappable").
observation_fn: Optional function that can be used to enhance the local
agent observations to include more state. See
rllib/evaluation/observation_function.py for more info.
count_steps_by: Which metric to use as the "batch size" when building a
MultiAgentBatch. The two supported values are:
"env_steps": Count each time the env is "stepped" (no matter how many
multi-agent actions are passed/how many multi-agent observations
have been returned in the previous step).
"agent_steps": Count each individual agent step as one step.
Returns:
This updated AlgorithmConfig object.
"""
if policies is not NotProvided:
# Make sure our Policy IDs are ok (this should work whether `policies`
# is a dict or just any Sequence).
for pid in policies:
validate_module_id(pid, error=True)
# Collection: Convert to dict.
if isinstance(policies, (set, tuple, list)):
policies = {p: PolicySpec() for p in policies}
# Validate each policy spec in a given dict.
if isinstance(policies, dict):
for pid, spec in policies.items():
# If not a PolicySpec object, values must be lists/tuples of len 4.
if not isinstance(spec, PolicySpec):
if not isinstance(spec, (list, tuple)) or len(spec) != 4:
raise ValueError(
"Policy specs must be tuples/lists of "
"(cls or None, obs_space, action_space, config), "
f"got {spec} for PolicyID={pid}"
)
# TODO: Switch from dict to AlgorithmConfigOverride, once available.
# Config not a dict.
elif (
not isinstance(spec.config, (AlgorithmConfig, dict))
and spec.config is not None
):
raise ValueError(
f"Multi-agent policy config for {pid} must be a dict or "
f"AlgorithmConfig object, but got {type(spec.config)}!"
)
self.policies = policies
else:
raise ValueError(
"`policies` must be dict mapping PolicyID to PolicySpec OR a "
"set/tuple/list of PolicyIDs!"
)
if algorithm_config_overrides_per_module != DEPRECATED_VALUE:
deprecation_warning(old="", error=False)
self.rl_module(
algorithm_config_overrides_per_module=(
algorithm_config_overrides_per_module
)
)
if policy_map_capacity is not NotProvided:
self.policy_map_capacity = policy_map_capacity
if policy_mapping_fn is not NotProvided:
# Create `policy_mapping_fn` from a config dict.
# Helpful if users would like to specify custom callable classes in
# yaml files.
if isinstance(policy_mapping_fn, dict):
policy_mapping_fn = from_config(policy_mapping_fn)
self.policy_mapping_fn = policy_mapping_fn
if observation_fn is not NotProvided:
self.observation_fn = observation_fn
if policy_map_cache != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.multi_agent(policy_map_cache=..)",
error=True,
)
if replay_mode != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.multi_agent(replay_mode=..)",
new="AlgorithmConfig.training("
"replay_buffer_config={'replay_mode': ..})",
error=True,
)
if count_steps_by is not NotProvided:
if count_steps_by not in ["env_steps", "agent_steps"]:
raise ValueError(
"config.multi_agent(count_steps_by=..) must be one of "
f"[env_steps|agent_steps], not {count_steps_by}!"
)
self.count_steps_by = count_steps_by
if policies_to_train is not NotProvided:
assert (
isinstance(policies_to_train, (list, set, tuple))
or callable(policies_to_train)
or policies_to_train is None
), (
"ERROR: `policies_to_train` must be a [list|set|tuple] or a "
"callable taking PolicyID and SampleBatch and returning "
"True|False (trainable or not?) or None (for always training all "
"policies)."
)
# Check `policies_to_train` for invalid entries.
if isinstance(policies_to_train, (list, set, tuple)):
if len(policies_to_train) == 0:
logger.warning(
"`config.multi_agent(policies_to_train=..)` is empty! "
"Make sure - if you would like to learn at least one policy - "
"to add its ID to that list."
)
self.policies_to_train = policies_to_train
if policy_states_are_swappable is not NotProvided:
self.policy_states_are_swappable = policy_states_are_swappable
return self
[docs]
def reporting(
self,
*,
keep_per_episode_custom_metrics: Optional[bool] = NotProvided,
metrics_episode_collection_timeout_s: Optional[float] = NotProvided,
metrics_num_episodes_for_smoothing: Optional[int] = NotProvided,
min_time_s_per_iteration: Optional[float] = NotProvided,
min_train_timesteps_per_iteration: Optional[int] = NotProvided,
min_sample_timesteps_per_iteration: Optional[int] = NotProvided,
log_gradients: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's reporting settings.
Args:
keep_per_episode_custom_metrics: Store raw custom metrics without
calculating max, min, mean
metrics_episode_collection_timeout_s: Wait for metric batches for at most
this many seconds. Those that have not returned in time are collected
in the next train iteration.
metrics_num_episodes_for_smoothing: Smooth rollout metrics over this many
episodes, if possible.
In case rollouts (sample collection) just started, there may be fewer
than this many episodes in the buffer and we'll compute metrics
over this smaller number of available episodes.
In case there are more than this many episodes collected in a single
training iteration, use all of these episodes for metrics computation,
meaning don't ever cut any "excess" episodes.
Set this to 1 to disable smoothing and to always report only the most
recently collected episode's return.
min_time_s_per_iteration: Minimum time (in sec) to accumulate within a
single `Algorithm.train()` call. This value does not affect learning,
only the number of times `Algorithm.training_step()` is called by
`Algorithm.train()`. If - after one such step attempt, the time taken
has not reached `min_time_s_per_iteration`, performs n more
`Algorithm.training_step()` calls until the minimum time has been
consumed. Set to 0 or None for no minimum time.
min_train_timesteps_per_iteration: Minimum training timesteps to accumulate
within a single `train()` call. This value does not affect learning,
only the number of times `Algorithm.training_step()` is called by
`Algorithm.train()`. If - after one such step attempt, the training
timestep count has not been reached, performs n more
`training_step()` calls until the minimum timesteps have been
executed. Set to 0 or None for no minimum timesteps.
min_sample_timesteps_per_iteration: Minimum env sampling timesteps to
accumulate within a single `train()` call. This value does not affect
learning, only the number of times `Algorithm.training_step()` is
called by `Algorithm.train()`. If - after one such step attempt, the env
sampling timestep count has not been reached, performs n more
`training_step()` calls until the minimum timesteps have been
executed. Set to 0 or None for no minimum timesteps.
log_gradients: Log gradients to results. If this is `True` the global norm
of the gradients dictionariy for each optimizer is logged to results.
The default is `True`.
Returns:
This updated AlgorithmConfig object.
"""
if keep_per_episode_custom_metrics is not NotProvided:
self.keep_per_episode_custom_metrics = keep_per_episode_custom_metrics
if metrics_episode_collection_timeout_s is not NotProvided:
self.metrics_episode_collection_timeout_s = (
metrics_episode_collection_timeout_s
)
if metrics_num_episodes_for_smoothing is not NotProvided:
self.metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing
if min_time_s_per_iteration is not NotProvided:
self.min_time_s_per_iteration = min_time_s_per_iteration
if min_train_timesteps_per_iteration is not NotProvided:
self.min_train_timesteps_per_iteration = min_train_timesteps_per_iteration
if min_sample_timesteps_per_iteration is not NotProvided:
self.min_sample_timesteps_per_iteration = min_sample_timesteps_per_iteration
if log_gradients is not NotProvided:
self.log_gradients = log_gradients
return self
[docs]
def checkpointing(
self,
export_native_model_files: Optional[bool] = NotProvided,
checkpoint_trainable_policies_only: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's checkpointing settings.
Args:
export_native_model_files: Whether an individual Policy-
or the Algorithm's checkpoints also contain (tf or torch) native
model files. These could be used to restore just the NN models
from these files w/o requiring RLlib. These files are generated
by calling the tf- or torch- built-in saving utility methods on
the actual models.
checkpoint_trainable_policies_only: Whether to only add Policies to the
Algorithm checkpoint (in sub-directory "policies/") that are trainable
according to the `is_trainable_policy` callable of the local worker.
Returns:
This updated AlgorithmConfig object.
"""
if export_native_model_files is not NotProvided:
self.export_native_model_files = export_native_model_files
if checkpoint_trainable_policies_only is not NotProvided:
self.checkpoint_trainable_policies_only = checkpoint_trainable_policies_only
return self
[docs]
def debugging(
self,
*,
logger_creator: Optional[Callable[[], Logger]] = NotProvided,
logger_config: Optional[dict] = NotProvided,
log_level: Optional[str] = NotProvided,
log_sys_usage: Optional[bool] = NotProvided,
fake_sampler: Optional[bool] = NotProvided,
seed: Optional[int] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's debugging settings.
Args:
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
logger_config: Define logger-specific configuration to be used inside Logger
Default value None allows overwriting with nested dicts.
log_level: Set the ray.rllib.* log level for the agent process and its
workers. Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level
also periodically prints out summaries of relevant internal dataflow
(this is also printed out once at startup at the INFO level).
log_sys_usage: Log system resource metrics to results. This requires
`psutil` to be installed for sys stats, and `gputil` for GPU metrics.
fake_sampler: Use fake (infinite speed) sampler. For testing only.
seed: This argument, in conjunction with worker_index, sets the random
seed of each worker, so that identically configured trials have
identical results. This makes experiments reproducible.
Returns:
This updated AlgorithmConfig object.
"""
if logger_creator is not NotProvided:
self.logger_creator = logger_creator
if logger_config is not NotProvided:
self.logger_config = logger_config
if log_level is not NotProvided:
self.log_level = log_level
if log_sys_usage is not NotProvided:
self.log_sys_usage = log_sys_usage
if fake_sampler is not NotProvided:
self.fake_sampler = fake_sampler
if seed is not NotProvided:
self.seed = seed
return self
[docs]
def fault_tolerance(
self,
*,
restart_failed_env_runners: Optional[bool] = NotProvided,
ignore_env_runner_failures: Optional[bool] = NotProvided,
max_num_env_runner_restarts: Optional[int] = NotProvided,
delay_between_env_runner_restarts_s: Optional[float] = NotProvided,
restart_failed_sub_environments: Optional[bool] = NotProvided,
num_consecutive_env_runner_failures_tolerance: Optional[int] = NotProvided,
env_runner_health_probe_timeout_s: Optional[float] = NotProvided,
env_runner_restore_timeout_s: Optional[float] = NotProvided,
# Deprecated args.
recreate_failed_env_runners=DEPRECATED_VALUE,
ignore_worker_failures=DEPRECATED_VALUE,
recreate_failed_workers=DEPRECATED_VALUE,
max_num_worker_restarts=DEPRECATED_VALUE,
delay_between_worker_restarts_s=DEPRECATED_VALUE,
num_consecutive_worker_failures_tolerance=DEPRECATED_VALUE,
worker_health_probe_timeout_s=DEPRECATED_VALUE,
worker_restore_timeout_s=DEPRECATED_VALUE,
):
"""Sets the config's fault tolerance settings.
Args:
restart_failed_env_runners: Whether - upon an EnvRunner failure - RLlib
tries to restart the lost EnvRunner(s) as an identical copy of the
failed one(s). You should set this to True when training on SPOT
instances that may preempt any time. The new, recreated EnvRunner(s)
only differ from the failed one in their `self.recreated_worker=True`
property value and have the same `worker_index` as the original(s).
If this setting is True, the value of the `ignore_env_runner_failures`
setting is ignored.
ignore_env_runner_failures: Whether to ignore any EnvRunner failures
and continue running with the remaining EnvRunners. This setting is
ignored, if `restart_failed_env_runners=True`.
max_num_env_runner_restarts: The maximum number of times any EnvRunner
is allowed to be restarted (if `restart_failed_env_runners` is True).
delay_between_env_runner_restarts_s: The delay (in seconds) between two
consecutive EnvRunner restarts (if `restart_failed_env_runners` is
True).
restart_failed_sub_environments: If True and any sub-environment (within
a vectorized env) throws any error during env stepping, the
Sampler tries to restart the faulty sub-environment. This is done
without disturbing the other (still intact) sub-environment and without
the EnvRunner crashing.
num_consecutive_env_runner_failures_tolerance: The number of consecutive
times an EnvRunner failure (also for evaluation) is tolerated before
finally crashing the Algorithm. Only useful if either
`ignore_env_runner_failures` or `restart_failed_env_runners` is True.
Note that for `restart_failed_sub_environments` and sub-environment
failures, the EnvRunner itself is NOT affected and won't throw any
errors as the flawed sub-environment is silently restarted under the
hood.
env_runner_health_probe_timeout_s: Max amount of time in seconds, we should
spend waiting for EnvRunner health probe calls
(`EnvRunner.ping.remote()`) to respond. Health pings are very cheap,
however, we perform the health check via a blocking `ray.get()`, so the
default value should not be too large.
env_runner_restore_timeout_s: Max amount of time we should wait to restore
states on recovered EnvRunner actors. Default is 30 mins.
Returns:
This updated AlgorithmConfig object.
"""
if recreate_failed_env_runners != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.fault_tolerance(recreate_failed_env_runners)",
new="AlgorithmConfig.fault_tolerance(restart_failed_env_runners)",
error=True,
)
if ignore_worker_failures != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.fault_tolerance(ignore_worker_failures)",
new="AlgorithmConfig.fault_tolerance(ignore_env_runner_failures)",
error=True,
)
if recreate_failed_workers != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.fault_tolerance(recreate_failed_workers)",
new="AlgorithmConfig.fault_tolerance(restart_failed_env_runners)",
error=True,
)
if max_num_worker_restarts != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.fault_tolerance(max_num_worker_restarts)",
new="AlgorithmConfig.fault_tolerance(max_num_env_runner_restarts)",
error=True,
)
if delay_between_worker_restarts_s != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.fault_tolerance(delay_between_worker_restarts_s)",
new="AlgorithmConfig.fault_tolerance(delay_between_env_runner_"
"restarts_s)",
error=True,
)
if num_consecutive_worker_failures_tolerance != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.fault_tolerance(num_consecutive_worker_"
"failures_tolerance)",
new="AlgorithmConfig.fault_tolerance(num_consecutive_env_runner_"
"failures_tolerance)",
error=True,
)
if worker_health_probe_timeout_s != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.fault_tolerance(worker_health_probe_timeout_s)",
new="AlgorithmConfig.fault_tolerance("
"env_runner_health_probe_timeout_s)",
error=True,
)
if worker_restore_timeout_s != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.fault_tolerance(worker_restore_timeout_s)",
new="AlgorithmConfig.fault_tolerance(env_runner_restore_timeout_s)",
error=True,
)
if ignore_env_runner_failures is not NotProvided:
self.ignore_env_runner_failures = ignore_env_runner_failures
if restart_failed_env_runners is not NotProvided:
self.restart_failed_env_runners = restart_failed_env_runners
if max_num_env_runner_restarts is not NotProvided:
self.max_num_env_runner_restarts = max_num_env_runner_restarts
if delay_between_env_runner_restarts_s is not NotProvided:
self.delay_between_env_runner_restarts_s = (
delay_between_env_runner_restarts_s
)
if restart_failed_sub_environments is not NotProvided:
self.restart_failed_sub_environments = restart_failed_sub_environments
if num_consecutive_env_runner_failures_tolerance is not NotProvided:
self.num_consecutive_env_runner_failures_tolerance = (
num_consecutive_env_runner_failures_tolerance
)
if env_runner_health_probe_timeout_s is not NotProvided:
self.env_runner_health_probe_timeout_s = env_runner_health_probe_timeout_s
if env_runner_restore_timeout_s is not NotProvided:
self.env_runner_restore_timeout_s = env_runner_restore_timeout_s
return self
[docs]
def rl_module(
self,
*,
model_config: Optional[Union[Dict[str, Any], DefaultModelConfig]] = NotProvided,
rl_module_spec: Optional[RLModuleSpecType] = NotProvided,
algorithm_config_overrides_per_module: Optional[
Dict[ModuleID, PartialAlgorithmConfigDict]
] = NotProvided,
# Deprecated arg.
model_config_dict=DEPRECATED_VALUE,
_enable_rl_module_api=DEPRECATED_VALUE,
) -> "AlgorithmConfig":
"""Sets the config's RLModule settings.
Args:
model_config: The DefaultModelConfig object (or a config dictionary) passed
as `model_config` arg into each RLModule's constructor. This is used
for all RLModules, if not otherwise specified through `rl_module_spec`.
rl_module_spec: The RLModule spec to use for this config. It can be either
a RLModuleSpec or a MultiRLModuleSpec. If the
observation_space, action_space, catalog_class, or the model config is
not specified it is inferred from the env and other parts of the
algorithm config object.
algorithm_config_overrides_per_module: Only used if
`enable_rl_module_and_learner=True`.
A mapping from ModuleIDs to per-module AlgorithmConfig override dicts,
which apply certain settings,
e.g. the learning rate, from the main AlgorithmConfig only to this
particular module (within a MultiRLModule).
You can create override dicts by using the `AlgorithmConfig.overrides`
utility. For example, to override your learning rate and (PPO) lambda
setting just for a single RLModule with your MultiRLModule, do:
config.multi_agent(algorithm_config_overrides_per_module={
"module_1": PPOConfig.overrides(lr=0.0002, lambda_=0.75),
})
Returns:
This updated AlgorithmConfig object.
"""
if _enable_rl_module_api != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.rl_module(_enable_rl_module_api=..)",
new="AlgorithmConfig.api_stack(enable_rl_module_and_learner=..)",
error=True,
)
if model_config_dict != DEPRECATED_VALUE:
deprecation_warning(
old="AlgorithmConfig.rl_module(model_config_dict=..)",
new="AlgorithmConfig.rl_module(model_config=..)",
error=False,
)
model_config = model_config_dict
if model_config is not NotProvided:
self._model_config = model_config
if rl_module_spec is not NotProvided:
self._rl_module_spec = rl_module_spec
if algorithm_config_overrides_per_module is not NotProvided:
if not isinstance(algorithm_config_overrides_per_module, dict):
raise ValueError(
"`algorithm_config_overrides_per_module` must be a dict mapping "
"module IDs to config override dicts! You provided "
f"{algorithm_config_overrides_per_module}."
)
self.algorithm_config_overrides_per_module.update(
algorithm_config_overrides_per_module
)
return self
[docs]
def experimental(
self,
*,
_validate_config: Optional[bool] = True,
_use_msgpack_checkpoints: Optional[bool] = NotProvided,
_torch_grad_scaler_class: Optional[Type] = NotProvided,
_torch_lr_scheduler_classes: Optional[
Union[List[Type], Dict[ModuleID, List[Type]]]
] = NotProvided,
_tf_policy_handles_more_than_one_loss: Optional[bool] = NotProvided,
_disable_preprocessor_api: Optional[bool] = NotProvided,
_disable_action_flattening: Optional[bool] = NotProvided,
_disable_initialize_loss_from_dummy_batch: Optional[bool] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's experimental settings.
Args:
_validate_config: Whether to run `validate()` on this config. True by
default. If False, ignores any calls to `self.validate()`.
_use_msgpack_checkpoints: Create state files in all checkpoints through
msgpack rather than pickle.
_torch_grad_scaler_class: Class to use for torch loss scaling (and gradient
unscaling). The class must implement the following methods to be
compatible with a `TorchLearner`. These methods/APIs match exactly those
of torch's own `torch.amp.GradScaler` (see here for more details
https://pytorch.org/docs/stable/amp.html#gradient-scaling):
`scale([loss])` to scale the loss by some factor.
`get_scale()` to get the current scale factor value.
`step([optimizer])` to unscale the grads (divide by the scale factor)
and step the given optimizer.
`update()` to update the scaler after an optimizer step (for example to
adjust the scale factor).
_torch_lr_scheduler_classes: A list of `torch.lr_scheduler.LRScheduler`
(see here for more details
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate)
classes or a dictionary mapping module IDs to such a list of respective
scheduler classes. Multiple scheduler classes can be applied in sequence
and are stepped in the same sequence as defined here. Note, most
learning rate schedulers need arguments to be configured, that is, you
might have to partially initialize the schedulers in the list(s) using
`functools.partial`.
_tf_policy_handles_more_than_one_loss: Experimental flag.
If True, TFPolicy handles more than one loss or optimizer.
Set this to True, if you would like to return more than
one loss term from your `loss_fn` and an equal number of optimizers
from your `optimizer_fn`.
_disable_preprocessor_api: Experimental flag.
If True, no (observation) preprocessor is created and
observations arrive in model as they are returned by the env.
_disable_action_flattening: Experimental flag.
If True, RLlib doesn't flatten the policy-computed actions into
a single tensor (for storage in SampleCollectors/output files/etc..),
but leave (possibly nested) actions as-is. Disabling flattening affects:
- SampleCollectors: Have to store possibly nested action structs.
- Models that have the previous action(s) as part of their input.
- Algorithms reading from offline files (incl. action information).
Returns:
This updated AlgorithmConfig object.
"""
if _validate_config is not NotProvided:
self._validate_config = _validate_config
if _use_msgpack_checkpoints is not NotProvided:
self._use_msgpack_checkpoints = _use_msgpack_checkpoints
if _tf_policy_handles_more_than_one_loss is not NotProvided:
self._tf_policy_handles_more_than_one_loss = (
_tf_policy_handles_more_than_one_loss
)
if _disable_preprocessor_api is not NotProvided:
self._disable_preprocessor_api = _disable_preprocessor_api
if _disable_action_flattening is not NotProvided:
self._disable_action_flattening = _disable_action_flattening
if _disable_initialize_loss_from_dummy_batch is not NotProvided:
self._disable_initialize_loss_from_dummy_batch = (
_disable_initialize_loss_from_dummy_batch
)
if _torch_grad_scaler_class is not NotProvided:
self._torch_grad_scaler_class = _torch_grad_scaler_class
if _torch_lr_scheduler_classes is not NotProvided:
self._torch_lr_scheduler_classes = _torch_lr_scheduler_classes
return self
@property
def is_atari(self) -> bool:
"""True if if specified env is an Atari env."""
# Not yet determined, try to figure this out.
if self._is_atari is None:
# Atari envs are usually specified via a string like "PongNoFrameskip-v4"
# or "ale_py:ALE/Breakout-v5".
# We do NOT attempt to auto-detect Atari env for other specified types like
# a callable, to avoid running heavy logics in validate().
# For these cases, users can explicitly set `environment(atari=True)`.
if type(self.env) is not str:
return False
try:
env = gym.make(self.env)
# Any gymnasium error -> Cannot be an Atari env.
except gym.error.Error:
return False
self._is_atari = is_atari(env)
# Clean up env's resources, if any.
env.close()
return self._is_atari
@property
def is_multi_agent(self) -> bool:
"""Returns whether this config specifies a multi-agent setup.
Returns:
True, if a) >1 policies defined OR b) 1 policy defined, but its ID is NOT
DEFAULT_POLICY_ID.
"""
return len(self.policies) > 1 or DEFAULT_POLICY_ID not in self.policies
@property
def learner_class(self) -> Type["Learner"]:
"""Returns the Learner sub-class to use by this Algorithm.
Either
a) User sets a specific learner class via calling `.training(learner_class=...)`
b) User leaves learner class unset (None) and the AlgorithmConfig itself
figures out the actual learner class by calling its own
`.get_default_learner_class()` method.
"""
return self._learner_class or self.get_default_learner_class()
@property
def model_config(self):
"""Defines the model configuration used.
This method combines the auto configuration `self _model_config_auto_includes`
defined by an algorithm with the user-defined configuration in
`self._model_config`.This configuration dictionary is used to
configure the `RLModule` in the new stack and the `ModelV2` in the old
stack.
Returns:
A dictionary with the model configuration.
"""
return self._model_config_auto_includes | (
self._model_config
if isinstance(self._model_config, dict)
else dataclasses.asdict(self._model_config)
)
@property
def rl_module_spec(self):
default_rl_module_spec = self.get_default_rl_module_spec()
_check_rl_module_spec(default_rl_module_spec)
# `self._rl_module_spec` has been user defined (via call to `self.rl_module()`).
if self._rl_module_spec is not None:
# Merge provided RL Module spec class with defaults.
_check_rl_module_spec(self._rl_module_spec)
# Merge given spec with default one (in case items are missing, such as
# spaces, module class, etc.)
if isinstance(self._rl_module_spec, RLModuleSpec):
if isinstance(default_rl_module_spec, RLModuleSpec):
default_rl_module_spec.update(self._rl_module_spec)
return default_rl_module_spec
elif isinstance(default_rl_module_spec, MultiRLModuleSpec):
raise ValueError(
"Cannot merge MultiRLModuleSpec with RLModuleSpec!"
)
else:
multi_rl_module_spec = copy.deepcopy(self._rl_module_spec)
multi_rl_module_spec.update(default_rl_module_spec)
return multi_rl_module_spec
# `self._rl_module_spec` has not been user defined -> return default one.
else:
return default_rl_module_spec
@property
def train_batch_size_per_learner(self):
# If not set explicitly, try to infer the value.
if self._train_batch_size_per_learner is None:
return self.train_batch_size // (self.num_learners or 1)
return self._train_batch_size_per_learner
@train_batch_size_per_learner.setter
def train_batch_size_per_learner(self, value):
self._train_batch_size_per_learner = value
@property
def train_batch_size_per_learner(self) -> int:
# If not set explicitly, try to infer the value.
if self._train_batch_size_per_learner is None:
return self.train_batch_size // (self.num_learners or 1)
return self._train_batch_size_per_learner
@train_batch_size_per_learner.setter
def train_batch_size_per_learner(self, value: int) -> None:
self._train_batch_size_per_learner = value
@property
def total_train_batch_size(self) -> int:
"""Returns the effective total train batch size.
New API stack: `train_batch_size_per_learner` * [effective num Learners].
@OldAPIStack: User never touches `train_batch_size_per_learner` or
`num_learners`) -> `train_batch_size`.
"""
return self.train_batch_size_per_learner * (self.num_learners or 1)
# TODO: Make rollout_fragment_length as read-only property and replace the current
# self.rollout_fragment_length a private variable.
[docs]
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
"""Automatically infers a proper rollout_fragment_length setting if "auto".
Uses the simple formula:
`rollout_fragment_length` = `total_train_batch_size` /
(`num_envs_per_env_runner` * `num_env_runners`)
If result is a fraction AND `worker_index` is provided, makes
those workers add additional timesteps, such that the overall batch size (across
the workers) adds up to exactly the `total_train_batch_size`.
Returns:
The user-provided `rollout_fragment_length` or a computed one (if user
provided value is "auto"), making sure `total_train_batch_size` is reached
exactly in each iteration.
"""
if self.rollout_fragment_length == "auto":
# Example:
# 2 workers, 2 envs per worker, 2000 train batch size:
# -> 2000 / 4 -> 500
# 4 workers, 3 envs per worker, 2500 train batch size:
# -> 2500 / 12 -> 208.333 -> diff=4 (208 * 12 = 2496)
# -> worker 1, 2: 209, workers 3, 4: 208
# 2 workers, 20 envs per worker, 512 train batch size:
# -> 512 / 40 -> 12.8 -> diff=32 (12 * 40 = 480)
# -> worker 1: 13, workers 2: 12
rollout_fragment_length = self.total_train_batch_size / (
self.num_envs_per_env_runner * (self.num_env_runners or 1)
)
if int(rollout_fragment_length) != rollout_fragment_length:
diff = self.total_train_batch_size - int(
rollout_fragment_length
) * self.num_envs_per_env_runner * (self.num_env_runners or 1)
if ((worker_index - 1) * self.num_envs_per_env_runner) >= diff:
return int(rollout_fragment_length)
else:
return int(rollout_fragment_length) + 1
return int(rollout_fragment_length)
else:
return self.rollout_fragment_length
# TODO: Make evaluation_config as read-only property and replace the current
# self.evaluation_config a private variable.
[docs]
def get_evaluation_config_object(
self,
) -> Optional["AlgorithmConfig"]:
"""Creates a full AlgorithmConfig object from `self.evaluation_config`.
Returns:
A fully valid AlgorithmConfig object that can be used for the evaluation
EnvRunnerGroup. If `self` is already an evaluation config object, return
None.
"""
if self.in_evaluation:
assert self.evaluation_config is None
return None
evaluation_config = self.evaluation_config
# Already an AlgorithmConfig -> copy and use as-is.
if isinstance(evaluation_config, AlgorithmConfig):
eval_config_obj = evaluation_config.copy(copy_frozen=False)
# Create unfrozen copy of self to be used as the to-be-returned eval
# AlgorithmConfig.
else:
eval_config_obj = self.copy(copy_frozen=False)
# Update with evaluation override settings:
eval_config_obj.update_from_dict(evaluation_config or {})
# Switch on the `in_evaluation` flag and remove `evaluation_config`
# (set to None).
eval_config_obj.in_evaluation = True
eval_config_obj.evaluation_config = None
# Force-set the `num_env_runners` setting to `self.evaluation_num_env_runners`.
# Actually, the `self.evaluation_num_env_runners` is merely a convenience
# attribute and might be set instead through:
# `config.evaluation(evaluation_config={"num_env_runners": ...})`
eval_config_obj.num_env_runners = self.evaluation_num_env_runners
# NOTE: The following if-block is only relevant for the old API stack.
# For the new API stack (EnvRunners), the evaluation methods of Algorithm
# explicitly tell each EnvRunner on each sample call, how many timesteps
# of episodes to collect.
# Evaluation duration unit: episodes.
# Switch on `complete_episode` rollouts. Also, make sure
# rollout fragments are short so we never have more than one
# episode in one rollout.
if self.evaluation_duration_unit == "episodes":
eval_config_obj.batch_mode = "complete_episodes"
eval_config_obj.rollout_fragment_length = 1
# Evaluation duration unit: timesteps.
# - Set `batch_mode=truncate_episodes` so we don't perform rollouts
# strictly along episode borders.
# Set `rollout_fragment_length` such that desired steps are divided
# equally amongst workers or - in "auto" duration mode - set it
# to a reasonably small number (10), such that a single `sample()`
# call doesn't take too much time and we can stop evaluation as soon
# as possible after the train step is completed.
else:
eval_config_obj.batch_mode = "truncate_episodes"
eval_config_obj.rollout_fragment_length = (
# Set to a moderately small (but not too small) value in order
# to a) not overshoot too much the parallelly running `training_step`
# but also to b) avoid too many `sample()` remote calls.
# 100 seems like a good middle ground.
100
if self.evaluation_duration == "auto"
else int(
math.ceil(
self.evaluation_duration
/ (self.evaluation_num_env_runners or 1)
)
)
)
return eval_config_obj
[docs]
def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
"""Detects mismatches for `train_batch_size` vs `rollout_fragment_length`.
Only applicable for algorithms, whose train_batch_size should be directly
dependent on rollout_fragment_length (synchronous sampling, on-policy PG algos).
If rollout_fragment_length != "auto", makes sure that the product of
`rollout_fragment_length` x `num_env_runners` x `num_envs_per_env_runner`
roughly (10%) matches the provided `train_batch_size`. Otherwise, errors with
asking the user to set rollout_fragment_length to `auto` or to a matching
value.
Raises:
ValueError: If there is a mismatch between user provided
`rollout_fragment_length` and `total_train_batch_size`.
"""
if self.rollout_fragment_length != "auto" and not self.in_evaluation:
min_batch_size = (
max(self.num_env_runners, 1)
* self.num_envs_per_env_runner
* self.rollout_fragment_length
)
batch_size = min_batch_size
while batch_size < self.total_train_batch_size:
batch_size += min_batch_size
if batch_size - self.total_train_batch_size > (
0.1 * self.total_train_batch_size
) or batch_size - min_batch_size - self.total_train_batch_size > (
0.1 * self.total_train_batch_size
):
suggested_rollout_fragment_length = self.total_train_batch_size // (
self.num_envs_per_env_runner * (self.num_env_runners or 1)
)
self._value_error(
"Your desired `total_train_batch_size` "
f"({self.total_train_batch_size}={self.num_learners} "
f"learners x {self.train_batch_size_per_learner}) "
"or a value 10% off of that cannot be achieved with your other "
f"settings (num_env_runners={self.num_env_runners}; "
f"num_envs_per_env_runner={self.num_envs_per_env_runner}; "
f"rollout_fragment_length={self.rollout_fragment_length})! "
"Try setting `rollout_fragment_length` to 'auto' OR to a value of "
f"{suggested_rollout_fragment_length}."
)
[docs]
def get_torch_compile_worker_config(self):
"""Returns the TorchCompileConfig to use on workers."""
from ray.rllib.core.rl_module.torch.torch_compile_config import (
TorchCompileConfig,
)
return TorchCompileConfig(
torch_dynamo_backend=self.torch_compile_worker_dynamo_backend,
torch_dynamo_mode=self.torch_compile_worker_dynamo_mode,
)
[docs]
def get_default_rl_module_spec(self) -> RLModuleSpecType:
"""Returns the RLModule spec to use for this algorithm.
Override this method in the subclass to return the RLModule spec, given
the input framework.
Returns:
The RLModuleSpec (or MultiRLModuleSpec) to
use for this algorithm's RLModule.
"""
raise NotImplementedError
[docs]
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
"""Returns the Learner class to use for this algorithm.
Override this method in the sub-class to return the Learner class type given
the input framework.
Returns:
The Learner class to use for this algorithm either as a class type or as
a string (e.g. "ray.rllib.algorithms.ppo.ppo_learner.PPOLearner").
"""
raise NotImplementedError
[docs]
def get_rl_module_spec(
self,
env: Optional[EnvType] = None,
spaces: Optional[Dict[str, gym.Space]] = None,
inference_only: Optional[bool] = None,
) -> RLModuleSpec:
"""Returns the RLModuleSpec based on the given env/spaces.
Args:
env: An optional environment instance, from which to infer the observation-
and action spaces for the RLModule. If not provided, tries to infer
from `spaces`, otherwise from `self.observation_space` and
`self.action_space`. Raises an error, if no information on spaces can be
inferred.
spaces: Optional dict mapping ModuleIDs to 2-tuples of observation- and
action space that should be used for the respective RLModule.
These spaces are usually provided by an already instantiated remote
EnvRunner (call `EnvRunner.get_spaces()`). If not provided, tries
to infer from `env`, otherwise from `self.observation_space` and
`self.action_space`. Raises an error, if no information on spaces can be
inferred.
inference_only: If `True`, the returned module spec is used in an
inference-only setting (sampling) and the RLModule can thus be built in
its light version (if available). For example, the `inference_only`
version of an RLModule might only contain the networks required for
computing actions, but misses additional target- or critic networks.
Returns:
A new RLModuleSpec instance that can be used to build an RLModule.
"""
rl_module_spec = copy.deepcopy(self.rl_module_spec)
# If a MultiRLModuleSpec -> Reduce to single-agent (and assert that
# all non DEFAULT_MODULE_IDs are `learner_only` (so they are not built on
# EnvRunner).
if isinstance(rl_module_spec, MultiRLModuleSpec):
error = False
if DEFAULT_MODULE_ID not in rl_module_spec:
error = True
if inference_only:
for mid, spec in rl_module_spec.rl_module_specs.items():
if mid != DEFAULT_MODULE_ID:
if not spec.learner_only:
error = True
elif len(rl_module_spec) > 1:
error = True
if error:
raise ValueError(
"When calling `AlgorithmConfig.get_rl_module_spec()`, the "
"configuration must contain the `DEFAULT_MODULE_ID` key and all "
"other keys' specs must have the setting `learner_only=True`! If "
"you are using a more complex setup, call "
"`AlgorithmConfig.get_multi_rl_module_spec(...)` instead."
)
rl_module_spec = rl_module_spec[DEFAULT_MODULE_ID]
if spaces is not None:
rl_module_spec.observation_space = spaces[DEFAULT_MODULE_ID][0]
rl_module_spec.action_space = spaces[DEFAULT_MODULE_ID][1]
elif env is not None:
if isinstance(env, gym.vector.VectorEnv):
rl_module_spec.observation_space = env.single_observation_space
rl_module_spec.action_space = env.single_action_space
# If module_config_dict is not defined, set to our generic one.
if rl_module_spec.model_config is None:
rl_module_spec.model_config = self.model_config
if inference_only is not None:
rl_module_spec.inference_only = inference_only
return rl_module_spec
[docs]
def get_multi_rl_module_spec(
self,
*,
env: Optional[EnvType] = None,
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
inference_only: bool = False,
# @HybridAPIStack
policy_dict: Optional[Dict[str, PolicySpec]] = None,
single_agent_rl_module_spec: Optional[RLModuleSpec] = None,
) -> MultiRLModuleSpec:
"""Returns the MultiRLModuleSpec based on the given env/spaces.
Args:
env: An optional environment instance, from which to infer the different
spaces for the individual RLModules. If not provided, tries to infer
from `spaces`, otherwise from `self.observation_space` and
`self.action_space`. Raises an error, if no information on spaces can be
inferred.
spaces: Optional dict mapping ModuleIDs to 2-tuples of observation- and
action space that should be used for the respective RLModule.
These spaces are usually provided by an already instantiated remote
EnvRunner (call `EnvRunner.get_spaces()`). If not provided, tries
to infer from `env`, otherwise from `self.observation_space` and
`self.action_space`. Raises an error, if no information on spaces can be
inferred.
inference_only: If `True`, the returned module spec is used in an
inference-only setting (sampling) and the RLModule can thus be built in
its light version (if available). For example, the `inference_only`
version of an RLModule might only contain the networks required for
computing actions, but misses additional target- or critic networks.
Also, if `True`, the returned spec does NOT contain those (sub)
RLModuleSpecs that have their `learner_only` flag set to True.
Returns:
A new MultiRLModuleSpec instance that can be used to build a MultiRLModule.
"""
# TODO (Kourosh,sven): When we replace policy entirely there is no need for
# this function to map policy_dict to multi_rl_module_specs anymore. The module
# spec is directly given by the user or inferred from env and spaces.
if policy_dict is None:
policy_dict, _ = self.get_multi_agent_setup(env=env, spaces=spaces)
# TODO (Kourosh): Raise an error if the config is not frozen
# If the module is single-agent convert it to multi-agent spec
# The default RLModuleSpec (might be multi-agent or single-agent).
default_rl_module_spec = self.get_default_rl_module_spec()
# The currently configured RLModuleSpec (might be multi-agent or single-agent).
# If None, use the default one.
current_rl_module_spec = self._rl_module_spec or default_rl_module_spec
# Algorithm is currently setup as a single-agent one.
if isinstance(current_rl_module_spec, RLModuleSpec):
# Use either the provided `single_agent_rl_module_spec` (a
# RLModuleSpec), the currently configured one of this
# AlgorithmConfig object, or the default one.
single_agent_rl_module_spec = (
single_agent_rl_module_spec or current_rl_module_spec
)
single_agent_rl_module_spec.inference_only = inference_only
# Now construct the proper MultiRLModuleSpec.
multi_rl_module_spec = MultiRLModuleSpec(
rl_module_specs={
k: copy.deepcopy(single_agent_rl_module_spec)
for k in policy_dict.keys()
},
)
# Algorithm is currently setup as a multi-agent one.
else:
# The user currently has a MultiAgentSpec setup (either via
# self._rl_module_spec or the default spec of this AlgorithmConfig).
assert isinstance(current_rl_module_spec, MultiRLModuleSpec)
# Default is single-agent but the user has provided a multi-agent spec
# so the use-case is multi-agent.
if isinstance(default_rl_module_spec, RLModuleSpec):
# The individual (single-agent) module specs are defined by the user
# in the currently setup MultiRLModuleSpec -> Use that
# RLModuleSpec.
if isinstance(current_rl_module_spec.rl_module_specs, RLModuleSpec):
single_agent_spec = single_agent_rl_module_spec or (
current_rl_module_spec.rl_module_specs
)
single_agent_spec.inference_only = inference_only
module_specs = {
k: copy.deepcopy(single_agent_spec) for k in policy_dict.keys()
}
# The individual (single-agent) module specs have not been configured
# via this AlgorithmConfig object -> Use provided single-agent spec or
# the the default spec (which is also a RLModuleSpec in this
# case).
else:
single_agent_spec = (
single_agent_rl_module_spec or default_rl_module_spec
)
single_agent_spec.inference_only = inference_only
module_specs = {
k: copy.deepcopy(
current_rl_module_spec.rl_module_specs.get(
k, single_agent_spec
)
)
for k in (
policy_dict | current_rl_module_spec.rl_module_specs
).keys()
}
# Now construct the proper MultiRLModuleSpec.
# We need to infer the multi-agent class from `current_rl_module_spec`
# and fill in the module_specs dict.
multi_rl_module_spec = current_rl_module_spec.__class__(
multi_rl_module_class=current_rl_module_spec.multi_rl_module_class,
rl_module_specs=module_specs,
modules_to_load=current_rl_module_spec.modules_to_load,
load_state_path=current_rl_module_spec.load_state_path,
)
# Default is multi-agent and user wants to override it -> Don't use the
# default.
else:
# User provided an override RLModuleSpec -> Use this to
# construct the individual RLModules within the MultiRLModuleSpec.
if single_agent_rl_module_spec is not None:
pass
# User has NOT provided an override RLModuleSpec.
else:
# But the currently setup multi-agent spec has a SingleAgentRLModule
# spec defined -> Use that to construct the individual RLModules
# within the MultiRLModuleSpec.
if isinstance(current_rl_module_spec.rl_module_specs, RLModuleSpec):
# The individual module specs are not given, it is given as one
# RLModuleSpec to be re-used for all
single_agent_rl_module_spec = (
current_rl_module_spec.rl_module_specs
)
# The currently set up multi-agent spec has NO
# RLModuleSpec in it -> Error (there is no way we can
# infer this information from anywhere at this point).
else:
raise ValueError(
"We have a MultiRLModuleSpec "
f"({current_rl_module_spec}), but no "
"`RLModuleSpec`s to compile the individual "
"RLModules' specs! Use "
"`AlgorithmConfig.get_multi_rl_module_spec("
"policy_dict=.., rl_module_spec=..)`."
)
single_agent_rl_module_spec.inference_only = inference_only
# Now construct the proper MultiRLModuleSpec.
multi_rl_module_spec = current_rl_module_spec.__class__(
multi_rl_module_class=current_rl_module_spec.multi_rl_module_class,
rl_module_specs={
k: copy.deepcopy(single_agent_rl_module_spec)
for k in policy_dict.keys()
},
modules_to_load=current_rl_module_spec.modules_to_load,
load_state_path=current_rl_module_spec.load_state_path,
)
# Fill in the missing values from the specs that we already have. By combining
# PolicySpecs and the default RLModuleSpec.
for module_id in policy_dict | multi_rl_module_spec.rl_module_specs:
# Remove/skip `learner_only=True` RLModules if `inference_only` is True.
module_spec = multi_rl_module_spec.rl_module_specs[module_id]
if inference_only and module_spec.learner_only:
multi_rl_module_spec.remove_modules(module_id)
continue
policy_spec = policy_dict.get(module_id)
if policy_spec is None:
policy_spec = policy_dict[DEFAULT_MODULE_ID]
if module_spec.module_class is None:
if isinstance(default_rl_module_spec, RLModuleSpec):
module_spec.module_class = default_rl_module_spec.module_class
elif isinstance(default_rl_module_spec.rl_module_specs, RLModuleSpec):
module_class = default_rl_module_spec.rl_module_specs.module_class
# This should be already checked in validate() but we check it
# again here just in case
if module_class is None:
raise ValueError(
"The default rl_module spec cannot have an empty "
"module_class under its RLModuleSpec."
)
module_spec.module_class = module_class
elif module_id in default_rl_module_spec.rl_module_specs:
module_spec.module_class = default_rl_module_spec.rl_module_specs[
module_id
].module_class
else:
raise ValueError(
f"Module class for module {module_id} cannot be inferred. "
f"It is neither provided in the rl_module_spec that "
"is passed in nor in the default module spec used in "
"the algorithm."
)
if module_spec.catalog_class is None:
if isinstance(default_rl_module_spec, RLModuleSpec):
module_spec.catalog_class = default_rl_module_spec.catalog_class
elif isinstance(default_rl_module_spec.rl_module_specs, RLModuleSpec):
catalog_class = default_rl_module_spec.rl_module_specs.catalog_class
module_spec.catalog_class = catalog_class
elif module_id in default_rl_module_spec.rl_module_specs:
module_spec.catalog_class = default_rl_module_spec.rl_module_specs[
module_id
].catalog_class
else:
raise ValueError(
f"Catalog class for module {module_id} cannot be inferred. "
f"It is neither provided in the rl_module_spec that "
"is passed in nor in the default module spec used in "
"the algorithm."
)
# TODO (sven): Find a good way to pack module specific parameters from
# the algorithms into the `model_config_dict`.
if module_spec.observation_space is None:
module_spec.observation_space = policy_spec.observation_space
if module_spec.action_space is None:
module_spec.action_space = policy_spec.action_space
# In case the `RLModuleSpec` does not have a model config dict, we use the
# the one defined by the auto keys and the `model_config_dict` arguments in
# `self.rl_module()`.
if module_spec.model_config is None:
module_spec.model_config = self.model_config
# Otherwise we combine the two dictionaries where settings from the
# `RLModuleSpec` have higher priority.
else:
module_spec.model_config = (
self.model_config | module_spec._get_model_config()
)
return multi_rl_module_spec
def __setattr__(self, key, value):
"""Gatekeeper in case we are in frozen state and need to error."""
# If we are frozen, do not allow to set any attributes anymore.
if hasattr(self, "_is_frozen") and self._is_frozen:
# TODO: Remove `simple_optimizer` entirely.
# Remove need to set `worker_index` in RolloutWorker's c'tor.
if key not in ["simple_optimizer", "worker_index", "_is_frozen"]:
raise AttributeError(
f"Cannot set attribute ({key}) of an already frozen "
"AlgorithmConfig!"
)
# Backward compatibility for checkpoints taken with wheels, in which
# `self.rl_module_spec` was still settable (now it's a property).
if key == "rl_module_spec":
key = "_rl_module_spec"
super().__setattr__(key, value)
def __getitem__(self, item):
"""Shim method to still support accessing properties by key lookup.
This way, an AlgorithmConfig object can still be used as if a dict, e.g.
by Ray Tune.
Examples:
.. testcode::
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
config = AlgorithmConfig()
print(config["lr"])
.. testoutput::
0.001
"""
# TODO: Uncomment this once all algorithms use AlgorithmConfigs under the
# hood (as well as Ray Tune).
# if log_once("algo_config_getitem"):
# logger.warning(
# "AlgorithmConfig objects should NOT be used as dict! "
# f"Try accessing `{item}` directly as a property."
# )
# In case user accesses "old" keys, e.g. "num_workers", which need to
# be translated to their correct property names.
item = self._translate_special_keys(item)
return getattr(self, item)
def __setitem__(self, key, value):
# TODO: Remove comments once all methods/functions only support
# AlgorithmConfigs and there is no more ambiguity anywhere in the code
# on whether an AlgorithmConfig is used or an old python config dict.
# raise AttributeError(
# "AlgorithmConfig objects should not have their values set like dicts"
# f"(`config['{key}'] = {value}`), "
# f"but via setting their properties directly (config.{prop} = {value})."
# )
if key == "multiagent":
raise AttributeError(
"Cannot set `multiagent` key in an AlgorithmConfig!\nTry setting "
"the multi-agent components of your AlgorithmConfig object via the "
"`multi_agent()` method and its arguments.\nE.g. `config.multi_agent("
"policies=.., policy_mapping_fn.., policies_to_train=..)`."
)
super().__setattr__(key, value)
def __contains__(self, item) -> bool:
"""Shim method to help pretend we are a dict."""
prop = self._translate_special_keys(item, warn_deprecated=False)
return hasattr(self, prop)
[docs]
def get(self, key, default=None):
"""Shim method to help pretend we are a dict."""
prop = self._translate_special_keys(key, warn_deprecated=False)
return getattr(self, prop, default)
[docs]
def pop(self, key, default=None):
"""Shim method to help pretend we are a dict."""
return self.get(key, default)
[docs]
def keys(self):
"""Shim method to help pretend we are a dict."""
return self.to_dict().keys()
[docs]
def values(self):
"""Shim method to help pretend we are a dict."""
return self.to_dict().values()
[docs]
def items(self):
"""Shim method to help pretend we are a dict."""
return self.to_dict().items()
@property
def _model_config_auto_includes(self) -> Dict[str, Any]:
"""Defines which `AlgorithmConfig` settings/properties should be
auto-included into `self.model_config`.
The dictionary in this property contains the default configuration of an
algorithm. Together with the `self._model`, this method is used to
define the configuration sent to the `RLModule`.
Returns:
A dictionary with the automatically included properties/settings of this
`AlgorithmConfig` object into `self.model_config`.
"""
return {}
# -----------------------------------------------------------
# Various validation methods for different types of settings.
# -----------------------------------------------------------
def _value_error(self, errmsg) -> None:
msg = errmsg + (
"\nTo suppress all validation errors, set "
"`config.experimental(_validate_config=False)` at your own risk."
)
if self._validate_config:
raise ValueError(msg)
else:
logger.warning(errmsg)
def _validate_env_runner_settings(self) -> None:
allowed_vectorize_modes = set(
list(gym.envs.registration.VectorizeMode.__members__.keys())
+ list(gym.envs.registration.VectorizeMode.__members__.values())
)
if self.gym_env_vectorize_mode not in allowed_vectorize_modes:
self._value_error(
f"`gym_env_vectorize_mode` ({self.gym_env_vectorize_mode}) must be a "
"member of `gym.envs.registration.VectorizeMode`! Allowed values "
f"are {allowed_vectorize_modes}."
)
def _validate_callbacks_settings(self) -> None:
"""Validates callbacks settings."""
# Old API stack:
# - self.callbacks_cls must be a subclass of RLlibCallback.
# - All self.callbacks_... attributes must be None.
if not self.enable_env_runner_and_connector_v2:
if (
self.callbacks_on_environment_created is not None
or self.callbacks_on_algorithm_init is not None
or self.callbacks_on_train_result is not None
or self.callbacks_on_evaluate_start is not None
or self.callbacks_on_evaluate_end is not None
or self.callbacks_on_sample_end is not None
or self.callbacks_on_environment_created is not None
or self.callbacks_on_episode_created is not None
or self.callbacks_on_episode_start is not None
or self.callbacks_on_episode_step is not None
or self.callbacks_on_episode_end is not None
or self.callbacks_on_checkpoint_loaded is not None
or self.callbacks_on_env_runners_recreated is not None
):
self._value_error(
"Config settings `config.callbacks(on_....=lambda ..)` aren't "
"supported on the old API stack! Switch to the new API stack "
"through `config.api_stack(enable_env_runner_and_connector_v2=True,"
" enable_rl_module_and_learner=True)`."
)
def _validate_framework_settings(self) -> None:
"""Validates framework settings and checks whether framework is installed."""
_tf1, _tf, _tfv = None, None, None
_torch = None
if self.framework_str not in {"tf", "tf2"} and self.framework_str != "torch":
return
elif self.framework_str in {"tf", "tf2"}:
_tf1, _tf, _tfv = try_import_tf()
else:
_torch, _ = try_import_torch()
# Can not use "tf" with learner API.
if self.framework_str == "tf" and self.enable_rl_module_and_learner:
self._value_error(
"Cannot use `framework=tf` with the new API stack! Either switch to tf2"
" via `config.framework('tf2')` OR disable the new API stack via "
"`config.api_stack(enable_rl_module_and_learner=False)`."
)
# Check if torch framework supports torch.compile.
if (
_torch is not None
and self.framework_str == "torch"
and version.parse(_torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
and (self.torch_compile_learner or self.torch_compile_worker)
):
self._value_error("torch.compile is only supported from torch 2.0.0")
# Make sure the Learner's torch-what-to-compile setting is supported.
if self.torch_compile_learner:
from ray.rllib.core.learner.torch.torch_learner import (
TorchCompileWhatToCompile,
)
if self.torch_compile_learner_what_to_compile not in [
TorchCompileWhatToCompile.FORWARD_TRAIN,
TorchCompileWhatToCompile.COMPLETE_UPDATE,
]:
self._value_error(
f"`config.torch_compile_learner_what_to_compile` must be one of ["
f"TorchCompileWhatToCompile.forward_train, "
f"TorchCompileWhatToCompile.complete_update] but is"
f" {self.torch_compile_learner_what_to_compile}"
)
self._check_if_correct_nn_framework_installed(_tf1, _tf, _torch)
self._resolve_tf_settings(_tf1, _tfv)
def _validate_resources_settings(self):
"""Checks, whether resources related settings make sense."""
# TODO @Avnishn: This is a short-term work around due to
# https://github.com/ray-project/ray/issues/35409
# Remove this once we are able to specify placement group bundle index in RLlib
if self.num_cpus_per_learner > 1 and self.num_gpus_per_learner > 0:
self._value_error(
"Can't set both `num_cpus_per_learner` > 1 and "
" `num_gpus_per_learner` > 0! Either set "
"`num_cpus_per_learner` > 1 (and `num_gpus_per_learner`"
"=0) OR set `num_gpus_per_learner` > 0 (and leave "
"`num_cpus_per_learner` at its default value of 1). "
"This is due to issues with placement group fragmentation. See "
"https://github.com/ray-project/ray/issues/35409 for more details."
)
def _validate_multi_agent_settings(self):
"""Checks, whether multi-agent related settings make sense."""
# Check `policies_to_train` for invalid entries.
if isinstance(self.policies_to_train, (list, set, tuple)):
for pid in self.policies_to_train:
if pid not in self.policies:
self._value_error(
"`config.multi_agent(policies_to_train=..)` contains "
f"policy ID ({pid}) that was not defined in "
f"`config.multi_agent(policies=..)`!"
)
# TODO (sven): For now, vectorization is not allowed on new EnvRunners with
# multi-agent.
if (
self.is_multi_agent
and self.enable_env_runner_and_connector_v2
and self.num_envs_per_env_runner > 1
):
self._value_error(
"For now, using env vectorization "
"(`config.num_envs_per_env_runner > 1`) in combination with "
"multi-agent AND the new EnvRunners is not supported! Try setting "
"`config.num_envs_per_env_runner = 1`."
)
def _validate_evaluation_settings(self):
"""Checks, whether evaluation related settings make sense."""
# Async evaluation has been deprecated. Use "simple" parallel mode instead
# (which is also async):
# `config.evaluation(evaluation_parallel_to_training=True)`.
if self.enable_async_evaluation is True:
self._value_error(
"`enable_async_evaluation` has been deprecated (you should set this to "
"False)! Use `config.evaluation(evaluation_parallel_to_training=True)` "
"instead."
)
# If `evaluation_num_env_runners` > 0, warn if `evaluation_interval` is 0 or
# None.
if self.evaluation_num_env_runners > 0 and not self.evaluation_interval:
logger.warning(
f"You have specified {self.evaluation_num_env_runners} "
"evaluation workers, but your `evaluation_interval` is 0 or None! "
"Therefore, evaluation doesn't occur automatically with each"
" call to `Algorithm.train()`. Instead, you have to call "
"`Algorithm.evaluate()` manually in order to trigger an "
"evaluation run."
)
# If `evaluation_num_env_runners=0` and
# `evaluation_parallel_to_training=True`, warn that you need
# at least one remote eval worker for parallel training and
# evaluation, and set `evaluation_parallel_to_training` to False.
if (
self.evaluation_num_env_runners == 0
and self.evaluation_parallel_to_training
):
self._value_error(
"`evaluation_parallel_to_training` can only be done if "
"`evaluation_num_env_runners` > 0! Try setting "
"`config.evaluation_parallel_to_training` to False."
)
# If `evaluation_duration=auto`, error if
# `evaluation_parallel_to_training=False`.
if self.evaluation_duration == "auto":
if not self.evaluation_parallel_to_training:
self._value_error(
"`evaluation_duration=auto` not supported for "
"`evaluation_parallel_to_training=False`!"
)
elif self.evaluation_duration_unit == "episodes":
logger.warning(
"When using `config.evaluation_duration='auto'`, the sampling unit "
"used is always 'timesteps'! You have set "
"`config.evaluation_duration_unit='episodes'`, which is ignored."
)
# Make sure, `evaluation_duration` is an int otherwise.
elif (
not isinstance(self.evaluation_duration, int)
or self.evaluation_duration <= 0
):
self._value_error(
f"`evaluation_duration` ({self.evaluation_duration}) must be an "
f"int and >0!"
)
def _validate_input_settings(self):
"""Checks, whether input related settings make sense."""
if self.input_ == "sampler" and self.off_policy_estimation_methods:
self._value_error(
"Off-policy estimation methods can only be used if the input is a "
"dataset. We currently do not support applying off_policy_estimation_"
"method on a sampler input."
)
if self.input_ == "dataset":
# If you need to read a Ray dataset set the parallelism and
# num_cpus_per_read_task from rollout worker settings
self.input_config["num_cpus_per_read_task"] = self.num_cpus_per_env_runner
if self.in_evaluation:
# If using dataset for evaluation, the parallelism gets set to
# evaluation_num_env_runners for backward compatibility and num_cpus
# gets set to num_cpus_per_env_runner from rollout worker. User only
# needs to set evaluation_num_env_runners.
self.input_config["parallelism"] = self.evaluation_num_env_runners or 1
else:
# If using dataset for training, the parallelism and num_cpus gets set
# based on rollout worker parameters. This is for backwards
# compatibility for now. User only needs to set num_env_runners.
self.input_config["parallelism"] = self.num_env_runners or 1
def _validate_new_api_stack_settings(self):
"""Checks, whether settings related to the new API stack make sense."""
# Old API stack checks.
if not self.enable_rl_module_and_learner:
# Throw a warning if the user has used `self.rl_module(rl_module_spec=...)`
# but has not enabled the new API stack at the same time.
if self._rl_module_spec is not None:
logger.warning(
"You have setup a RLModuleSpec (via calling "
"`config.rl_module(...)`), but have not enabled the new API stack. "
"To enable it, call `config.api_stack(enable_rl_module_and_learner="
"True)`."
)
# Throw a warning if the user has used `self.training(learner_class=...)`
# but has not enabled the new API stack at the same time.
if self._learner_class is not None:
logger.warning(
"You specified a custom Learner class (via "
f"`AlgorithmConfig.training(learner_class={self._learner_class})`, "
f"but have the new API stack disabled. You need to enable it via "
"`AlgorithmConfig.api_stack(enable_rl_module_and_learner=True)`."
)
# User is using the new EnvRunners, but forgot to switch on
# `enable_rl_module_and_learner`.
if self.enable_env_runner_and_connector_v2:
self._value_error(
"You are using the new API stack EnvRunners (SingleAgentEnvRunner "
"or MultiAgentEnvRunner), but have forgotten to switch on the new "
"API stack! Try setting "
"`config.api_stack(enable_rl_module_and_learner=True)`."
)
# Early out. The rest of this method is only for
# `enable_rl_module_and_learner=True`.
return
# Warn about new API stack on by default.
logger.warning(
f"You are running {self.algo_class.__name__} on the new API stack! "
"This is the new default behavior for this algorithm. If you don't "
"want to use the new API stack, set `config.api_stack("
"enable_rl_module_and_learner=False,"
"enable_env_runner_and_connector_v2=False)`. For a detailed migration "
"guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html" # noqa
)
# Disabled hybrid API stack. Now, both `enable_rl_module_and_learner` and
# `enable_env_runner_and_connector_v2` must be True or both False.
if not self.enable_env_runner_and_connector_v2:
self._value_error(
"Setting `enable_rl_module_and_learner` to True and "
"`enable_env_runner_and_connector_v2` to False ('hybrid API stack'"
") is not longer supported! Set both to True (new API stack) or both "
"to False (old API stack), instead."
)
# For those users that accidentally use the new API stack (because it's the
# default now for many algos), we need to make sure they are warned.
try:
tree.assert_same_structure(self.model, MODEL_DEFAULTS)
# Create copies excluding the specified key
check(
{k: v for k, v in self.model.items() if k != "vf_share_layers"},
{k: v for k, v in MODEL_DEFAULTS.items() if k != "vf_share_layers"},
)
except Exception:
logger.warning(
"You configured a custom `model` config (probably through calling "
"config.training(model=..), whereas your config uses the new API "
"stack! In order to switch off the new API stack, set in your config: "
"`config.api_stack(enable_rl_module_and_learner=False, "
"enable_env_runner_and_connector_v2=False)`. If you DO want to use "
"the new API stack, configure your model, instead, through: "
"`config.rl_module(model_config={..})`."
)
# LR-schedule checking.
Scheduler.validate(
fixed_value_or_schedule=self.lr,
setting_name="lr",
description="learning rate",
)
# This is not compatible with RLModules, which all have a method
# `forward_exploration` to specify custom exploration behavior.
if self.exploration_config:
self._value_error(
"When the RLModule API is enabled, exploration_config can not be "
"set. If you want to implement custom exploration behaviour, "
"please modify the `forward_exploration` method of the "
"RLModule at hand. On configs that have a default exploration "
"config, this must be done via "
"`config.exploration_config={}`."
)
not_compatible_w_rlm_msg = (
"Cannot use `{}` option with the new API stack (RLModule and "
"Learner APIs)! `{}` is part of the ModelV2 API and Policy API,"
" which are not compatible with the new API stack. You can either "
"deactivate the new stack via `config.api_stack( "
"enable_rl_module_and_learner=False)`,"
"or use the new stack (incl. RLModule API) and implement your "
"custom model as an RLModule."
)
if self.model["custom_model"] is not None:
self._value_error(
not_compatible_w_rlm_msg.format("custom_model", "custom_model")
)
if self.model["custom_model_config"] != {}:
self._value_error(
not_compatible_w_rlm_msg.format(
"custom_model_config", "custom_model_config"
)
)
# TODO (sven): Once everything is on the new API stack, we won't need this method
# anymore.
def _validate_to_be_deprecated_settings(self):
# `render_env` is deprecated on new API stack.
if self.enable_env_runner_and_connector_v2 and self.render_env is not False:
deprecation_warning(
old="AlgorithmConfig.render_env",
help="The `render_env` setting is not supported on the new API stack! "
"In order to log videos to WandB (or other loggers), take a look at "
"this example here: "
"https://github.com/ray-project/ray/blob/master/rllib/examples/envs/env_rendering_and_recording.py", # noqa
)
if self.preprocessor_pref not in ["rllib", "deepmind", None]:
self._value_error(
"`config.preprocessor_pref` must be either 'rllib', 'deepmind' or None!"
)
# Check model config.
# If no preprocessing, propagate into model's config as well
# (so model knows whether inputs are preprocessed or not).
if self._disable_preprocessor_api is True:
self.model["_disable_preprocessor_api"] = True
# If no action flattening, propagate into model's config as well
# (so model knows whether action inputs are already flattened or not).
if self._disable_action_flattening is True:
self.model["_disable_action_flattening"] = True
if self.model.get("custom_preprocessor"):
deprecation_warning(
old="AlgorithmConfig.training(model={'custom_preprocessor': ...})",
help="Custom preprocessors are deprecated, "
"since they sometimes conflict with the built-in "
"preprocessors for handling complex observation spaces. "
"Please use wrapper classes around your environment "
"instead.",
error=True,
)
# Multi-GPU settings.
if self.simple_optimizer is True:
pass
# Multi-GPU setting: Must use MultiGPUTrainOneStep.
elif not self.enable_rl_module_and_learner and self.num_gpus > 1:
# TODO: AlphaStar uses >1 GPUs differently (1 per policy actor), so this is
# ok for tf2 here.
# Remove this hacky check, once we have fully moved to the Learner API.
if self.framework_str == "tf2" and type(self).__name__ != "AlphaStar":
self._value_error(
"`num_gpus` > 1 not supported yet for "
f"framework={self.framework_str}!"
)
elif self.simple_optimizer is True:
self._value_error(
"Cannot use `simple_optimizer` if `num_gpus` > 1! "
"Consider not setting `simple_optimizer` in your config."
)
self.simple_optimizer = False
# Auto-setting: Use simple-optimizer for tf-eager or multiagent,
# otherwise: MultiGPUTrainOneStep (if supported by the algo's execution
# plan).
elif self.simple_optimizer == DEPRECATED_VALUE:
# tf-eager: Must use simple optimizer.
if self.framework_str not in ["tf", "torch"]:
self.simple_optimizer = True
# Multi-agent case: Try using MultiGPU optimizer (only
# if all policies used are DynamicTFPolicies or TorchPolicies).
elif self.is_multi_agent:
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy.torch_policy import TorchPolicy
default_policy_cls = None
if self.algo_class:
default_policy_cls = self.algo_class.get_default_policy_class(self)
policies = self.policies
policy_specs = (
[
PolicySpec(*spec) if isinstance(spec, (tuple, list)) else spec
for spec in policies.values()
]
if isinstance(policies, dict)
else [PolicySpec() for _ in policies]
)
if any(
(spec.policy_class or default_policy_cls) is None
or not issubclass(
spec.policy_class or default_policy_cls,
(DynamicTFPolicy, TorchPolicy),
)
for spec in policy_specs
):
self.simple_optimizer = True
else:
self.simple_optimizer = False
else:
self.simple_optimizer = False
# User manually set simple-optimizer to False -> Error if tf-eager.
elif self.simple_optimizer is False:
if self.framework_str == "tf2":
self._value_error(
"`simple_optimizer=False` not supported for "
f"config.framework({self.framework_str})!"
)
def _validate_offline_settings(self):
# If a user does not have an environment and cannot run evaluation,
# or does not want to run evaluation, she needs to provide at least
# action and observation spaces. Note, we require here the spaces,
# i.e. a user cannot provide an environment instead because we do
# not want to create the environment to receive spaces.
if self.is_offline and (
not (self.evaluation_num_env_runners > 0 or self.evaluation_interval)
and (self.action_space is None or self.observation_space is None)
):
self._value_error(
"If no evaluation should be run, `action_space` and "
"`observation_space` must be provided."
)
from ray.rllib.offline.offline_data import OfflineData
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
if self.offline_data_class and not issubclass(
self.offline_data_class, OfflineData
):
self._value_error(
"Unknown `offline_data_class`. OfflineData class needs to inherit "
"from `OfflineData` class."
)
if self.prelearner_class and not issubclass(
self.prelearner_class, OfflinePreLearner
):
self._value_error(
"Unknown `prelearner_class`. PreLearner class needs to inherit "
"from `OfflinePreLearner` class."
)
from ray.rllib.utils.replay_buffers.episode_replay_buffer import (
EpisodeReplayBuffer,
)
if self.prelearner_buffer_class and not issubclass(
self.prelearner_buffer_class, EpisodeReplayBuffer
):
self._value_error(
"Unknown `prelearner_buffer_class`. The buffer class for the "
"prelearner needs to inherit from `EpisodeReplayBuffer`. "
"Specifically it needs to store and sample lists of "
"`Single-/MultiAgentEpisode`s."
)
if self.input_read_batch_size and not (
self.input_read_episodes or self.input_read_sample_batches
):
self._value_error(
"Setting `input_read_batch_size` is only allowed in case of a "
"dataset that holds either `EpisodeType` or `BatchType` data (i.e. "
"rows that contains multiple timesteps), but neither "
"`input_read_episodes` nor `input_read_sample_batches` is set to "
"`True`."
)
if (
self.output
and self.output_write_episodes
and self.batch_mode != "complete_episodes"
):
self._value_error(
"When recording episodes only complete episodes should be "
"recorded (i.e. `batch_mode=='complete_episodes'`). Otherwise "
"recorded episodes cannot be read in for training."
)
@property
def is_offline(self) -> bool:
"""Defines, if this config is for offline RL."""
return (
# Does the user provide any input path/class?
bool(self.input_)
# Is it a real string path or list of such paths.
and (
isinstance(self.input_, str)
or (isinstance(self.input_, list) and isinstance(self.input_[0], str))
)
# Could be old stack - which is considered very differently.
and self.input_ != "sampler"
and self.enable_rl_module_and_learner
)
@staticmethod
def _serialize_dict(config):
# Serialize classes to classpaths:
if "callbacks_class" in config:
config["callbacks"] = config.pop("callbacks_class")
if "class" in config:
config["class"] = serialize_type(config["class"])
config["callbacks"] = serialize_type(config["callbacks"])
config["sample_collector"] = serialize_type(config["sample_collector"])
if isinstance(config["env"], type):
config["env"] = serialize_type(config["env"])
if "replay_buffer_config" in config and (
isinstance(config["replay_buffer_config"].get("type"), type)
):
config["replay_buffer_config"]["type"] = serialize_type(
config["replay_buffer_config"]["type"]
)
if isinstance(config["exploration_config"].get("type"), type):
config["exploration_config"]["type"] = serialize_type(
config["exploration_config"]["type"]
)
if isinstance(config["model"].get("custom_model"), type):
config["model"]["custom_model"] = serialize_type(
config["model"]["custom_model"]
)
# List'ify `policies`, iff a set or tuple (these types are not JSON'able).
ma_config = config.get("multiagent")
if ma_config is not None:
if isinstance(ma_config.get("policies"), (set, tuple)):
ma_config["policies"] = list(ma_config["policies"])
# Do NOT serialize functions/lambdas.
if ma_config.get("policy_mapping_fn"):
ma_config["policy_mapping_fn"] = NOT_SERIALIZABLE
if ma_config.get("policies_to_train"):
ma_config["policies_to_train"] = NOT_SERIALIZABLE
# However, if these "multiagent" settings have been provided directly
# on the top-level (as they should), we override the settings under
# "multiagent". Note that the "multiagent" key should no longer be used anyways.
if isinstance(config.get("policies"), (set, tuple)):
config["policies"] = list(config["policies"])
# Do NOT serialize functions/lambdas.
if config.get("policy_mapping_fn"):
config["policy_mapping_fn"] = NOT_SERIALIZABLE
if config.get("policies_to_train"):
config["policies_to_train"] = NOT_SERIALIZABLE
return config
@staticmethod
def _translate_special_keys(key: str, warn_deprecated: bool = True) -> str:
# Handle special key (str) -> `AlgorithmConfig.[some_property]` cases.
if key == "callbacks":
key = "callbacks_class"
elif key == "create_env_on_driver":
key = "create_env_on_local_worker"
elif key == "custom_eval_function":
key = "custom_evaluation_function"
elif key == "framework":
key = "framework_str"
elif key == "input":
key = "input_"
elif key == "lambda":
key = "lambda_"
elif key == "num_cpus_for_driver":
key = "num_cpus_for_main_process"
elif key == "num_workers":
key = "num_env_runners"
# Deprecated keys.
if warn_deprecated:
if key == "collect_metrics_timeout":
deprecation_warning(
old="collect_metrics_timeout",
new="metrics_episode_collection_timeout_s",
error=True,
)
elif key == "metrics_smoothing_episodes":
deprecation_warning(
old="config.metrics_smoothing_episodes",
new="config.metrics_num_episodes_for_smoothing",
error=True,
)
elif key == "min_iter_time_s":
deprecation_warning(
old="config.min_iter_time_s",
new="config.min_time_s_per_iteration",
error=True,
)
elif key == "min_time_s_per_reporting":
deprecation_warning(
old="config.min_time_s_per_reporting",
new="config.min_time_s_per_iteration",
error=True,
)
elif key == "min_sample_timesteps_per_reporting":
deprecation_warning(
old="config.min_sample_timesteps_per_reporting",
new="config.min_sample_timesteps_per_iteration",
error=True,
)
elif key == "min_train_timesteps_per_reporting":
deprecation_warning(
old="config.min_train_timesteps_per_reporting",
new="config.min_train_timesteps_per_iteration",
error=True,
)
elif key == "timesteps_per_iteration":
deprecation_warning(
old="config.timesteps_per_iteration",
new="`config.min_sample_timesteps_per_iteration` OR "
"`config.min_train_timesteps_per_iteration`",
error=True,
)
elif key == "evaluation_num_episodes":
deprecation_warning(
old="config.evaluation_num_episodes",
new="`config.evaluation_duration` and "
"`config.evaluation_duration_unit=episodes`",
error=True,
)
return key
def _check_if_correct_nn_framework_installed(self, _tf1, _tf, _torch):
"""Check if tf/torch experiment is running and tf/torch installed."""
if self.framework_str in {"tf", "tf2"}:
if not (_tf1 or _tf):
raise ImportError(
(
"TensorFlow was specified as the framework to use (via `config."
"framework([tf|tf2])`)! However, no installation was "
"found. You can install TensorFlow via `pip install tensorflow`"
)
)
elif self.framework_str == "torch":
if not _torch:
raise ImportError(
(
"PyTorch was specified as the framework to use (via `config."
"framework('torch')`)! However, no installation was found. You "
"can install PyTorch via `pip install torch`."
)
)
def _resolve_tf_settings(self, _tf1, _tfv):
"""Check and resolve tf settings."""
if _tf1 and self.framework_str == "tf2":
if self.framework_str == "tf2" and _tfv < 2:
raise ValueError(
"You configured `framework`=tf2, but your installed "
"pip tf-version is < 2.0! Make sure your TensorFlow "
"version is >= 2.x."
)
if not _tf1.executing_eagerly():
_tf1.enable_eager_execution()
# Recommend setting tracing to True for speedups.
logger.info(
f"Executing eagerly (framework='{self.framework_str}'),"
f" with eager_tracing={self.eager_tracing}. For "
"production workloads, make sure to set eager_tracing=True"
" in order to match the speed of tf-static-graph "
"(framework='tf'). For debugging purposes, "
"`eager_tracing=False` is the best choice."
)
# Tf-static-graph (framework=tf): Recommend upgrading to tf2 and
# enabling eager tracing for similar speed.
elif _tf1 and self.framework_str == "tf":
logger.info(
"Your framework setting is 'tf', meaning you are using "
"static-graph mode. Set framework='tf2' to enable eager "
"execution with tf2.x. You may also then want to set "
"eager_tracing=True in order to reach similar execution "
"speed as with static-graph mode."
)
[docs]
@OldAPIStack
def get_multi_agent_setup(
self,
*,
policies: Optional[MultiAgentPolicyConfigDict] = None,
env: Optional[EnvType] = None,
spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None,
default_policy_class: Optional[Type[Policy]] = None,
) -> Tuple[MultiAgentPolicyConfigDict, Callable[[PolicyID, SampleBatchType], bool]]:
r"""Compiles complete multi-agent config (dict) from the information in `self`.
Infers the observation- and action spaces, the policy classes, and the policy's
configs. The returned `MultiAgentPolicyConfigDict` is fully unified and strictly
maps PolicyIDs to complete PolicySpec objects (with all their fields not-None).
Examples:
.. testcode::
import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
config = (
PPOConfig()
.environment("CartPole-v1")
.framework("torch")
.multi_agent(policies={"pol1", "pol2"}, policies_to_train=["pol1"])
)
policy_dict, is_policy_to_train = config.get_multi_agent_setup(
env=gym.make("CartPole-v1"))
is_policy_to_train("pol1")
is_policy_to_train("pol2")
Args:
policies: An optional multi-agent `policies` dict, mapping policy IDs
to PolicySpec objects. If not provided uses `self.policies`
instead. Note that the `policy_class`, `observation_space`, and
`action_space` properties in these PolicySpecs may be None and must
therefore be inferred here.
env: An optional env instance, from which to infer the different spaces for
the different policies. If not provided, tries to infer from
`spaces`. Otherwise from `self.observation_space` and
`self.action_space`. Raises an error, if no information on spaces can be
infered.
spaces: Optional dict mapping policy IDs to tuples of 1) observation space
and 2) action space that should be used for the respective policy.
These spaces were usually provided by an already instantiated remote
EnvRunner. Note that if the `env` argument is provided, tries to
infer spaces from `env` first.
default_policy_class: The Policy class to use should a PolicySpec have its
policy_class property set to None.
Returns:
A tuple consisting of 1) a MultiAgentPolicyConfigDict and 2) a
`is_policy_to_train(PolicyID, SampleBatchType) -> bool` callable.
Raises:
ValueError: In case, no spaces can be infered for the policy/ies.
ValueError: In case, two agents in the env map to the same PolicyID
(according to `self.policy_mapping_fn`), but have different action- or
observation spaces according to the infered space information.
"""
policies = copy.deepcopy(policies or self.policies)
# Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy
# automatically via empty PolicySpec (makes RLlib infer observation- and
# action spaces as well as the Policy's class).
if isinstance(policies, (set, list, tuple)):
policies = {pid: PolicySpec() for pid in policies}
# Try extracting spaces from env or from given spaces dict.
env_obs_space = None
env_act_space = None
# Env is a ray.remote: Get spaces via its (automatically added)
# `_get_spaces()` method.
if isinstance(env, ray.actor.ActorHandle):
env_obs_space, env_act_space = ray.get(env._get_spaces.remote())
# Normal env (gym.Env or MultiAgentEnv): These should have the
# `observation_space` and `action_space` properties.
elif env is not None:
# `env` is a gymnasium.vector.Env.
if hasattr(env, "single_observation_space") and isinstance(
env.single_observation_space, gym.Space
):
env_obs_space = env.single_observation_space
# `env` is a gymnasium.Env.
elif hasattr(env, "observation_space") and isinstance(
env.observation_space, gym.Space
):
env_obs_space = env.observation_space
# `env` is a gymnasium.vector.Env.
if hasattr(env, "single_action_space") and isinstance(
env.single_action_space, gym.Space
):
env_act_space = env.single_action_space
# `env` is a gymnasium.Env.
elif hasattr(env, "action_space") and isinstance(
env.action_space, gym.Space
):
env_act_space = env.action_space
# Last resort: Try getting the env's spaces from the spaces
# dict's special __env__ key.
if spaces is not None:
if env_obs_space is None:
env_obs_space = spaces.get(INPUT_ENV_SPACES, [None])[0]
if env_act_space is None:
env_act_space = spaces.get(INPUT_ENV_SPACES, [None, None])[1]
# Check each defined policy ID and unify its spec.
for pid, policy_spec in policies.copy().items():
# Convert to PolicySpec if plain list/tuple.
if not isinstance(policy_spec, PolicySpec):
policies[pid] = policy_spec = PolicySpec(*policy_spec)
# Infer policy classes for policies dict, if not provided (None).
if policy_spec.policy_class is None and default_policy_class is not None:
policies[pid].policy_class = default_policy_class
# Infer observation space.
if policy_spec.observation_space is None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Module's space is provided -> Use it as-is.
if spaces is not None and pid in spaces:
obs_space = spaces[pid][0]
# MultiAgentEnv -> Check, whether agents have different spaces.
elif isinstance(env_unwrapped, MultiAgentEnv):
obs_space = None
mapping_fn = self.policy_mapping_fn
aids = list(
env_unwrapped.possible_agents
if hasattr(env_unwrapped, "possible_agents")
and env_unwrapped.possible_agents
else env_unwrapped.get_agent_ids()
)
if len(aids) == 0:
one_obs_space = env_unwrapped.observation_space
else:
one_obs_space = env_unwrapped.get_observation_space(aids[0])
# If all obs spaces are the same, just use the first space.
if all(
env_unwrapped.get_observation_space(aid) == one_obs_space
for aid in aids
):
obs_space = one_obs_space
# Need to reverse-map spaces (for the different agents) to certain
# policy IDs. We have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in aids:
# Match: Assign spaces for this agentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
obs_space is not None
and env_unwrapped.get_observation_space(aid)
!= obs_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different observation spaces!"
)
obs_space = env_unwrapped.get_observation_space(aid)
# Just use env's obs space as-is.
elif env_obs_space is not None:
obs_space = env_obs_space
# Space given directly in config.
elif self.observation_space:
obs_space = self.observation_space
else:
raise ValueError(
"`observation_space` not provided in PolicySpec for "
f"{pid} and env does not have an observation space OR "
"no spaces received from other workers' env(s) OR no "
"`observation_space` specified in config!"
)
policies[pid].observation_space = obs_space
# Infer action space.
if policy_spec.action_space is None:
env_unwrapped = env.unwrapped if hasattr(env, "unwrapped") else env
# Module's space is provided -> Use it as-is.
if spaces is not None and pid in spaces:
act_space = spaces[pid][1]
# MultiAgentEnv -> Check, whether agents have different spaces.
elif isinstance(env_unwrapped, MultiAgentEnv):
act_space = None
mapping_fn = self.policy_mapping_fn
aids = list(
env_unwrapped.possible_agents
if hasattr(env_unwrapped, "possible_agents")
and env_unwrapped.possible_agents
else env_unwrapped.get_agent_ids()
)
if len(aids) == 0:
one_act_space = env_unwrapped.action_space
else:
one_act_space = env_unwrapped.get_action_space(aids[0])
# If all obs spaces are the same, just use the first space.
if all(
env_unwrapped.get_action_space(aid) == one_act_space
for aid in aids
):
act_space = one_act_space
# Need to reverse-map spaces (for the different agents) to certain
# policy IDs. We have to compare the ModuleID with all possible
# AgentIDs and find the agent ID that matches.
elif mapping_fn:
for aid in aids:
# Match: Assign spaces for this AgentID to the PolicyID.
if mapping_fn(aid, None, worker=None) == pid:
# Make sure, different agents that map to the same
# policy don't have different spaces.
if (
act_space is not None
and env_unwrapped.get_action_space(aid) != act_space
):
raise ValueError(
"Two agents in your environment map to the "
"same policyID (as per your `policy_mapping"
"_fn`), however, these agents also have "
"different action spaces!"
)
act_space = env_unwrapped.get_action_space(aid)
# Just use env's action space as-is.
elif env_act_space is not None:
act_space = env_act_space
elif self.action_space:
act_space = self.action_space
else:
raise ValueError(
"`action_space` not provided in PolicySpec for "
f"{pid} and env does not have an action space OR "
"no spaces received from other workers' env(s) OR no "
"`action_space` specified in config!"
)
policies[pid].action_space = act_space
# Create entire AlgorithmConfig object from the provided override.
# If None, use {} as override.
if not isinstance(policies[pid].config, AlgorithmConfig):
assert policies[pid].config is None or isinstance(
policies[pid].config, dict
)
policies[pid].config = self.copy(copy_frozen=False).update_from_dict(
policies[pid].config or {}
)
# If collection given, construct a simple default callable returning True
# if the PolicyID is found in the list/set of IDs.
if self.policies_to_train is not None and not callable(self.policies_to_train):
pols = set(self.policies_to_train)
def is_policy_to_train(pid, batch=None):
return pid in pols
else:
is_policy_to_train = self.policies_to_train
return policies, is_policy_to_train
@Deprecated(new="AlgorithmConfig.build_algo", error=False)
def build(self, *args, **kwargs):
return self.build_algo(*args, **kwargs)
@Deprecated(new="AlgorithmConfig.get_multi_rl_module_spec()", error=True)
def get_marl_module_spec(self, *args, **kwargs):
pass
@Deprecated(new="AlgorithmConfig.env_runners(..)", error=True)
def rollouts(self, *args, **kwargs):
pass
@Deprecated(new="AlgorithmConfig.env_runners(..)", error=True)
def exploration(self, *args, **kwargs):
pass
@property
@Deprecated(
new="AlgorithmConfig.fault_tolerance(restart_failed_env_runners=..)",
error=True,
)
def recreate_failed_env_runners(self):
pass
@recreate_failed_env_runners.setter
def recreate_failed_env_runners(self, value):
deprecation_warning(
old="AlgorithmConfig.recreate_failed_env_runners",
new="AlgorithmConfig.restart_failed_env_runners",
error=True,
)
@property
@Deprecated(new="AlgorithmConfig._enable_new_api_stack", error=True)
def _enable_new_api_stack(self):
pass
@_enable_new_api_stack.setter
def _enable_new_api_stack(self, value):
deprecation_warning(
old="AlgorithmConfig._enable_new_api_stack",
new="AlgorithmConfig.enable_rl_module_and_learner",
error=True,
)
@property
@Deprecated(new="AlgorithmConfig.enable_env_runner_and_connector_v2", error=True)
def uses_new_env_runners(self):
pass
@property
@Deprecated(new="AlgorithmConfig.num_env_runners", error=True)
def num_rollout_workers(self):
pass
@num_rollout_workers.setter
def num_rollout_workers(self, value):
deprecation_warning(
old="AlgorithmConfig.num_rollout_workers",
new="AlgorithmConfig.num_env_runners",
error=True,
)
@property
@Deprecated(new="AlgorithmConfig.evaluation_num_workers", error=True)
def evaluation_num_workers(self):
pass
@evaluation_num_workers.setter
def evaluation_num_workers(self, value):
deprecation_warning(
old="AlgorithmConfig.evaluation_num_workers",
new="AlgorithmConfig.evaluation_num_env_runners",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.num_envs_per_env_runner", error=True)
def num_envs_per_worker(self):
pass
@num_envs_per_worker.setter
def num_envs_per_worker(self, value):
deprecation_warning(
old="AlgorithmConfig.num_envs_per_worker",
new="AlgorithmConfig.num_envs_per_env_runner",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.ignore_env_runner_failures", error=True)
def ignore_worker_failures(self):
pass
@ignore_worker_failures.setter
def ignore_worker_failures(self, value):
deprecation_warning(
old="AlgorithmConfig.ignore_worker_failures",
new="AlgorithmConfig.ignore_env_runner_failures",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.restart_failed_env_runners", error=True)
def recreate_failed_workers(self):
pass
@recreate_failed_workers.setter
def recreate_failed_workers(self, value):
deprecation_warning(
old="AlgorithmConfig.recreate_failed_workers",
new="AlgorithmConfig.restart_failed_env_runners",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.max_num_env_runner_restarts", error=True)
def max_num_worker_restarts(self):
pass
@max_num_worker_restarts.setter
def max_num_worker_restarts(self, value):
deprecation_warning(
old="AlgorithmConfig.max_num_worker_restarts",
new="AlgorithmConfig.max_num_env_runner_restarts",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.delay_between_env_runner_restarts_s", error=True)
def delay_between_worker_restarts_s(self):
pass
@delay_between_worker_restarts_s.setter
def delay_between_worker_restarts_s(self, value):
deprecation_warning(
old="AlgorithmConfig.delay_between_worker_restarts_s",
new="AlgorithmConfig.delay_between_env_runner_restarts_s",
error=True,
)
pass
@property
@Deprecated(
new="AlgorithmConfig.num_consecutive_env_runner_failures_tolerance", error=True
)
def num_consecutive_worker_failures_tolerance(self):
pass
@num_consecutive_worker_failures_tolerance.setter
def num_consecutive_worker_failures_tolerance(self, value):
deprecation_warning(
old="AlgorithmConfig.num_consecutive_worker_failures_tolerance",
new="AlgorithmConfig.num_consecutive_env_runner_failures_tolerance",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.env_runner_health_probe_timeout_s", error=True)
def worker_health_probe_timeout_s(self):
pass
@worker_health_probe_timeout_s.setter
def worker_health_probe_timeout_s(self, value):
deprecation_warning(
old="AlgorithmConfig.worker_health_probe_timeout_s",
new="AlgorithmConfig.env_runner_health_probe_timeout_s",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.env_runner_restore_timeout_s", error=True)
def worker_restore_timeout_s(self):
pass
@worker_restore_timeout_s.setter
def worker_restore_timeout_s(self, value):
deprecation_warning(
old="AlgorithmConfig.worker_restore_timeout_s",
new="AlgorithmConfig.env_runner_restore_timeout_s",
error=True,
)
pass
@property
@Deprecated(
new="AlgorithmConfig.validate_env_runners_after_construction",
error=True,
)
def validate_workers_after_construction(self):
pass
@validate_workers_after_construction.setter
def validate_workers_after_construction(self, value):
deprecation_warning(
old="AlgorithmConfig.validate_workers_after_construction",
new="AlgorithmConfig.validate_env_runners_after_construction",
error=True,
)
pass
# Cleanups from `resources()`.
@property
@Deprecated(new="AlgorithmConfig.num_cpus_per_env_runner", error=True)
def num_cpus_per_worker(self):
pass
@num_cpus_per_worker.setter
def num_cpus_per_worker(self, value):
deprecation_warning(
old="AlgorithmConfig.num_cpus_per_worker",
new="AlgorithmConfig.num_cpus_per_env_runner",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.num_gpus_per_env_runner", error=True)
def num_gpus_per_worker(self):
pass
@num_gpus_per_worker.setter
def num_gpus_per_worker(self, value):
deprecation_warning(
old="AlgorithmConfig.num_gpus_per_worker",
new="AlgorithmConfig.num_gpus_per_env_runner",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.custom_resources_per_env_runner", error=True)
def custom_resources_per_worker(self):
pass
@custom_resources_per_worker.setter
def custom_resources_per_worker(self, value):
deprecation_warning(
old="AlgorithmConfig.custom_resources_per_worker",
new="AlgorithmConfig.custom_resources_per_env_runner",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.num_learners", error=True)
def num_learner_workers(self):
pass
@num_learner_workers.setter
def num_learner_workers(self, value):
deprecation_warning(
old="AlgorithmConfig.num_learner_workers",
new="AlgorithmConfig.num_learners",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.num_cpus_per_learner", error=True)
def num_cpus_per_learner_worker(self):
pass
@num_cpus_per_learner_worker.setter
def num_cpus_per_learner_worker(self, value):
deprecation_warning(
old="AlgorithmConfig.num_cpus_per_learner_worker",
new="AlgorithmConfig.num_cpus_per_learner",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.num_gpus_per_learner", error=True)
def num_gpus_per_learner_worker(self):
pass
@num_gpus_per_learner_worker.setter
def num_gpus_per_learner_worker(self, value):
deprecation_warning(
old="AlgorithmConfig.num_gpus_per_learner_worker",
new="AlgorithmConfig.num_gpus_per_learner",
error=True,
)
pass
@property
@Deprecated(new="AlgorithmConfig.num_cpus_for_local_worker", error=True)
def num_cpus_for_local_worker(self):
pass
@num_cpus_for_local_worker.setter
def num_cpus_for_local_worker(self, value):
deprecation_warning(
old="AlgorithmConfig.num_cpus_for_local_worker",
new="AlgorithmConfig.num_cpus_for_main_process",
error=True,
)
pass
class TorchCompileWhatToCompile(str, Enum):
"""Enumerates schemes of what parts of the TorchLearner can be compiled.
This can be either the entire update step of the learner or only the forward
methods (and therein the forward_train method) of the RLModule.
.. note::
- torch.compiled code can become slow on graph breaks or even raise
errors on unsupported operations. Empirically, compiling
`forward_train` should introduce little graph breaks, raise no
errors but result in a speedup comparable to compiling the
complete update.
- Using `complete_update` is experimental and may result in errors.
"""
# Compile the entire update step of the learner.
# This includes the forward pass of the RLModule, the loss computation, and the
# optimizer step.
COMPLETE_UPDATE = "complete_update"
# Only compile the forward methods (and therein the forward_train method) of the
# RLModule.
FORWARD_TRAIN = "forward_train"