ray.rllib.models.distributions.Distribution.from_logits#
- classmethod Distribution.from_logits(logits: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor, **kwargs) Distribution [source]#
Creates a Distribution from logits.
The caller does not need to have knowledge of the distribution class in order to create it and sample from it. The passed batched logits vectors might be split up and are passed to the distribution class’ constructor as kwargs.
- Parameters:
logits – The logits to create the distribution from.
**kwargs – Forward compatibility placeholder.
- Returns:
The created distribution.
import numpy as np from ray.rllib.models.distributions import Distribution class Uniform(Distribution): def __init__(self, lower, upper): self.lower = lower self.upper = upper def sample(self): return self.lower + (self.upper - self.lower) * np.random.rand() def logp(self, x): ... def kl(self, other): ... def entropy(self): ... @staticmethod def required_input_dim(space): ... def rsample(self): ... @classmethod def from_logits(cls, logits, **kwargs): return Uniform(logits[:, 0], logits[:, 1]) logits = np.array([[0.0, 1.0], [2.0, 3.0]]) my_dist = Uniform.from_logits(logits) sample = my_dist.sample()