ray.rllib.utils.torch_utils.update_target_network#
- ray.rllib.utils.torch_utils.update_target_network(main_net: torch.nn.Module | tf.keras.Model, target_net: torch.nn.Module | tf.keras.Model, tau: float) None [source]#
Updates a torch.nn.Module target network using Polyak averaging.
new_target_net_weight = ( tau * main_net_weight + (1.0 - tau) * current_target_net_weight )
- Parameters:
main_net – The nn.Module to update from.
target_net – The target network to update.
tau – The tau value to use in the Polyak averaging formula.