ray.rllib.policy.Policy.loss#
- Policy.loss(model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [source]#
Loss function for this Policy.
Override this method in order to implement custom loss computations.
- Parameters:
model – The model to calculate the loss(es).
dist_class – The action distribution class to sample actions from the model’s outputs.
train_batch – The input batch on which to calculate the loss.
- Returns:
Either a single loss tensor or a list of loss tensors.