Source code for ray.rllib.env.env_runner

import abc
import logging
from typing import Any, Dict, Tuple, TYPE_CHECKING

import gymnasium as gym
import tree  # pip install dm_tree

from ray.rllib.utils.actor_manager import FaultAwareApply
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import TensorType
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.rllib.algorithms.algorithm_config import AlgorithmConfig

logger = logging.getLogger("ray.rllib")

tf1, tf, _ = try_import_tf()

ENV_RESET_FAILURE = "env_reset_failure"
ENV_STEP_FAILURE = "env_step_failure"


# TODO (sven): As soon as RolloutWorker is no longer supported, make this base class
#  a Checkpointable. Currently, only some of its subclasses are Checkpointables.
[docs] @PublicAPI(stability="alpha") class EnvRunner(FaultAwareApply, metaclass=abc.ABCMeta): """Base class for distributed RL-style data collection from an environment. The EnvRunner API's core functionalities can be summarized as: - Gets configured via passing a AlgorithmConfig object to the constructor. Normally, subclasses of EnvRunner then construct their own environment (possibly vectorized) copies and RLModules/Policies and use the latter to step through the environment in order to collect training data. - Clients of EnvRunner can use the `sample()` method to collect data for training from the environment(s). - EnvRunner offers parallelism via creating n remote Ray Actors based on this class. Use `ray.remote([resources])(EnvRunner)` method to create the corresponding Ray remote class. Then instantiate n Actors using the Ray `[ctor].remote(...)` syntax. - EnvRunner clients can get information about the server/node on which the individual Actors are running. """
[docs] def __init__(self, *, config: "AlgorithmConfig", **kwargs): """Initializes an EnvRunner instance. Args: config: The AlgorithmConfig to use to setup this EnvRunner. **kwargs: Forward compatibility kwargs. """ self.config = config.copy(copy_frozen=False) self.env = None super().__init__(**kwargs) # This eager check is necessary for certain all-framework tests # that use tf's eager_mode() context generator. if ( tf1 and (self.config.framework_str == "tf2" or config.enable_tf1_exec_eagerly) and not tf1.executing_eagerly() ): tf1.enable_eager_execution()
[docs] @abc.abstractmethod def assert_healthy(self): """Checks that self.__init__() has been completed properly. Useful in case an `EnvRunner` is run as @ray.remote (Actor) and the owner would like to make sure the Ray Actor has been properly initialized. Raises: AssertionError: If the EnvRunner Actor has NOT been properly initialized. """
# TODO: Make this an abstract method that must be implemented.
[docs] def make_env(self): """Creates the RL environment for this EnvRunner and assigns it to `self.env`. Note that users should be able to change the EnvRunner's config (e.g. change `self.config.env_config`) and then call this method to create new environments with the updated configuration. It should also be called after a failure of an earlier env in order to clean up the existing env (for example `close()` it), re-create a new one, and then continue sampling with that new env. """ pass
[docs] @abc.abstractmethod def sample(self, **kwargs) -> Any: """Returns experiences (of any form) sampled from this EnvRunner. The exact nature and size of collected data are defined via the EnvRunner's config and may be overridden by the given arguments. Args: **kwargs: Forward compatibility kwargs. Returns: The collected experience in any form. """
[docs] @abc.abstractmethod def get_spaces(self) -> Dict[str, Tuple[gym.Space, gym.Space]]: """Returns a dict mapping ModuleIDs to 2-tuples of obs- and action space. The returned dict might also contain an extra key `__env__`, which maps to a 2-tuple of the bare Env's observation- and action spaces. """
[docs] def stop(self) -> None: """Releases all resources used by this EnvRunner. For example, when using a gym.Env in this EnvRunner, you should make sure that its `close()` method is called. """ pass
def __del__(self) -> None: """If this Actor is deleted, clears all resources used by it.""" pass def _try_env_reset(self): """Tries resetting the env and - if an error orrurs - handles it gracefully.""" # Try to reset. try: obs, infos = self.env.reset() # Everything ok -> return. return obs, infos # Error. except Exception as e: # If user wants to simply restart the env -> recreate env and try again # (calling this method recursively until success). if self.config.restart_failed_sub_environments: logger.exception( "Resetting the env resulted in an error! The original error " f"is: {e.args[0]}" ) # Recreate the env and simply try again. self.make_env() return self._try_env_reset() else: raise e def _try_env_step(self, actions): """Tries stepping the env and - if an error orrurs - handles it gracefully.""" try: results = self.env.step(actions) return results except Exception as e: if self.config.restart_failed_sub_environments: logger.exception( "Stepping the env resulted in an error! The original error " f"is: {e.args[0]}" ) # Recreate the env. self.make_env() # And return that the stepping failed. The caller will then handle # specific cleanup operations (for example discarding thus-far collected # data and repeating the step attempt). return ENV_STEP_FAILURE else: raise e def _convert_to_tensor(self, struct) -> TensorType: """Converts structs to a framework-specific tensor.""" if self.config.framework_str == "torch": return convert_to_torch_tensor(struct) else: return tree.map_structure(tf.convert_to_tensor, struct)