ray.rllib.policy.Policy.loss
ray.rllib.policy.Policy.loss#
- Policy.loss(model: ray.rllib.models.modelv2.ModelV2, dist_class: ray.rllib.models.action_dist.ActionDistribution, train_batch: ray.rllib.policy.sample_batch.SampleBatch) Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor, List[Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]]] [source]#
Loss function for this Policy.
Override this method in order to implement custom loss computations.
- Parameters
model – The model to calculate the loss(es).
dist_class – The action distribution class to sample actions from the model’s outputs.
train_batch – The input batch on which to calculate the loss.
- Returns
Either a single loss tensor or a list of loss tensors.