ray.rllib.models.distributions.Distribution#

class ray.rllib.models.distributions.Distribution[source]#

Bases: ABC

The base class for distribution over a random variable.

Examples:

import torch
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.models.torch.torch_distributions import TorchCategorical

model = MLPHeadConfig(input_dims=[1]).build(framework="torch")

# Create an action distribution from model logits
action_logits = model(torch.Tensor([[1]]))
action_dist = TorchCategorical.from_logits(action_logits)
action = action_dist.sample()

# Create another distribution from a dummy Tensor
action_dist2 = TorchCategorical.from_logits(torch.Tensor([0]))

# Compute some common metrics
logp = action_dist.logp(action)
kl = action_dist.kl(action_dist2)
entropy = action_dist.entropy()

Methods

__init__

entropy

The entropy of the distribution.

from_logits

Creates a Distribution from logits.

get_partial_dist_cls

Returns a partial child of TorchMultiActionDistribution.

kl

The KL-divergence between two distributions.

logp

The log-likelihood of the distribution computed at value

required_input_dim

Returns the required length of an input parameter tensor.

rsample

Draw a re-parameterized sample from the action distribution.

sample

Draw a sample from the distribution.

to_deterministic

Returns a deterministic equivalent for this distribution.