Source code for ray.rllib.algorithms.iql.iql
from typing import Optional, Type, Union
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
AddObservationsFromEpisodesToBatch,
)
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
AddNextObservationsFromEpisodesToTrainBatch,
)
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType
[docs]
class IQLConfig(MARWILConfig):
"""Defines a configuration class from which a new IQL Algorithm can be built
.. testcode::
:skipif: True
from ray.rllib.algorithms.iql import IQLConfig
# Run this from the ray directory root.
config = IQLConfig().training(actor_lr=0.00001, gamma=0.99)
config = config.offline_data(
input_="./rllib/tests/data/pendulum/pendulum-v1_enormous")
# Build an Algorithm object from the config and run 1 training iteration.
algo = config.build()
algo.train()
.. testcode::
:skipif: True
from ray.rllib.algorithms.iql import IQLConfig
from ray import tune
config = IQLConfig()
# Print out some default values.
print(config.beta)
# Update the config object.
config.training(
lr=tune.grid_search([0.001, 0.0001]), beta=0.75
)
# Set the config object's data path.
# Run this from the ray directory root.
config.offline_data(
input_="./rllib/tests/data/pendulum-v1_enormous"
)
# Set the config object's env, used for evaluation.
config.environment(env="Pendulum-v1")
# Use to_dict() to get the old-style python config dict
# when running with tune.
tune.Tuner(
"IQL",
param_space=config.to_dict(),
).fit()
"""
def __init__(self, algo_class=None):
super().__init__(algo_class=algo_class or IQL)
# fmt: off
# __sphinx_doc_begin__
# The temperature for the actor loss.
self.beta = 0.1
# The expectile to use in expectile regression.
self.expectile = 0.8
# The learning rates for the actor, critic and value network(s).
self.actor_lr = 3e-4
self.critic_lr = 3e-4
self.value_lr = 3e-4
# Set `lr` parameter to `None` and ensure it is not used.
self.lr = None
# If a twin-Q architecture should be used (advisable).
self.twin_q = True
# How often the target network should be updated.
self.target_network_update_freq = 0
# The weight for Polyak averaging.
self.tau = 1.0
# __sphinx_doc_end__
# fmt: on
[docs]
@override(MARWILConfig)
def training(
self,
*,
twin_q: Optional[bool] = NotProvided,
expectile: Optional[float] = NotProvided,
actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
value_lr: Optional[LearningRateOrSchedule] = NotProvided,
target_network_update_freq: Optional[int] = NotProvided,
tau: Optional[float] = NotProvided,
**kwargs,
) -> "IQLConfig":
"""Sets the training related configuration.
Args:
beta: The temperature to scaling advantages in exponential terms.
Must be >> 0.0. The higher this parameter the less greedy
(exploitative) the policy becomes. It also means that the policy
is fitting less to the best actions in the dataset.
twin_q: If a twin-Q architecture should be used (advisable).
expectile: The expectile to use in expectile regression for the value
function. For high expectiles the value function tries to match
the upper tail of the Q-value distribution.
actor_lr: The learning rate for the actor network. Actor learning rates
greater than critic learning rates work well in experiments.
critic_lr: The learning rate for the Q-network. Critic learning rates
greater than value function learning rates work well in experiments.
value_lr: The learning rate for the value function network.
target_network_update_freq: The number of timesteps in between the target
Q-network is fixed. Note, too high values here could harm convergence.
The target network is updated via Polyak-averaging.
tau: The update parameter for Polyak-averaging of the target Q-network.
The higher this value the faster the weights move towards the actual
Q-network.
Return:
This updated `AlgorithmConfig` object.
"""
super().training(**kwargs)
if twin_q is not NotProvided:
self.twin_q = twin_q
if expectile is not NotProvided:
self.expectile = expectile
if actor_lr is not NotProvided:
self.actor_lr = actor_lr
if critic_lr is not NotProvided:
self.critic_lr = critic_lr
if value_lr is not NotProvided:
self.value_lr = value_lr
if target_network_update_freq is not NotProvided:
self.target_network_update_freq = target_network_update_freq
if tau is not NotProvided:
self.tau = tau
return self
@override(MARWILConfig)
def get_default_learner_class(self) -> Union[Type["Learner"], str]:
if self.framework_str == "torch":
from ray.rllib.algorithms.iql.torch.iql_torch_learner import IQLTorchLearner
return IQLTorchLearner
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use `'torch'` instead."
)
@override(MARWILConfig)
def get_default_rl_module_spec(self) -> RLModuleSpecType:
if self.framework_str == "torch":
from ray.rllib.algorithms.iql.torch.default_iql_torch_rl_module import (
DefaultIQLTorchRLModule,
)
return RLModuleSpec(module_class=DefaultIQLTorchRLModule)
else:
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use `torch` instead."
)
@override(MARWILConfig)
def build_learner_connector(
self,
input_observation_space,
input_action_space,
device=None,
):
pipeline = super().build_learner_connector(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
device=device,
)
# Remove unneeded connectors from the MARWIL connector pipeline.
pipeline.remove("AddOneTsToEpisodesAndTruncate")
pipeline.remove("GeneralAdvantageEstimation")
# Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right
# after the corresponding "add-OBS-..." default piece).
pipeline.insert_after(
AddObservationsFromEpisodesToBatch,
AddNextObservationsFromEpisodesToTrainBatch(),
)
return pipeline
@override(MARWILConfig)
def validate(self) -> None:
# Call super's validation method.
super().validate()
# Ensure hyperparameters are meaningful.
if self.beta <= 0.0:
self._value_error(
"For meaningful results, `beta` (temperature) parameter must be >> 0.0!"
)
if not 0.0 < self.expectile < 1.0:
self._value_error(
"For meaningful results, `expectile` parameter must be in (0, 1)."
)
@property
def _model_config_auto_includes(self):
return super()._model_config_auto_includes | {"twin_q": self.twin_q}
class IQL(MARWIL):
"""Implicit Q-learning (derived from MARWIL).
Uses MARWIL training step.
"""
@classmethod
@override(MARWIL)
def get_default_config(cls) -> AlgorithmConfig:
return IQLConfig()