Source code for ray.rllib.env.env_runner

import abc
from typing import Any, TYPE_CHECKING

import tree  # pip install dm_tree

from ray.rllib.utils.actor_manager import FaultAwareApply
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import TensorType
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.rllib.algorithms.algorithm_config import AlgorithmConfig

tf1, tf, _ = try_import_tf()


# TODO (sven): As soon as RolloutWorker is no longer supported, make this base class
#  a Checkpointable. Currently, only some of its subclasses are Checkpointables.
[docs]@PublicAPI(stability="alpha") class EnvRunner(FaultAwareApply, metaclass=abc.ABCMeta): """Base class for distributed RL-style data collection from an environment. The EnvRunner API's core functionalities can be summarized as: - Gets configured via passing a AlgorithmConfig object to the constructor. Normally, subclasses of EnvRunner then construct their own environment (possibly vectorized) copies and RLModules/Policies and use the latter to step through the environment in order to collect training data. - Clients of EnvRunner can use the `sample()` method to collect data for training from the environment(s). - EnvRunner offers parallelism via creating n remote Ray Actors based on this class. Use `ray.remote([resources])(EnvRunner)` method to create the corresponding Ray remote class. Then instantiate n Actors using the Ray `[ctor].remote(...)` syntax. - EnvRunner clients can get information about the server/node on which the individual Actors are running. """
[docs] def __init__(self, *, config: "AlgorithmConfig", **kwargs): """Initializes an EnvRunner instance. Args: config: The AlgorithmConfig to use to setup this EnvRunner. **kwargs: Forward compatibility kwargs. """ self.config = config.copy(copy_frozen=False) super().__init__(**kwargs) # This eager check is necessary for certain all-framework tests # that use tf's eager_mode() context generator. if ( tf1 and (self.config.framework_str == "tf2" or config.enable_tf1_exec_eagerly) and not tf1.executing_eagerly() ): tf1.enable_eager_execution()
[docs] @abc.abstractmethod def assert_healthy(self): """Checks that self.__init__() has been completed properly. Useful in case an `EnvRunner` is run as @ray.remote (Actor) and the owner would like to make sure the Ray Actor has been properly initialized. Raises: AssertionError: If the EnvRunner Actor has NOT been properly initialized. """
[docs] @abc.abstractmethod def sample(self, **kwargs) -> Any: """Returns experiences (of any form) sampled from this EnvRunner. The exact nature and size of collected data are defined via the EnvRunner's config and may be overridden by the given arguments. Args: **kwargs: Forward compatibility kwargs. Returns: The collected experience in any form. """
[docs] def stop(self) -> None: """Releases all resources used by this EnvRunner. For example, when using a gym.Env in this EnvRunner, you should make sure that its `close()` method is called. """ pass
def __del__(self) -> None: """If this Actor is deleted, clears all resources used by it.""" pass def _convert_to_tensor(self, struct) -> TensorType: """Converts structs to a framework-specific tensor.""" if self.config.framework_str == "torch": return convert_to_torch_tensor(struct) else: return tree.map_structure(tf.convert_to_tensor, struct)