Exploration API
Contents
Exploration API#
Exploration in RL is crucial for a learning agent in order to more easily reach areas of the environment that have not been discovered so far and therefore find new states yielding possibly high rewards.
RLlib comes with several built-in exploration components, used by the different algorithms. Also users can customize an algo’s exploration behavior by sub-classing the Exploration base class and implementing their own logic:
Base Exploration class (ray.rllib.utils.exploration.exploration.Exploration)#
- class ray.rllib.utils.exploration.exploration.Exploration(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, framework: str, policy_config: dict, model: ray.rllib.models.modelv2.ModelV2, num_workers: int, worker_index: int)[source]#
Implements an exploration strategy for Policies.
An Exploration takes model outputs, a distribution, and a timestep from the agent and computes an action to apply to the environment using an implemented exploration schema.
- before_compute_actions(*, timestep: Optional[Union[numpy.array, tf.Tensor, torch.Tensor, int]] = None, explore: Optional[Union[numpy.array, tf.Tensor, torch.Tensor, bool]] = None, tf_sess: Optional[tf.Session] = None, **kwargs)[source]#
Hook for preparations before policy.compute_actions() is called.
- Parameters
timestep – An optional timestep tensor.
explore – An optional explore boolean flag.
tf_sess – The tf-session object to use.
**kwargs – Forward compatibility kwargs.
- get_exploration_action(*, action_distribution: ray.rllib.models.action_dist.ActionDistribution, timestep: Union[numpy.array, tf.Tensor, torch.Tensor, int], explore: bool = True)[source]#
Returns a (possibly) exploratory action and its log-likelihood.
Given the Model’s logits outputs and action distribution, returns an exploratory action.
- Parameters
action_distribution – The instantiated ActionDistribution object to work with when creating exploration actions.
timestep – The current sampling time step. It can be a tensor for TF graph mode, otherwise an integer.
explore – True: “Normal” exploration behavior. False: Suppress all exploratory behavior and return a deterministic action.
- Returns
A tuple consisting of 1) the chosen exploration action or a tf-op to fetch the exploration action from the graph and 2) the log-likelihood of the exploration action.
- on_episode_start(policy: Policy, *, environment: ray.rllib.env.base_env.BaseEnv = None, episode: int = None, tf_sess: Optional[tf.Session] = None)[source]#
Handles necessary exploration logic at the beginning of an episode.
- Parameters
policy – The Policy object that holds this Exploration.
environment – The environment object we are acting in.
episode – The number of the episode that is starting.
tf_sess – In case of tf, the session object.
- on_episode_end(policy: Policy, *, environment: ray.rllib.env.base_env.BaseEnv = None, episode: int = None, tf_sess: Optional[tf.Session] = None)[source]#
Handles necessary exploration logic at the end of an episode.
- Parameters
policy – The Policy object that holds this Exploration.
environment – The environment object we are acting in.
episode – The number of the episode that is starting.
tf_sess – In case of tf, the session object.
- postprocess_trajectory(policy: Policy, sample_batch: ray.rllib.policy.sample_batch.SampleBatch, tf_sess: Optional[tf.Session] = None)[source]#
Handles post-processing of done episode trajectories.
Changes the given batch in place. This callback is invoked by the sampler after policy.postprocess_trajectory() is called.
- Parameters
policy – The owning policy object.
sample_batch – The SampleBatch object to post-process.
tf_sess – An optional tf.Session object.
- get_exploration_optimizer(optimizers: List[Union[tf.keras.optimizers.Optimizer, torch.optim.Optimizer]]) List[Union[tf.keras.optimizers.Optimizer, torch.optim.Optimizer]] [source]#
May add optimizer(s) to the Policy’s own
optimizers
.The number of optimizers (Policy’s plus Exploration’s optimizers) must match the number of loss terms produced by the Policy’s loss function and the Exploration component’s loss terms.
- Parameters
optimizers – The list of the Policy’s local optimizers.
- Returns
The updated list of local optimizers to use on the different loss terms.
- get_state(sess: Optional[tf.Session] = None) Dict[str, Union[numpy.array, tf.Tensor, torch.Tensor]] [source]#
Returns the current exploration state.
- Parameters
sess – An optional tf Session object to use.
- Returns
The Exploration object’s current state.
- set_state(state: object, sess: Optional[tf.Session] = None) None [source]#
Sets the Exploration object’s state to the given values.
Note that some exploration components are stateless, even though they decay some values over time (e.g. EpsilonGreedy). However the decay is only dependent on the current global timestep of the policy and we therefore don’t need to keep track of it.
- Parameters
state – The state to set this Exploration to.
sess – An optional tf Session object to use.
All built-in Exploration classes#
- class ray.rllib.utils.exploration.random.Random(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, model: ray.rllib.models.modelv2.ModelV2, framework: Optional[str], **kwargs)[source]#
A random action selector (deterministic/greedy for explore=False).
If explore=True, returns actions randomly from
self.action_space
(via Space.sample()). If explore=False, returns the greedy/max-likelihood action.- __init__(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, model: ray.rllib.models.modelv2.ModelV2, framework: Optional[str], **kwargs)[source]#
Initialize a Random Exploration object.
- Parameters
action_space – The gym action space used by the environment.
framework – One of None, “tf”, “torch”.
- get_exploration_action(*, action_distribution: ray.rllib.models.action_dist.ActionDistribution, timestep: Union[int, numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], explore: bool = True)[source]#
Returns a (possibly) exploratory action and its log-likelihood.
Given the Model’s logits outputs and action distribution, returns an exploratory action.
- Parameters
action_distribution – The instantiated ActionDistribution object to work with when creating exploration actions.
timestep – The current sampling time step. It can be a tensor for TF graph mode, otherwise an integer.
explore – True: “Normal” exploration behavior. False: Suppress all exploratory behavior and return a deterministic action.
- Returns
A tuple consisting of 1) the chosen exploration action or a tf-op to fetch the exploration action from the graph and 2) the log-likelihood of the exploration action.
- class ray.rllib.utils.exploration.stochastic_sampling.StochasticSampling(action_space: <MagicMock name='mock.spaces.Space' id='140330629676496'>, *, framework: str, model: ray.rllib.models.modelv2.ModelV2, random_timesteps: int = 0, **kwargs)[source]#
An exploration that simply samples from a distribution.
The sampling can be made deterministic by passing explore=False into the call to
get_exploration_action
. Also allows for scheduled parameters for the distributions, such as lowering stddev, temperature, etc.. over time.- __init__(action_space: <MagicMock name='mock.spaces.Space' id='140330629676496'>, *, framework: str, model: ray.rllib.models.modelv2.ModelV2, random_timesteps: int = 0, **kwargs)[source]#
Initializes a StochasticSampling Exploration object.
- Parameters
action_space – The gym action space used by the environment.
framework – One of None, “tf”, “torch”.
model – The ModelV2 used by the owning Policy.
random_timesteps – The number of timesteps for which to act completely randomly. Only after this number of timesteps, actual samples will be drawn to get exploration actions.
- get_exploration_action(*, action_distribution: ray.rllib.models.action_dist.ActionDistribution, timestep: Optional[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor, int]] = None, explore: bool = True)[source]#
Returns a (possibly) exploratory action and its log-likelihood.
Given the Model’s logits outputs and action distribution, returns an exploratory action.
- Parameters
action_distribution – The instantiated ActionDistribution object to work with when creating exploration actions.
timestep – The current sampling time step. It can be a tensor for TF graph mode, otherwise an integer.
explore – True: “Normal” exploration behavior. False: Suppress all exploratory behavior and return a deterministic action.
- Returns
A tuple consisting of 1) the chosen exploration action or a tf-op to fetch the exploration action from the graph and 2) the log-likelihood of the exploration action.
- class ray.rllib.utils.exploration.epsilon_greedy.EpsilonGreedy(action_space: <MagicMock name='mock.spaces.Space' id='140330629676496'>, *, framework: str, initial_epsilon: float = 1.0, final_epsilon: float = 0.05, warmup_timesteps: int = 0, epsilon_timesteps: int = 100000, epsilon_schedule: Optional[ray.rllib.utils.schedules.schedule.Schedule] = None, **kwargs)[source]#
Epsilon-greedy Exploration class that produces exploration actions.
When given a Model’s output and a current epsilon value (based on some Schedule), it produces a random action (if rand(1) < eps) or uses the model-computed one (if rand(1) >= eps).
- __init__(action_space: <MagicMock name='mock.spaces.Space' id='140330629676496'>, *, framework: str, initial_epsilon: float = 1.0, final_epsilon: float = 0.05, warmup_timesteps: int = 0, epsilon_timesteps: int = 100000, epsilon_schedule: Optional[ray.rllib.utils.schedules.schedule.Schedule] = None, **kwargs)[source]#
Create an EpsilonGreedy exploration class.
- Parameters
action_space – The action space the exploration should occur in.
framework – The framework specifier.
initial_epsilon – The initial epsilon value to use.
final_epsilon – The final epsilon value to use.
warmup_timesteps – The timesteps over which to not change epsilon in the beginning.
epsilon_timesteps – The timesteps (additional to
warmup_timesteps
) after which epsilon should always befinal_epsilon
. E.g.: warmup_timesteps=20k epsilon_timesteps=50k -> After 70k timesteps, epsilon will reach its final value.epsilon_schedule – An optional Schedule object to use (instead of constructing one from the given parameters).
- get_exploration_action(*, action_distribution: ray.rllib.models.action_dist.ActionDistribution, timestep: Union[int, numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], explore: Optional[Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor, bool]] = True)[source]#
Returns a (possibly) exploratory action and its log-likelihood.
Given the Model’s logits outputs and action distribution, returns an exploratory action.
- Parameters
action_distribution – The instantiated ActionDistribution object to work with when creating exploration actions.
timestep – The current sampling time step. It can be a tensor for TF graph mode, otherwise an integer.
explore – True: “Normal” exploration behavior. False: Suppress all exploratory behavior and return a deterministic action.
- Returns
A tuple consisting of 1) the chosen exploration action or a tf-op to fetch the exploration action from the graph and 2) the log-likelihood of the exploration action.
- get_state(sess: Optional[tf.Session] = None)[source]#
Returns the current exploration state.
- Parameters
sess – An optional tf Session object to use.
- Returns
The Exploration object’s current state.
- set_state(state: dict, sess: Optional[tf.Session] = None) None [source]#
Sets the Exploration object’s state to the given values.
Note that some exploration components are stateless, even though they decay some values over time (e.g. EpsilonGreedy). However the decay is only dependent on the current global timestep of the policy and we therefore don’t need to keep track of it.
- Parameters
state – The state to set this Exploration to.
sess – An optional tf Session object to use.
- class ray.rllib.utils.exploration.gaussian_noise.GaussianNoise(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, framework: str, model: ray.rllib.models.modelv2.ModelV2, random_timesteps: int = 1000, stddev: float = 0.1, initial_scale: float = 1.0, final_scale: float = 0.02, scale_timesteps: int = 10000, scale_schedule: Optional[ray.rllib.utils.schedules.schedule.Schedule] = None, **kwargs)[source]#
An exploration that adds white noise to continuous actions.
If explore=True, returns actions plus scale (annealed over time) x Gaussian noise. Also, some completely random period is possible at the beginning.
If explore=False, returns the deterministic action.
- __init__(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, framework: str, model: ray.rllib.models.modelv2.ModelV2, random_timesteps: int = 1000, stddev: float = 0.1, initial_scale: float = 1.0, final_scale: float = 0.02, scale_timesteps: int = 10000, scale_schedule: Optional[ray.rllib.utils.schedules.schedule.Schedule] = None, **kwargs)[source]#
Initializes a GaussianNoise instance.
- Parameters
random_timesteps – The number of timesteps for which to act completely randomly. Only after this number of timesteps, the
self.scale
annealing process will start (see below).stddev – The stddev (sigma) to use for the Gaussian noise to be added to the actions.
initial_scale – The initial scaling weight to multiply the noise with.
final_scale – The final scaling weight to multiply the noise with.
scale_timesteps – The timesteps over which to linearly anneal the scaling factor (after(!) having used random actions for
random_timesteps
steps).scale_schedule – An optional Schedule object to use (instead of constructing one from the given parameters).
- get_exploration_action(*, action_distribution: ray.rllib.models.action_dist.ActionDistribution, timestep: Union[int, numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], explore: bool = True)[source]#
Returns a (possibly) exploratory action and its log-likelihood.
Given the Model’s logits outputs and action distribution, returns an exploratory action.
- Parameters
action_distribution – The instantiated ActionDistribution object to work with when creating exploration actions.
timestep – The current sampling time step. It can be a tensor for TF graph mode, otherwise an integer.
explore – True: “Normal” exploration behavior. False: Suppress all exploratory behavior and return a deterministic action.
- Returns
A tuple consisting of 1) the chosen exploration action or a tf-op to fetch the exploration action from the graph and 2) the log-likelihood of the exploration action.
- get_state(sess: Optional[tf.Session] = None)[source]#
Returns the current scale value.
- Returns
The current scale value.
- Return type
Union[float,tf.Tensor[float]]
- set_state(state: dict, sess: Optional[tf.Session] = None) None [source]#
Sets the Exploration object’s state to the given values.
Note that some exploration components are stateless, even though they decay some values over time (e.g. EpsilonGreedy). However the decay is only dependent on the current global timestep of the policy and we therefore don’t need to keep track of it.
- Parameters
state – The state to set this Exploration to.
sess – An optional tf Session object to use.
- class ray.rllib.utils.exploration.ornstein_uhlenbeck_noise.OrnsteinUhlenbeckNoise(action_space, *, framework: str, ou_theta: float = 0.15, ou_sigma: float = 0.2, ou_base_scale: float = 0.1, random_timesteps: int = 1000, initial_scale: float = 1.0, final_scale: float = 0.02, scale_timesteps: int = 10000, scale_schedule: Optional[ray.rllib.utils.schedules.schedule.Schedule] = None, **kwargs)[source]#
An exploration that adds Ornstein-Uhlenbeck noise to continuous actions.
If explore=True, returns sampled actions plus a noise term X, which changes according to this formula: Xt+1 = -theta*Xt + sigma*N[0,stddev], where theta, sigma and stddev are constants. Also, some completely random period is possible at the beginning. If explore=False, returns the deterministic action.
- __init__(action_space, *, framework: str, ou_theta: float = 0.15, ou_sigma: float = 0.2, ou_base_scale: float = 0.1, random_timesteps: int = 1000, initial_scale: float = 1.0, final_scale: float = 0.02, scale_timesteps: int = 10000, scale_schedule: Optional[ray.rllib.utils.schedules.schedule.Schedule] = None, **kwargs)[source]#
Initializes an Ornstein-Uhlenbeck Exploration object.
- Parameters
action_space – The gym action space used by the environment.
ou_theta – The theta parameter of the Ornstein-Uhlenbeck process.
ou_sigma – The sigma parameter of the Ornstein-Uhlenbeck process.
ou_base_scale – A fixed scaling factor, by which all OU- noise is multiplied. NOTE: This is on top of the parent GaussianNoise’s scaling.
random_timesteps – The number of timesteps for which to act completely randomly. Only after this number of timesteps, the
self.scale
annealing process will start (see below).initial_scale – The initial scaling weight to multiply the noise with.
final_scale – The final scaling weight to multiply the noise with.
scale_timesteps – The timesteps over which to linearly anneal the scaling factor (after(!) having used random actions for
random_timesteps
steps.scale_schedule – An optional Schedule object to use (instead of constructing one from the given parameters).
framework – One of None, “tf”, “torch”.
- get_state(sess: Optional[tf.Session] = None)[source]#
Returns the current scale value.
- Returns
The current scale value.
- Return type
Union[float,tf.Tensor[float]]
- set_state(state: dict, sess: Optional[tf.Session] = None) None [source]#
Sets the Exploration object’s state to the given values.
Note that some exploration components are stateless, even though they decay some values over time (e.g. EpsilonGreedy). However the decay is only dependent on the current global timestep of the policy and we therefore don’t need to keep track of it.
- Parameters
state – The state to set this Exploration to.
sess – An optional tf Session object to use.
- class ray.rllib.utils.exploration.random_encoder.RE3(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, framework: str, model: ray.rllib.models.modelv2.ModelV2, embeds_dim: int = 128, encoder_net_config: Optional[dict] = None, beta: float = 0.2, beta_schedule: str = 'constant', rho: float = 0.1, k_nn: int = 50, random_timesteps: int = 10000, sub_exploration: Optional[Union[Dict[str, Any], type, str]] = None, **kwargs)[source]#
Random Encoder for Efficient Exploration.
Implementation of: [1] State entropy maximization with random encoders for efficient exploration. Seo, Chen, Shin, Lee, Abbeel, & Lee, (2021). arXiv preprint arXiv:2102.09430.
Estimates state entropy using a particle-based k-nearest neighbors (k-NN) estimator in the latent space. The state’s latent representation is calculated using an encoder with randomly initialized parameters.
The entropy of a state is considered as intrinsic reward and added to the environment’s extrinsic reward for policy optimization. Entropy is calculated per batch, it does not take the distribution of the entire replay buffer into consideration.
- __init__(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, framework: str, model: ray.rllib.models.modelv2.ModelV2, embeds_dim: int = 128, encoder_net_config: Optional[dict] = None, beta: float = 0.2, beta_schedule: str = 'constant', rho: float = 0.1, k_nn: int = 50, random_timesteps: int = 10000, sub_exploration: Optional[Union[Dict[str, Any], type, str]] = None, **kwargs)[source]#
Initialize RE3.
- Parameters
action_space – The action space in which to explore.
framework – Supports “tf”, this implementation does not support torch.
model – The policy’s model.
embeds_dim – The dimensionality of the observation embedding vectors in latent space.
encoder_net_config – Optional model configuration for the encoder network, producing embedding vectors from observations. This can be used to configure fcnet- or conv_net setups to properly process any observation space.
beta – Hyperparameter to choose between exploration and exploitation.
beta_schedule – Schedule to use for beta decay, one of “constant” or “linear_decay”.
rho – Beta decay factor, used for on-policy algorithm.
k_nn – Number of neighbours to set for K-NN entropy estimation.
random_timesteps – The number of timesteps to act completely randomly (see [1]).
sub_exploration – The config dict for the underlying Exploration to use (e.g. epsilon-greedy for DQN). If None, uses the FromSpecDict provided in the Policy’s default config.
- Raises
ValueError – If the input framework is Torch.
- get_exploration_action(*, action_distribution: ray.rllib.models.action_dist.ActionDistribution, timestep: Union[int, numpy.array, tf.Tensor, torch.Tensor], explore: bool = True)[source]#
Returns a (possibly) exploratory action and its log-likelihood.
Given the Model’s logits outputs and action distribution, returns an exploratory action.
- Parameters
action_distribution – The instantiated ActionDistribution object to work with when creating exploration actions.
timestep – The current sampling time step. It can be a tensor for TF graph mode, otherwise an integer.
explore – True: “Normal” exploration behavior. False: Suppress all exploratory behavior and return a deterministic action.
- Returns
A tuple consisting of 1) the chosen exploration action or a tf-op to fetch the exploration action from the graph and 2) the log-likelihood of the exploration action.
- class ray.rllib.utils.exploration.curiosity.Curiosity(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, framework: str, model: ray.rllib.models.modelv2.ModelV2, feature_dim: int = 288, feature_net_config: Optional[dict] = None, inverse_net_hiddens: Tuple[int] = (256,), inverse_net_activation: str = 'relu', forward_net_hiddens: Tuple[int] = (256,), forward_net_activation: str = 'relu', beta: float = 0.2, eta: float = 1.0, lr: float = 0.001, sub_exploration: Optional[Union[Dict[str, Any], type, str]] = None, **kwargs)[source]#
Implementation of: [1] Curiosity-driven Exploration by Self-supervised Prediction Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. https://arxiv.org/pdf/1705.05363.pdf
Learns a simplified model of the environment based on three networks: 1) Embedding observations into latent space (“feature” network). 2) Predicting the action, given two consecutive embedded observations (“inverse” network). 3) Predicting the next embedded obs, given an obs and action (“forward” network).
The less the agent is able to predict the actually observed next feature vector, given obs and action (through the forwards network), the larger the “intrinsic reward”, which will be added to the extrinsic reward. Therefore, if a state transition was unexpected, the agent becomes “curious” and will further explore this transition leading to better exploration in sparse rewards environments.
- __init__(action_space: <MagicMock name='mock.Space' id='140329264337040'>, *, framework: str, model: ray.rllib.models.modelv2.ModelV2, feature_dim: int = 288, feature_net_config: Optional[dict] = None, inverse_net_hiddens: Tuple[int] = (256,), inverse_net_activation: str = 'relu', forward_net_hiddens: Tuple[int] = (256,), forward_net_activation: str = 'relu', beta: float = 0.2, eta: float = 1.0, lr: float = 0.001, sub_exploration: Optional[Union[Dict[str, Any], type, str]] = None, **kwargs)[source]#
Initializes a Curiosity object.
Uses as defaults the hyperparameters described in [1].
- Parameters
feature_dim – The dimensionality of the feature (phi) vectors.
feature_net_config – Optional model configuration for the feature network, producing feature vectors (phi) from observations. This can be used to configure fcnet- or conv_net setups to properly process any observation space.
inverse_net_hiddens – Tuple of the layer sizes of the inverse (action predicting) NN head (on top of the feature outputs for phi and phi’).
inverse_net_activation – Activation specifier for the inverse net.
forward_net_hiddens – Tuple of the layer sizes of the forward (phi’ predicting) NN head.
forward_net_activation – Activation specifier for the forward net.
beta – Weight for the forward loss (over the inverse loss, which gets weight=1.0-beta) in the common loss term.
eta – Weight for intrinsic rewards before being added to extrinsic ones.
lr – The learning rate for the curiosity-specific optimizer, optimizing feature-, inverse-, and forward nets.
sub_exploration – The config dict for the underlying Exploration to use (e.g. epsilon-greedy for DQN). If None, uses the FromSpecDict provided in the Policy’s default config.
- get_exploration_action(*, action_distribution: ray.rllib.models.action_dist.ActionDistribution, timestep: Union[int, numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor], explore: bool = True)[source]#
Returns a (possibly) exploratory action and its log-likelihood.
Given the Model’s logits outputs and action distribution, returns an exploratory action.
- Parameters
action_distribution – The instantiated ActionDistribution object to work with when creating exploration actions.
timestep – The current sampling time step. It can be a tensor for TF graph mode, otherwise an integer.
explore – True: “Normal” exploration behavior. False: Suppress all exploratory behavior and return a deterministic action.
- Returns
A tuple consisting of 1) the chosen exploration action or a tf-op to fetch the exploration action from the graph and 2) the log-likelihood of the exploration action.
- get_exploration_optimizer(optimizers)[source]#
May add optimizer(s) to the Policy’s own
optimizers
.The number of optimizers (Policy’s plus Exploration’s optimizers) must match the number of loss terms produced by the Policy’s loss function and the Exploration component’s loss terms.
- Parameters
optimizers – The list of the Policy’s local optimizers.
- Returns
The updated list of local optimizers to use on the different loss terms.
- class ray.rllib.utils.exploration.parameter_noise.ParameterNoise(action_space, *, framework: str, policy_config: dict, model: ray.rllib.models.modelv2.ModelV2, initial_stddev: float = 1.0, random_timesteps: int = 10000, sub_exploration: Optional[dict] = None, **kwargs)[source]#
An exploration that changes a Model’s parameters.
Implemented based on: [1] https://blog.openai.com/better-exploration-with-parameter-noise/ [2] https://arxiv.org/pdf/1706.01905.pdf
At the beginning of an episode, Gaussian noise is added to all weights of the model. At the end of the episode, the noise is undone and an action diff (pi-delta) is calculated, from which we determine the changes in the noise’s stddev for the next episode.
- __init__(action_space, *, framework: str, policy_config: dict, model: ray.rllib.models.modelv2.ModelV2, initial_stddev: float = 1.0, random_timesteps: int = 10000, sub_exploration: Optional[dict] = None, **kwargs)[source]#
Initializes a ParameterNoise Exploration object.
- Parameters
initial_stddev – The initial stddev to use for the noise.
random_timesteps – The number of timesteps to act completely randomly (see [1]).
sub_exploration – Optional sub-exploration config. None for auto-detection/setup.
- before_compute_actions(*, timestep: Optional[int] = None, explore: Optional[bool] = None, tf_sess: Optional[tf.Session] = None)[source]#
Hook for preparations before policy.compute_actions() is called.
- Parameters
timestep – An optional timestep tensor.
explore – An optional explore boolean flag.
tf_sess – The tf-session object to use.
**kwargs – Forward compatibility kwargs.
- get_exploration_action(*, action_distribution: ray.rllib.models.action_dist.ActionDistribution, timestep: Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor, int], explore: Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor, bool])[source]#
Returns a (possibly) exploratory action and its log-likelihood.
Given the Model’s logits outputs and action distribution, returns an exploratory action.
- Parameters
action_distribution – The instantiated ActionDistribution object to work with when creating exploration actions.
timestep – The current sampling time step. It can be a tensor for TF graph mode, otherwise an integer.
explore – True: “Normal” exploration behavior. False: Suppress all exploratory behavior and return a deterministic action.
- Returns
A tuple consisting of 1) the chosen exploration action or a tf-op to fetch the exploration action from the graph and 2) the log-likelihood of the exploration action.
- on_episode_start(policy: Policy, *, environment: ray.rllib.env.base_env.BaseEnv = None, episode: int = None, tf_sess: Optional[tf.Session] = None)[source]#
Handles necessary exploration logic at the beginning of an episode.
- Parameters
policy – The Policy object that holds this Exploration.
environment – The environment object we are acting in.
episode – The number of the episode that is starting.
tf_sess – In case of tf, the session object.
- on_episode_end(policy, *, environment=None, episode=None, tf_sess=None)[source]#
Handles necessary exploration logic at the end of an episode.
- Parameters
policy – The Policy object that holds this Exploration.
environment – The environment object we are acting in.
episode – The number of the episode that is starting.
tf_sess – In case of tf, the session object.
- postprocess_trajectory(policy: Policy, sample_batch: ray.rllib.policy.sample_batch.SampleBatch, tf_sess: Optional[tf.Session] = None)[source]#
Handles post-processing of done episode trajectories.
Changes the given batch in place. This callback is invoked by the sampler after policy.postprocess_trajectory() is called.
- Parameters
policy – The owning policy object.
sample_batch – The SampleBatch object to post-process.
tf_sess – An optional tf.Session object.
- get_state(sess=None)[source]#
Returns the current exploration state.
- Parameters
sess – An optional tf Session object to use.
- Returns
The Exploration object’s current state.
- set_state(state: dict, sess: Optional[tf.Session] = None) None [source]#
Sets the Exploration object’s state to the given values.
Note that some exploration components are stateless, even though they decay some values over time (e.g. EpsilonGreedy). However the decay is only dependent on the current global timestep of the policy and we therefore don’t need to keep track of it.
- Parameters
state – The state to set this Exploration to.
sess – An optional tf Session object to use.