ray.rllib.core.learner.learner.Learner.additional_update#

Learner.additional_update(*, module_ids_to_update: Sequence[str] | None = None, timestep: int, **kwargs) Dict[str, Any][source]#

Apply additional non-gradient based updates to this Algorithm.

For example, this could be used to do a polyak averaging update of a target network in off policy algorithms like SAC or DQN.

Example:

import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo import (
    LEARNER_RESULTS_CURR_KL_COEFF_KEY,
    PPOConfig,
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import (
    PPOTorchLearner
)
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec

env = gym.make("CartPole-v1")
config = (
    PPOConfig()
    .training(
        kl_coeff=0.2,
        kl_target=0.01,
        clip_param=0.3,
        vf_clip_param=10.0,
        # Taper down entropy coeff. from 0.01 to 0.0 over 20M ts.
        entropy_coeff=[
            [0, 0.01],
            [20000000, 0.0],
        ],
        vf_loss_coeff=0.5,
    )
)

# Create a single agent RL module spec.
module_spec = SingleAgentRLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)

class CustomPPOLearner(PPOTorchLearner):
    def additional_update_for_module(
        self, *, module_id, config, timestep, sampled_kl_values
    ):

        results = super().additional_update_for_module(
            module_id=module_id,
            config=config,
            timestep=timestep,
            sampled_kl_values=sampled_kl_values,
        )

        # Try something else than the PPO paper here.
        sampled_kl = sampled_kl_values[module_id]
        curr_var = self.curr_kl_coeffs_per_module[module_id]
        if sampled_kl > 1.2 * self.config.kl_target:
            curr_var.data *= 1.2
        elif sampled_kl < 0.8 * self.config.kl_target:
            curr_var.data *= 0.4
        results.update({LEARNER_RESULTS_CURR_KL_COEFF_KEY: curr_var.item()})

# Construct the Learner object.
learner = CustomPPOLearner(
    config=config,
    module_spec=module_spec,
)
# Note: Learners need to be built before they can be used.
learner.build()

# Inside a training loop, we can now call the additional update as we like:
for i in range(100):
    # sample = ...
    # learner.update(sample)
    if i % 10 == 0:
        learner.additional_update(
            timestep=i,
            sampled_kl_values={"default_policy": 0.5}
        )
Parameters:
  • module_ids_to_update – The ids of the modules to update. If None, all modules will be updated.

  • timestep – The current timestep.

  • **kwargs – Keyword arguments to use for the additional update.

Returns:

A dictionary of results from the update