from typing import Optional
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule
from ray.rllib.utils.typing import LearningRateOrSchedule, TensorType
from ray.util.annotations import DeveloperAPI
_, tf, _ = try_import_tf()
torch, _ = try_import_torch()
[docs]
@DeveloperAPI
class Scheduler:
"""Class to manage a scheduled (framework-dependent) tensor variable.
Uses the PiecewiseSchedule (for maximum configuration flexibility)
"""
[docs]
def __init__(
self,
fixed_value_or_schedule: LearningRateOrSchedule,
*,
framework: str = "torch",
device: Optional[str] = None,
):
"""Initializes a Scheduler instance.
Args:
fixed_value_or_schedule: A fixed, constant value (in case no schedule should
be used) or a schedule configuration in the format of
[[timestep, value], [timestep, value], ...]
Intermediary timesteps will be assigned to linerarly interpolated
values. A schedule config's first entry must
start with timestep 0, i.e.: [[0, initial_value], [...]].
framework: The framework string, for which to create the tensor variable
that hold the current value. This is the variable that can be used in
the graph, e.g. in a loss function.
device: Optional device (for torch) to place the tensor variable on.
"""
self.framework = framework
self.device = device
self.use_schedule = isinstance(fixed_value_or_schedule, (list, tuple))
if self.use_schedule:
# Custom schedule, based on list of
# ([ts], [value to be reached by ts])-tuples.
self._schedule = PiecewiseSchedule(
fixed_value_or_schedule,
outside_value=fixed_value_or_schedule[-1][-1],
framework=None,
)
# As initial tensor valie, use the first timestep's (must be 0) value.
self._curr_value = self._create_tensor_variable(
initial_value=fixed_value_or_schedule[0][1]
)
# If no schedule, pin (fix) given value.
else:
self._curr_value = fixed_value_or_schedule
[docs]
@staticmethod
def validate(
*,
fixed_value_or_schedule: LearningRateOrSchedule,
setting_name: str,
description: str,
) -> None:
"""Performs checking of a certain schedule configuration.
The first entry in `value_or_schedule` (if it's not a fixed value) must have a
timestep of 0.
Args:
fixed_value_or_schedule: A fixed, constant value (in case no schedule should
be used) or a schedule configuration in the format of
[[timestep, value], [timestep, value], ...]
Intermediary timesteps will be assigned to linerarly interpolated
values. A schedule config's first entry must
start with timestep 0, i.e.: [[0, initial_value], [...]].
setting_name: The property name of the schedule setting (within a config),
e.g. `lr` or `entropy_coeff`.
description: A full text description of the property that's being scheduled,
e.g. `learning rate`.
Raises:
ValueError: In case, errors are found in the schedule's format.
"""
# Fixed (single) value.
if (
isinstance(fixed_value_or_schedule, (int, float))
or fixed_value_or_schedule is None
):
return
if not isinstance(fixed_value_or_schedule, (list, tuple)) or (
len(fixed_value_or_schedule) < 2
):
raise ValueError(
f"Invalid `{setting_name}` ({fixed_value_or_schedule}) specified! "
f"Must be a list of 2 or more tuples, each of the form "
f"(`timestep`, `{description} to reach`), for example "
"`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`."
)
elif fixed_value_or_schedule[0][0] != 0:
raise ValueError(
f"When providing a `{setting_name}` schedule, the first timestep must "
f"be 0 and the corresponding lr value is the initial {description}! "
f"You provided ts={fixed_value_or_schedule[0][0]} {description}="
f"{fixed_value_or_schedule[0][1]}."
)
elif any(len(pair) != 2 for pair in fixed_value_or_schedule):
raise ValueError(
f"When providing a `{setting_name}` schedule, each tuple in the "
f"schedule list must have exctly 2 items of the form "
f"(`timestep`, `{description} to reach`), for example "
"`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`."
)
[docs]
def get_current_value(self) -> TensorType:
"""Returns the current value (as a tensor variable).
This method should be used in loss functions of other (in-graph) places
where the current value is needed.
Returns:
The tensor variable (holding the current value to be used).
"""
return self._curr_value
[docs]
def update(self, timestep: int) -> float:
"""Updates the underlying (framework specific) tensor variable.
In case of a fixed value, this method does nothing and only returns the fixed
value as-is.
Args:
timestep: The current timestep that the update might depend on.
Returns:
The current value of the tensor variable as a python float.
"""
if self.use_schedule:
python_value = self._schedule.value(t=timestep)
if self.framework == "torch":
self._curr_value.data = torch.tensor(python_value)
else:
self._curr_value.assign(python_value)
else:
python_value = self._curr_value
return python_value
[docs]
def _create_tensor_variable(self, initial_value: float) -> TensorType:
"""Creates a framework-specific tensor variable to be scheduled.
Args:
initial_value: The initial (float) value for the variable to hold.
Returns:
The created framework-specific tensor variable.
"""
if self.framework == "torch":
return torch.tensor(
initial_value,
requires_grad=False,
dtype=torch.float32,
device=self.device,
)
else:
return tf.Variable(
initial_value,
trainable=False,
dtype=tf.float32,
)