"""This is the next version of action distribution base class."""
from typing import Tuple
import gymnasium as gym
import abc
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.typing import TensorType, Union
from ray.rllib.utils.annotations import override
[docs]
@ExperimentalAPI
class Distribution(abc.ABC):
"""The base class for distribution over a random variable.
Examples:
.. testcode::
import torch
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.models.torch.torch_distributions import TorchCategorical
model = MLPHeadConfig(input_dims=[1]).build(framework="torch")
# Create an action distribution from model logits
action_logits = model(torch.Tensor([[1]]))
action_dist = TorchCategorical.from_logits(action_logits)
action = action_dist.sample()
# Create another distribution from a dummy Tensor
action_dist2 = TorchCategorical.from_logits(torch.Tensor([0]))
# Compute some common metrics
logp = action_dist.logp(action)
kl = action_dist.kl(action_dist2)
entropy = action_dist.entropy()
"""
[docs]
@abc.abstractmethod
def sample(
self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
**kwargs,
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a sample from the distribution.
Args:
sample_shape: The shape of the sample to draw.
return_logp: Whether to return the logp of the sampled values.
**kwargs: Forward compatibility placeholder.
Returns:
The sampled values. If return_logp is True, returns a tuple of the
sampled values and its logp.
"""
[docs]
@abc.abstractmethod
def rsample(
self,
*,
sample_shape: Tuple[int, ...] = None,
return_logp: bool = False,
**kwargs,
) -> Union[TensorType, Tuple[TensorType, TensorType]]:
"""Draw a re-parameterized sample from the action distribution.
If this method is implemented, we can take gradients of samples w.r.t. the
distribution parameters.
Args:
sample_shape: The shape of the sample to draw.
return_logp: Whether to return the logp of the sampled values.
**kwargs: Forward compatibility placeholder.
Returns:
The sampled values. If return_logp is True, returns a tuple of the
sampled values and its logp.
"""
[docs]
@abc.abstractmethod
def logp(self, value: TensorType, **kwargs) -> TensorType:
"""The log-likelihood of the distribution computed at `value`
Args:
value: The value to compute the log-likelihood at.
**kwargs: Forward compatibility placeholder.
Returns:
The log-likelihood of the value.
"""
[docs]
@abc.abstractmethod
def kl(self, other: "Distribution", **kwargs) -> TensorType:
"""The KL-divergence between two distributions.
Args:
other: The other distribution.
**kwargs: Forward compatibility placeholder.
Returns:
The KL-divergence between the two distributions.
"""
[docs]
@abc.abstractmethod
def entropy(self, **kwargs) -> TensorType:
"""The entropy of the distribution.
Args:
**kwargs: Forward compatibility placeholder.
Returns:
The entropy of the distribution.
"""
[docs]
@classmethod
def from_logits(cls, logits: TensorType, **kwargs) -> "Distribution":
"""Creates a Distribution from logits.
The caller does not need to have knowledge of the distribution class in order
to create it and sample from it. The passed batched logits vectors might be
split up and are passed to the distribution class' constructor as kwargs.
Args:
logits: The logits to create the distribution from.
**kwargs: Forward compatibility placeholder.
Returns:
The created distribution.
.. testcode::
import numpy as np
from ray.rllib.models.distributions import Distribution
class Uniform(Distribution):
def __init__(self, lower, upper):
self.lower = lower
self.upper = upper
def sample(self):
return self.lower + (self.upper - self.lower) * np.random.rand()
def logp(self, x):
...
def kl(self, other):
...
def entropy(self):
...
@staticmethod
def required_input_dim(space):
...
def rsample(self):
...
@classmethod
def from_logits(cls, logits, **kwargs):
return Uniform(logits[:, 0], logits[:, 1])
logits = np.array([[0.0, 1.0], [2.0, 3.0]])
my_dist = Uniform.from_logits(logits)
sample = my_dist.sample()
"""
raise NotImplementedError
[docs]
@classmethod
def get_partial_dist_cls(
parent_cls: "Distribution", **partial_kwargs
) -> "Distribution":
"""Returns a partial child of TorchMultiActionDistribution.
This is useful if inputs needed to instantiate the Distribution from logits
are available, but the logits are not.
"""
class DistributionPartial(parent_cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def _merge_kwargs(**kwargs):
"""Checks if keys in kwargs don't clash with partial_kwargs."""
overlap = set(kwargs) & set(partial_kwargs)
if overlap:
raise ValueError(
f"Cannot override the following kwargs: {overlap}.\n"
f"This is because they were already set at the time this "
f"partial class was defined."
)
merged_kwargs = {**partial_kwargs, **kwargs}
return merged_kwargs
@classmethod
@override(parent_cls)
def required_input_dim(cls, space: gym.Space, **kwargs) -> int:
merged_kwargs = cls._merge_kwargs(**kwargs)
assert space == merged_kwargs["space"]
return parent_cls.required_input_dim(**merged_kwargs)
@classmethod
@override(parent_cls)
def from_logits(
cls,
logits: TensorType,
**kwargs,
) -> "DistributionPartial":
merged_kwargs = cls._merge_kwargs(**kwargs)
distribution = parent_cls.from_logits(logits, **merged_kwargs)
# Replace the class of the returned distribution with this partial
# This makes it so that we can use type() on this distribution and
# get back the partial class.
distribution.__class__ = cls
return distribution
# Substitute name of this partial class to match the original class.
DistributionPartial.__name__ = f"{parent_cls}Partial"
return DistributionPartial
[docs]
def to_deterministic(self) -> "Distribution":
"""Returns a deterministic equivalent for this distribution.
Specifically, the deterministic equivalent for a Categorical distribution is a
Deterministic distribution that selects the action with maximum logit value.
Generally, the choice of the deterministic replacement is informed by
established conventions.
"""
return self