Source code for ray.rllib.utils.schedules.scheduler

from typing import Optional

from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule
from ray.rllib.utils.typing import LearningRateOrSchedule, TensorType
from ray.util.annotations import DeveloperAPI


_, tf, _ = try_import_tf()
torch, _ = try_import_torch()


[docs] @DeveloperAPI class Scheduler: """Class to manage a scheduled (framework-dependent) tensor variable. Uses the PiecewiseSchedule (for maximum configuration flexibility) """
[docs] def __init__( self, fixed_value_or_schedule: LearningRateOrSchedule, *, framework: str = "torch", device: Optional[str] = None, ): """Initializes a Scheduler instance. Args: fixed_value_or_schedule: A fixed, constant value (in case no schedule should be used) or a schedule configuration in the format of [[timestep, value], [timestep, value], ...] Intermediary timesteps will be assigned to linerarly interpolated values. A schedule config's first entry must start with timestep 0, i.e.: [[0, initial_value], [...]]. framework: The framework string, for which to create the tensor variable that hold the current value. This is the variable that can be used in the graph, e.g. in a loss function. device: Optional device (for torch) to place the tensor variable on. """ self.framework = framework self.device = device self.use_schedule = isinstance(fixed_value_or_schedule, (list, tuple)) if self.use_schedule: # Custom schedule, based on list of # ([ts], [value to be reached by ts])-tuples. self._schedule = PiecewiseSchedule( fixed_value_or_schedule, outside_value=fixed_value_or_schedule[-1][-1], framework=None, ) # As initial tensor valie, use the first timestep's (must be 0) value. self._curr_value = self._create_tensor_variable( initial_value=fixed_value_or_schedule[0][1] ) # If no schedule, pin (fix) given value. else: self._curr_value = fixed_value_or_schedule
[docs] @staticmethod def validate( *, fixed_value_or_schedule: LearningRateOrSchedule, setting_name: str, description: str, ) -> None: """Performs checking of a certain schedule configuration. The first entry in `value_or_schedule` (if it's not a fixed value) must have a timestep of 0. Args: fixed_value_or_schedule: A fixed, constant value (in case no schedule should be used) or a schedule configuration in the format of [[timestep, value], [timestep, value], ...] Intermediary timesteps will be assigned to linerarly interpolated values. A schedule config's first entry must start with timestep 0, i.e.: [[0, initial_value], [...]]. setting_name: The property name of the schedule setting (within a config), e.g. `lr` or `entropy_coeff`. description: A full text description of the property that's being scheduled, e.g. `learning rate`. Raises: ValueError: In case, errors are found in the schedule's format. """ # Fixed (single) value. if ( isinstance(fixed_value_or_schedule, (int, float)) or fixed_value_or_schedule is None ): return if not isinstance(fixed_value_or_schedule, (list, tuple)) or ( len(fixed_value_or_schedule) < 2 ): raise ValueError( f"Invalid `{setting_name}` ({fixed_value_or_schedule}) specified! " f"Must be a list of 2 or more tuples, each of the form " f"(`timestep`, `{description} to reach`), for example " "`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`." ) elif fixed_value_or_schedule[0][0] != 0: raise ValueError( f"When providing a `{setting_name}` schedule, the first timestep must " f"be 0 and the corresponding lr value is the initial {description}! " f"You provided ts={fixed_value_or_schedule[0][0]} {description}=" f"{fixed_value_or_schedule[0][1]}." ) elif any(len(pair) != 2 for pair in fixed_value_or_schedule): raise ValueError( f"When providing a `{setting_name}` schedule, each tuple in the " f"schedule list must have exctly 2 items of the form " f"(`timestep`, `{description} to reach`), for example " "`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`." )
[docs] def get_current_value(self) -> TensorType: """Returns the current value (as a tensor variable). This method should be used in loss functions of other (in-graph) places where the current value is needed. Returns: The tensor variable (holding the current value to be used). """ return self._curr_value
[docs] def update(self, timestep: int) -> float: """Updates the underlying (framework specific) tensor variable. In case of a fixed value, this method does nothing and only returns the fixed value as-is. Args: timestep: The current timestep that the update might depend on. Returns: The current value of the tensor variable as a python float. """ if self.use_schedule: python_value = self._schedule.value(t=timestep) if self.framework == "torch": self._curr_value.data = torch.tensor(python_value) else: self._curr_value.assign(python_value) else: python_value = self._curr_value return python_value
[docs] def _create_tensor_variable(self, initial_value: float) -> TensorType: """Creates a framework-specific tensor variable to be scheduled. Args: initial_value: The initial (float) value for the variable to hold. Returns: The created framework-specific tensor variable. """ if self.framework == "torch": return torch.tensor( initial_value, requires_grad=False, dtype=torch.float32, device=self.device, ) else: return tf.Variable( initial_value, trainable=False, dtype=tf.float32, )