ray.rllib.utils.torch_utils.update_target_network#

ray.rllib.utils.torch_utils.update_target_network(main_net: torch.nn.Module | tf.keras.Model, target_net: torch.nn.Module | tf.keras.Model, tau: float) None[source]#

Updates a torch.nn.Module target network using Polyak averaging.

new_target_net_weight = (
    tau * main_net_weight + (1.0 - tau) * current_target_net_weight
)
Parameters:
  • main_net – The nn.Module to update from.

  • target_net – The target network to update.

  • tau – The tau value to use in the Polyak averaging formula.