Source code for ray.rllib.utils.exploration.random

from gymnasium.spaces import Discrete, Box, MultiDiscrete, Space
import numpy as np
import tree  # pip install dm_tree
from typing import Union, Optional

from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.utils.annotations import OldAPIStack, override
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils import force_tuple
from ray.rllib.utils.framework import try_import_tf, try_import_torch, TensorType
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.tf_utils import zero_logps_from_actions

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()


[docs] @OldAPIStack class Random(Exploration): """A random action selector (deterministic/greedy for explore=False). If explore=True, returns actions randomly from `self.action_space` (via Space.sample()). If explore=False, returns the greedy/max-likelihood action. """
[docs] def __init__( self, action_space: Space, *, model: ModelV2, framework: Optional[str], **kwargs ): """Initialize a Random Exploration object. Args: action_space: The gym action space used by the environment. framework: One of None, "tf", "torch". """ super().__init__( action_space=action_space, model=model, framework=framework, **kwargs ) self.action_space_struct = get_base_struct_from_space(self.action_space)
@override(Exploration) def get_exploration_action( self, *, action_distribution: ActionDistribution, timestep: Union[int, TensorType], explore: bool = True ): # Instantiate the distribution object. if self.framework in ["tf2", "tf"]: return self.get_tf_exploration_action_op(action_distribution, explore) else: return self.get_torch_exploration_action(action_distribution, explore) def get_tf_exploration_action_op( self, action_dist: ActionDistribution, explore: Optional[Union[bool, TensorType]], ): def true_fn(): batch_size = 1 req = force_tuple( action_dist.required_model_output_shape( self.action_space, getattr(self.model, "model_config", None) ) ) # Add a batch dimension? if len(action_dist.inputs.shape) == len(req) + 1: batch_size = tf.shape(action_dist.inputs)[0] # Function to produce random samples from primitive space # components: (Multi)Discrete or Box. def random_component(component): # Have at least an additional shape of (1,), even if the # component is Box(-1.0, 1.0, shape=()). shape = component.shape or (1,) if isinstance(component, Discrete): return tf.random.uniform( shape=(batch_size,) + component.shape, maxval=component.n, dtype=component.dtype, ) elif isinstance(component, MultiDiscrete): return tf.concat( [ tf.random.uniform( shape=(batch_size, 1), maxval=n, dtype=component.dtype ) for n in component.nvec ], axis=1, ) elif isinstance(component, Box): if component.bounded_above.all() and component.bounded_below.all(): if component.dtype.name.startswith("int"): return tf.random.uniform( shape=(batch_size,) + shape, minval=component.low.flat[0], maxval=component.high.flat[0], dtype=component.dtype, ) else: return tf.random.uniform( shape=(batch_size,) + shape, minval=component.low, maxval=component.high, dtype=component.dtype, ) else: return tf.random.normal( shape=(batch_size,) + shape, dtype=component.dtype ) else: assert isinstance(component, Simplex), ( "Unsupported distribution component '{}' for random " "sampling!".format(component) ) return tf.nn.softmax( tf.random.uniform( shape=(batch_size,) + shape, minval=0.0, maxval=1.0, dtype=component.dtype, ) ) actions = tree.map_structure(random_component, self.action_space_struct) return actions def false_fn(): return action_dist.deterministic_sample() action = tf.cond( pred=tf.constant(explore, dtype=tf.bool) if isinstance(explore, bool) else explore, true_fn=true_fn, false_fn=false_fn, ) logp = zero_logps_from_actions(action) return action, logp def get_torch_exploration_action( self, action_dist: ActionDistribution, explore: bool ): if explore: req = force_tuple( action_dist.required_model_output_shape( self.action_space, getattr(self.model, "model_config", None) ) ) # Add a batch dimension? if len(action_dist.inputs.shape) == len(req) + 1: batch_size = action_dist.inputs.shape[0] a = np.stack([self.action_space.sample() for _ in range(batch_size)]) else: a = self.action_space.sample() # Convert action to torch tensor. action = torch.from_numpy(a).to(self.device) else: action = action_dist.deterministic_sample() logp = torch.zeros((action.size()[0],), dtype=torch.float32, device=self.device) return action, logp