ray.rllib.policy.Policy.loss#

Policy.loss(model: ray.rllib.models.modelv2.ModelV2, dist_class: ray.rllib.models.action_dist.ActionDistribution, train_batch: ray.rllib.policy.sample_batch.SampleBatch) Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor, List[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]][source]#

Loss function for this Policy.

Override this method in order to implement custom loss computations.

Parameters
  • model – The model to calculate the loss(es).

  • dist_class – The action distribution class to sample actions from the model’s outputs.

  • train_batch – The input batch on which to calculate the loss.

Returns

Either a single loss tensor or a list of loss tensors.