ray.rllib.models.distributions.Distribution.from_logits#

classmethod Distribution.from_logits(logits: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, **kwargs) Distribution[source]#

Creates a Distribution from logits.

The caller does not need to have knowledge of the distribution class in order to create it and sample from it. The passed batched logits vectors might be split up and are passed to the distribution class’ constructor as kwargs.

Parameters:
  • logits – The logits to create the distribution from.

  • **kwargs – Forward compatibility placeholder.

Returns:

The created distribution.

import numpy as np
from ray.rllib.models.distributions import Distribution

class Uniform(Distribution):
    def __init__(self, lower, upper):
        self.lower = lower
        self.upper = upper

    def sample(self):
        return self.lower + (self.upper - self.lower) * np.random.rand()

    def logp(self, x):
        ...

    def kl(self, other):
        ...

    def entropy(self):
        ...

    @staticmethod
    def required_input_dim(space):
        ...

    def rsample(self):
        ...

    @classmethod
    def from_logits(cls, logits, **kwargs):
        return Uniform(logits[:, 0], logits[:, 1])

logits = np.array([[0.0, 1.0], [2.0, 3.0]])
my_dist = Uniform.from_logits(logits)
sample = my_dist.sample()