import gymnasium as gym
import numpy as np
import tree # pip install dm_tree
import random
from typing import Union, Optional
from ray.rllib.models.torch.torch_action_dist import TorchMultiActionDistribution
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.utils.annotations import override, OldAPIStack
from ray.rllib.utils.exploration.exploration import Exploration, TensorType
from ray.rllib.utils.framework import try_import_tf, try_import_torch, get_variable
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule
from ray.rllib.utils.torch_utils import FLOAT_MIN
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
[docs]
@OldAPIStack
class EpsilonGreedy(Exploration):
"""Epsilon-greedy Exploration class that produces exploration actions.
When given a Model's output and a current epsilon value (based on some
Schedule), it produces a random action (if rand(1) < eps) or
uses the model-computed one (if rand(1) >= eps).
"""
[docs]
def __init__(
self,
action_space: gym.spaces.Space,
*,
framework: str,
initial_epsilon: float = 1.0,
final_epsilon: float = 0.05,
warmup_timesteps: int = 0,
epsilon_timesteps: int = int(1e5),
epsilon_schedule: Optional[Schedule] = None,
**kwargs,
):
"""Create an EpsilonGreedy exploration class.
Args:
action_space: The action space the exploration should occur in.
framework: The framework specifier.
initial_epsilon: The initial epsilon value to use.
final_epsilon: The final epsilon value to use.
warmup_timesteps: The timesteps over which to not change epsilon in the
beginning.
epsilon_timesteps: The timesteps (additional to `warmup_timesteps`)
after which epsilon should always be `final_epsilon`.
E.g.: warmup_timesteps=20k epsilon_timesteps=50k -> After 70k timesteps,
epsilon will reach its final value.
epsilon_schedule: An optional Schedule object
to use (instead of constructing one from the given parameters).
"""
assert framework is not None
super().__init__(action_space=action_space, framework=framework, **kwargs)
self.epsilon_schedule = from_config(
Schedule, epsilon_schedule, framework=framework
) or PiecewiseSchedule(
endpoints=[
(0, initial_epsilon),
(warmup_timesteps, initial_epsilon),
(warmup_timesteps + epsilon_timesteps, final_epsilon),
],
outside_value=final_epsilon,
framework=self.framework,
)
# The current timestep value (tf-var or python int).
self.last_timestep = get_variable(
np.array(0, np.int64),
framework=framework,
tf_name="timestep",
dtype=np.int64,
)
# Build the tf-info-op.
if self.framework == "tf":
self._tf_state_op = self.get_state()
@override(Exploration)
def get_exploration_action(
self,
*,
action_distribution: ActionDistribution,
timestep: Union[int, TensorType],
explore: Optional[Union[bool, TensorType]] = True,
):
if self.framework in ["tf2", "tf"]:
return self._get_tf_exploration_action_op(
action_distribution, explore, timestep
)
else:
return self._get_torch_exploration_action(
action_distribution, explore, timestep
)
def _get_tf_exploration_action_op(
self,
action_distribution: ActionDistribution,
explore: Union[bool, TensorType],
timestep: Union[int, TensorType],
) -> "tf.Tensor":
"""TF method to produce the tf op for an epsilon exploration action.
Args:
action_distribution: The instantiated ActionDistribution object
to work with when creating exploration actions.
Returns:
The tf exploration-action op.
"""
# TODO: Support MultiActionDistr for tf.
q_values = action_distribution.inputs
epsilon = self.epsilon_schedule(
timestep if timestep is not None else self.last_timestep
)
# Get the exploit action as the one with the highest logit value.
exploit_action = tf.argmax(q_values, axis=1)
batch_size = tf.shape(q_values)[0]
# Mask out actions with q-value=-inf so that we don't even consider
# them for exploration.
random_valid_action_logits = tf.where(
tf.equal(q_values, tf.float32.min),
tf.ones_like(q_values) * tf.float32.min,
tf.ones_like(q_values),
)
random_actions = tf.squeeze(
tf.random.categorical(random_valid_action_logits, 1), axis=1
)
chose_random = (
tf.random.uniform(
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32
)
< epsilon
)
action = tf.cond(
pred=tf.constant(explore, dtype=tf.bool)
if isinstance(explore, bool)
else explore,
true_fn=(lambda: tf.where(chose_random, random_actions, exploit_action)),
false_fn=lambda: exploit_action,
)
if self.framework == "tf2" and not self.policy_config["eager_tracing"]:
self.last_timestep = timestep
return action, tf.zeros_like(action, dtype=tf.float32)
else:
assign_op = tf1.assign(self.last_timestep, tf.cast(timestep, tf.int64))
with tf1.control_dependencies([assign_op]):
return action, tf.zeros_like(action, dtype=tf.float32)
def _get_torch_exploration_action(
self,
action_distribution: ActionDistribution,
explore: bool,
timestep: Union[int, TensorType],
) -> "torch.Tensor":
"""Torch method to produce an epsilon exploration action.
Args:
action_distribution: The instantiated
ActionDistribution object to work with when creating
exploration actions.
Returns:
The exploration-action.
"""
q_values = action_distribution.inputs
self.last_timestep = timestep
exploit_action = action_distribution.deterministic_sample()
batch_size = q_values.size()[0]
action_logp = torch.zeros(batch_size, dtype=torch.float)
# Explore.
if explore:
# Get the current epsilon.
epsilon = self.epsilon_schedule(self.last_timestep)
if isinstance(action_distribution, TorchMultiActionDistribution):
exploit_action = tree.flatten(exploit_action)
for i in range(batch_size):
if random.random() < epsilon:
# TODO: (bcahlit) Mask out actions
random_action = tree.flatten(self.action_space.sample())
for j in range(len(exploit_action)):
exploit_action[j][i] = torch.tensor(random_action[j])
exploit_action = tree.unflatten_as(
action_distribution.action_space_struct, exploit_action
)
return exploit_action, action_logp
else:
# Mask out actions, whose Q-values are -inf, so that we don't
# even consider them for exploration.
random_valid_action_logits = torch.where(
q_values <= FLOAT_MIN,
torch.ones_like(q_values) * 0.0,
torch.ones_like(q_values),
)
# A random action.
random_actions = torch.squeeze(
torch.multinomial(random_valid_action_logits, 1), axis=1
)
# Pick either random or greedy.
action = torch.where(
torch.empty((batch_size,)).uniform_().to(self.device) < epsilon,
random_actions,
exploit_action,
)
return action, action_logp
# Return the deterministic "sample" (argmax) over the logits.
else:
return exploit_action, action_logp
@override(Exploration)
def get_state(self, sess: Optional["tf.Session"] = None):
if sess:
return sess.run(self._tf_state_op)
eps = self.epsilon_schedule(self.last_timestep)
return {
"cur_epsilon": convert_to_numpy(eps) if self.framework != "tf" else eps,
"last_timestep": convert_to_numpy(self.last_timestep)
if self.framework != "tf"
else self.last_timestep,
}
@override(Exploration)
def set_state(self, state: dict, sess: Optional["tf.Session"] = None) -> None:
if self.framework == "tf":
self.last_timestep.load(state["last_timestep"], session=sess)
elif isinstance(self.last_timestep, int):
self.last_timestep = state["last_timestep"]
else:
self.last_timestep.assign(state["last_timestep"])