Policy.loss(model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][source]#

Loss function for this Policy.

Override this method in order to implement custom loss computations.

  • model – The model to calculate the loss(es).

  • dist_class – The action distribution class to sample actions from the model’s outputs.

  • train_batch – The input batch on which to calculate the loss.


Either a single loss tensor or a list of loss tensors.