ray.rllib.core.learner.learner.Learner.additional_update#
- Learner.additional_update(*, module_ids_to_update: Sequence[str] | None = None, timestep: int, **kwargs) Dict[str, Any] [source]#
Apply additional non-gradient based updates to this Algorithm.
For example, this could be used to do a polyak averaging update of a target network in off policy algorithms like SAC or DQN.
Example:
import gymnasium as gym from ray.rllib.algorithms.ppo.ppo import ( LEARNER_RESULTS_CURR_KL_COEFF_KEY, PPOConfig, ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import ( PPOTorchLearner ) from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec env = gym.make("CartPole-v1") config = ( PPOConfig() .training( kl_coeff=0.2, kl_target=0.01, clip_param=0.3, vf_clip_param=10.0, # Taper down entropy coeff. from 0.01 to 0.0 over 20M ts. entropy_coeff=[ [0, 0.01], [20000000, 0.0], ], vf_loss_coeff=0.5, ) ) # Create a single agent RL module spec. module_spec = SingleAgentRLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict = {"hidden": [128, 128]}, catalog_class = PPOCatalog, ) class CustomPPOLearner(PPOTorchLearner): def additional_update_for_module( self, *, module_id, config, timestep, sampled_kl_values ): results = super().additional_update_for_module( module_id=module_id, config=config, timestep=timestep, sampled_kl_values=sampled_kl_values, ) # Try something else than the PPO paper here. sampled_kl = sampled_kl_values[module_id] curr_var = self.curr_kl_coeffs_per_module[module_id] if sampled_kl > 1.2 * self.config.kl_target: curr_var.data *= 1.2 elif sampled_kl < 0.8 * self.config.kl_target: curr_var.data *= 0.4 results.update({LEARNER_RESULTS_CURR_KL_COEFF_KEY: curr_var.item()}) # Construct the Learner object. learner = CustomPPOLearner( config=config, module_spec=module_spec, ) # Note: Learners need to be built before they can be used. learner.build() # Inside a training loop, we can now call the additional update as we like: for i in range(100): # sample = ... # learner.update(sample) if i % 10 == 0: learner.additional_update( timestep=i, sampled_kl_values={"default_policy": 0.5} )
- Parameters:
module_ids_to_update – The ids of the modules to update. If None, all modules will be updated.
timestep – The current timestep.
**kwargs – Keyword arguments to use for the additional update.
- Returns:
A dictionary of results from the update