TorchModelV2.custom_loss(policy_loss: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, loss_inputs: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] | numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor#

Override to customize the loss function used to optimize this model.

This can be used to incorporate self-supervised losses (by defining a loss over existing input and output tensors of this model), and supervised losses (by defining losses over a variable-sharing copy of this model’s layers).

You can find an runnable example in examples/custom_loss.py.

  • policy_loss – List of or single policy loss(es) from the policy.

  • loss_inputs – map of input placeholders for rollout data.


List of or scalar tensor for the customized loss(es) for this model.