import logging
from typing import Any, Dict, Optional, Tuple, Type, Union
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.dqn.dqn import DQN
from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.policy.policy import Policy
from ray.rllib.utils import deep_update
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.rllib.utils.framework import try_import_tf, try_import_tfp
from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType, ResultDict
tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
logger = logging.getLogger(__name__)
[docs]
class SACConfig(AlgorithmConfig):
"""Defines a configuration class from which an SAC Algorithm can be built.
.. testcode::
config = SACConfig().training(gamma=0.9, lr=0.01, train_batch_size=32)
config = config.resources(num_gpus=0)
config = config.env_runners(num_env_runners=1)
# Build a Algorithm object from the config and run 1 training iteration.
algo = config.build(env="CartPole-v1")
algo.train()
"""
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or SAC)
# fmt: off
# __sphinx_doc_begin__
# SAC-specific config settings.
self.twin_q = True
self.q_model_config = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define custom Q-model(s).
"custom_model_config": {},
}
self.policy_model_config = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define a custom policy model.
"custom_model_config": {},
}
self.clip_actions = False
self.tau = 5e-3
self.initial_alpha = 1.0
self.target_entropy = "auto"
self.n_step = 1
self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": int(1e6),
# If True prioritized replay buffer will be used.
"prioritized_replay": False,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
# Whether to compute priorities already on the remote worker side.
"worker_side_prioritization": False,
}
self.store_buffer_in_checkpoints = False
self.training_intensity = None
self.optimization = {
"actor_learning_rate": 3e-4,
"critic_learning_rate": 3e-4,
"entropy_learning_rate": 3e-4,
}
self.actor_lr = 3e-5
self.critic_lr = 3e-4
self.alpha_lr = 3e-4
# Set `lr` parameter to `None` and ensure it is not used.
self.lr = 3e-4
self.grad_clip = None
self.target_network_update_freq = 0
# .env_runners()
# Set to `self.n_step`, if 'auto'.
self.rollout_fragment_length = "auto"
self.compress_observations = False
self.exploration_config = {
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
# You can also provide the python class directly or the full location
# of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
# EpsilonGreedy").
"type": "StochasticSampling",
# Add constructor kwargs here (if any).
}
# .training()
self.train_batch_size_per_learner = 256
self.train_batch_size = 256 # @OldAPIstack
# Number of timesteps to collect from rollout workers before we start
# sampling from replay buffers for learning. Whether we count this in agent
# steps or environment steps depends on config.multi_agent(count_steps_by=..).
self.num_steps_sampled_before_learning_starts = 1500
# .reporting()
self.min_time_s_per_iteration = 1
self.min_sample_timesteps_per_iteration = 100
# __sphinx_doc_end__
# fmt: on
self._deterministic_loss = False
self._use_beta_distribution = False
self.use_state_preprocessor = DEPRECATED_VALUE
self.worker_side_prioritization = DEPRECATED_VALUE
[docs]
@override(AlgorithmConfig)
def training(
self,
*,
twin_q: Optional[bool] = NotProvided,
q_model_config: Optional[Dict[str, Any]] = NotProvided,
policy_model_config: Optional[Dict[str, Any]] = NotProvided,
tau: Optional[float] = NotProvided,
initial_alpha: Optional[float] = NotProvided,
target_entropy: Optional[Union[str, float]] = NotProvided,
n_step: Optional[Union[int, Tuple[int, int]]] = NotProvided,
store_buffer_in_checkpoints: Optional[bool] = NotProvided,
replay_buffer_config: Optional[Dict[str, Any]] = NotProvided,
training_intensity: Optional[float] = NotProvided,
clip_actions: Optional[bool] = NotProvided,
grad_clip: Optional[float] = NotProvided,
optimization_config: Optional[Dict[str, Any]] = NotProvided,
actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
alpha_lr: Optional[LearningRateOrSchedule] = NotProvided,
target_network_update_freq: Optional[int] = NotProvided,
_deterministic_loss: Optional[bool] = NotProvided,
_use_beta_distribution: Optional[bool] = NotProvided,
num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
**kwargs,
) -> "SACConfig":
"""Sets the training related configuration.
Args:
twin_q: Use two Q-networks (instead of one) for action-value estimation.
Note: Each Q-network will have its own target network.
q_model_config: Model configs for the Q network(s). These will override
MODEL_DEFAULTS. This is treated just as the top-level `model` dict in
setting up the Q-network(s) (2 if twin_q=True).
That means, you can do for different observation spaces:
`obs=Box(1D)` -> `Tuple(Box(1D) + Action)` -> `concat` -> `post_fcnet`
obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action
-> post_fcnet
obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action)
-> vision-net -> concat w/ Box(1D) and action -> post_fcnet
You can also have SAC use your custom_model as Q-model(s), by simply
specifying the `custom_model` sub-key in below dict (just like you would
do in the top-level `model` dict.
policy_model_config: Model options for the policy function (see
`q_model_config` above for details). The difference to `q_model_config`
above is that no action concat'ing is performed before the post_fcnet
stack.
tau: Update the target by \tau * policy + (1-\tau) * target_policy.
initial_alpha: Initial value to use for the entropy weight alpha.
target_entropy: Target entropy lower bound. If "auto", will be set
to `-|A|` (e.g. -2.0 for Discrete(2), -3.0 for Box(shape=(3,))).
This is the inverse of reward scale, and will be optimized
automatically.
n_step: N-step target updates. If >1, sars' tuples in trajectories will be
postprocessed to become sa[discounted sum of R][s t+n] tuples. An
integer will be interpreted as a fixed n-step value. If a tuple of 2
ints is provided here, the n-step value will be drawn for each sample(!)
in the train batch from a uniform distribution over the closed interval
defined by `[n_step[0], n_step[1]]`.
store_buffer_in_checkpoints: Set this to True, if you want the contents of
your buffer(s) to be stored in any saved checkpoints as well.
Warnings will be created if:
- This is True AND restoring from a checkpoint that contains no buffer
data.
- This is False AND restoring from a checkpoint that does contain
buffer data.
replay_buffer_config: Replay buffer config.
Examples:
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
"replay_sequence_length": 1,
}
- OR -
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"replay_sequence_length": 1,
}
- Where -
prioritized_replay_alpha: Alpha parameter controls the degree of
prioritization in the buffer. In other words, when a buffer sample has
a higher temporal-difference error, with how much more probability
should it drawn to use to update the parametrized Q-network. 0.0
corresponds to uniform probability. Setting much above 1.0 may quickly
result as the sampling distribution could become heavily “pointy” with
low entropy.
prioritized_replay_beta: Beta parameter controls the degree of
importance sampling which suppresses the influence of gradient updates
from samples that have higher probability of being sampled via alpha
parameter and the temporal-difference error.
prioritized_replay_eps: Epsilon parameter sets the baseline probability
for sampling so that when the temporal-difference error of a sample is
zero, there is still a chance of drawing the sample.
training_intensity: The intensity with which to update the model (vs
collecting samples from the env).
If None, uses "natural" values of:
`train_batch_size` / (`rollout_fragment_length` x `num_env_runners` x
`num_envs_per_env_runner`).
If not None, will make sure that the ratio between timesteps inserted
into and sampled from th buffer matches the given values.
Example:
training_intensity=1000.0
train_batch_size=250
rollout_fragment_length=1
num_env_runners=1 (or 0)
num_envs_per_env_runner=1
-> natural value = 250 / 1 = 250.0
-> will make sure that replay+train op will be executed 4x asoften as
rollout+insert op (4 * 250 = 1000).
See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further
details.
clip_actions: Whether to clip actions. If actions are already normalized,
this should be set to False.
grad_clip: If not None, clip gradients during optimization at this value.
optimization_config: Config dict for optimization. Set the supported keys
`actor_learning_rate`, `critic_learning_rate`, and
`entropy_learning_rate` in here.
actor_lr: The learning rate (float) or learning rate schedule for the
policy in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: It is common practice (two-timescale approach) to use a smaller
learning rate for the policy than for the critic to ensure that the
critic gives adequate values for improving the policy.
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-5, one decimal less than the respective
learning rate of the critic (see `critic_lr`).
critic_lr: The learning rate (float) or learning rate schedule for the
critic in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: It is common practice (two-timescale approach) to use a smaller
learning rate for the policy than for the critic to ensure that the
critic gives adequate values for improving the policy.
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-4, one decimal higher than the respective
learning rate of the actor (policy) (see `actor_lr`).
alpha_lr: The learning rate (float) or learning rate schedule for the
hyperparameter alpha in the format of
[[timestep, lr-value], [timestep, lr-value], ...] In case of a
schedule, intermediary timesteps will be assigned to linearly
interpolated learning rate values. A schedule config's first entry
must start with timestep 0, i.e.: [[0, initial_value], [...]].
Note: If you require a) more than one optimizer (per RLModule),
b) optimizer types that are not Adam, c) a learning rate schedule that
is not a linearly interpolated, piecewise schedule as described above,
or d) specifying c'tor arguments of the optimizer that are not the
learning rate (e.g. Adam's epsilon), then you must override your
Learner's `configure_optimizer_for_module()` method and handle
lr-scheduling yourself.
The default value is 3e-4, identical to the critic learning rate (`lr`).
target_network_update_freq: Update the target network every
`target_network_update_freq` steps.
_deterministic_loss: Whether the loss should be calculated deterministically
(w/o the stochastic action sampling step). True only useful for
continuous actions and for debugging.
_use_beta_distribution: Use a Beta-distribution instead of a
`SquashedGaussian` for bounded, continuous action spaces (not
recommended; for debugging only).
Returns:
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if twin_q is not NotProvided:
self.twin_q = twin_q
if q_model_config is not NotProvided:
self.q_model_config.update(q_model_config)
if policy_model_config is not NotProvided:
self.policy_model_config.update(policy_model_config)
if tau is not NotProvided:
self.tau = tau
if initial_alpha is not NotProvided:
self.initial_alpha = initial_alpha
if target_entropy is not NotProvided:
self.target_entropy = target_entropy
if n_step is not NotProvided:
self.n_step = n_step
if store_buffer_in_checkpoints is not NotProvided:
self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
if replay_buffer_config is not NotProvided:
# Override entire `replay_buffer_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
new_replay_buffer_config = deep_update(
{"replay_buffer_config": self.replay_buffer_config},
{"replay_buffer_config": replay_buffer_config},
False,
["replay_buffer_config"],
["replay_buffer_config"],
)
self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
if training_intensity is not NotProvided:
self.training_intensity = training_intensity
if clip_actions is not NotProvided:
self.clip_actions = clip_actions
if grad_clip is not NotProvided:
self.grad_clip = grad_clip
if optimization_config is not NotProvided:
self.optimization = optimization_config
if actor_lr is not NotProvided:
self.actor_lr = actor_lr
if critic_lr is not NotProvided:
self.critic_lr = critic_lr
if alpha_lr is not NotProvided:
self.alpha_lr = alpha_lr
if target_network_update_freq is not NotProvided:
self.target_network_update_freq = target_network_update_freq
if _deterministic_loss is not NotProvided:
self._deterministic_loss = _deterministic_loss
if _use_beta_distribution is not NotProvided:
self._use_beta_distribution = _use_beta_distribution
if num_steps_sampled_before_learning_starts is not NotProvided:
self.num_steps_sampled_before_learning_starts = (
num_steps_sampled_before_learning_starts
)
return self
@override(AlgorithmConfig)
def validate(self) -> None:
# Call super's validation method.
super().validate()
# Check rollout_fragment_length to be compatible with n_step.
if isinstance(self.n_step, tuple):
min_rollout_fragment_length = self.n_step[1]
else:
min_rollout_fragment_length = self.n_step
if (
not self.in_evaluation
and self.rollout_fragment_length != "auto"
and self.rollout_fragment_length
< min_rollout_fragment_length # (self.n_step or 1)
):
raise ValueError(
f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is "
f"smaller than needed for `n_step` ({self.n_step})! If `n_step` is "
f"an integer try setting `rollout_fragment_length={self.n_step}`. If "
"`n_step` is a tuple, try setting "
f"`rollout_fragment_length={self.n_step[1]}`."
)
if self.use_state_preprocessor != DEPRECATED_VALUE:
deprecation_warning(
old="config['use_state_preprocessor']",
error=False,
)
self.use_state_preprocessor = DEPRECATED_VALUE
if self.grad_clip is not None and self.grad_clip <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")
if self.framework in ["tf", "tf2"] and tfp is None:
logger.warning(
"You need `tensorflow_probability` in order to run SAC! "
"Install it via `pip install tensorflow_probability`. Your "
f"tf.__version__={tf.__version__ if tf else None}."
"Trying to import tfp results in the following error:"
)
try_import_tfp(error=True)
# Validate that we use the corresponding `EpisodeReplayBuffer` when using
# episodes.
# TODO (sven, simon): Implement the multi-agent case for replay buffers.
if (
self.enable_env_runner_and_connector_v2
and self.replay_buffer_config["type"]
not in [
"EpisodeReplayBuffer",
"PrioritizedEpisodeReplayBuffer",
"MultiAgentEpisodeReplayBuffer",
"MultiAgentPrioritizedEpisodeReplayBuffer",
]
and not (
# TODO (simon): Set up an indicator `is_offline_new_stack` that
# includes all these variable checks.
self.input_
and (
isinstance(self.input_, str)
or (
isinstance(self.input_, list)
and isinstance(self.input_[0], str)
)
)
and self.input_ != "sampler"
and self.enable_rl_module_and_learner
)
):
raise ValueError(
"When using the new `EnvRunner API` the replay buffer must be of type "
"`EpisodeReplayBuffer`."
)
if self.enable_rl_module_and_learner and self.lr is not None:
raise ValueError(
"Basic learning rate parameter `lr` is not `None`. For SAC "
"use the specific learning rate parameters `actor_lr`, `critic_lr` "
"and `alpha_lr`, for the actor, critic, and the hyperparameter "
"`alpha`, respectively."
)
@override(AlgorithmConfig)
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
if self.rollout_fragment_length == "auto":
return (
self.n_step[1]
if isinstance(self.n_step, (tuple, list))
else self.n_step
)
else:
return self.rollout_fragment_length
@override(AlgorithmConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
from ray.rllib.algorithms.sac.sac_catalog import SACCatalog
if self.framework_str == "torch":
from ray.rllib.algorithms.sac.torch.sac_torch_rl_module import (
SACTorchRLModule,
)
return RLModuleSpec(module_class=SACTorchRLModule, catalog_class=SACCatalog)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. " "Use `torch`."
)
@override(AlgorithmConfig)
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
if self.framework_str == "torch":
from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner
return SACTorchLearner
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. " "Use `torch`."
)
@override(AlgorithmConfig)
def build_learner_connector(
self,
input_observation_space,
input_action_space,
device=None,
):
pipeline = super().build_learner_connector(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
device=device,
)
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)
return pipeline
@property
def _model_config_auto_includes(self):
return super()._model_config_auto_includes | {"twin_q": self.twin_q}
class SAC(DQN):
"""Soft Actor Critic (SAC) Algorithm class.
This file defines the distributed Algorithm class for the soft actor critic
algorithm.
See `sac_[tf|torch]_policy.py` for the definition of the policy loss.
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#sac
"""
def __init__(self, *args, **kwargs):
self._allow_unknown_subkeys += ["policy_model_config", "q_model_config"]
super().__init__(*args, **kwargs)
@classmethod
@override(DQN)
def get_default_config(cls) -> AlgorithmConfig:
return SACConfig()
@classmethod
@override(DQN)
def get_default_policy_class(
cls, config: AlgorithmConfig
) -> Optional[Type[Policy]]:
if config["framework"] == "torch":
from ray.rllib.algorithms.sac.sac_torch_policy import SACTorchPolicy
return SACTorchPolicy
else:
return SACTFPolicy
@override(DQN)
def training_step(self) -> ResultDict:
"""SAC training iteration function.
Each training iteration, we:
- Sample (MultiAgentBatch) from workers.
- Store new samples in replay buffer.
- Sample training batch (MultiAgentBatch) from replay buffer.
- Learn on training batch.
- Update remote workers' new policy weights.
- Update target network every `target_network_update_freq` sample steps.
- Return all collected metrics for the iteration.
Returns:
The results dict from executing the training iteration.
"""
# New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
if self.config.enable_env_runner_and_connector_v2:
return self._training_step_new_api_stack(with_noise_reset=False)
# Old and hybrid API stacks (Policy, RolloutWorker, Connector, maybe RLModule,
# maybe Learner).
else:
return self._training_step_old_and_hybrid_api_stack()