ray.rllib.policy.torch_policy_v2.TorchPolicyV2.loss#
- TorchPolicyV2.loss(model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | List[numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [source]#
Constructs the loss function.
- Parameters:
model – The Model to calculate the loss for.
dist_class – The action distr. class.
train_batch – The training data.
- Returns:
Loss tensor given the input batch.