Source code for ray.train._internal.session

import functools
import logging
import os
import platform
import queue
import sys
import threading
import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type

import ray
from ray.air._internal.util import RunnerThread, StartTraceback
from ray.air.constants import (
    _ERROR_FETCH_TIMEOUT,
    _RESULT_FETCH_TIMEOUT,
    SESSION_MISUSE_LOG_ONCE_KEY,
    TIME_THIS_ITER_S,
    TIMESTAMP,
)
from ray.data import Dataset
from ray.train import Checkpoint
from ray.train._internal.accelerator import Accelerator
from ray.train._internal.storage import StorageContext
from ray.train.constants import (
    CHECKPOINT_DIR_NAME,
    DETAILED_AUTOFILLED_KEYS,
    RAY_CHDIR_TO_TRIAL_DIR,
    TIME_TOTAL_S,
    WORKER_HOSTNAME,
    WORKER_NODE_IP,
    WORKER_PID,
    _v2_migration_warnings_enabled,
)
from ray.train.error import SessionMisuseError
from ray.train.utils import _log_deprecation_warning
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.debug import log_once
from ray.util.placement_group import _valid_resource_shape
from ray.util.scheduling_strategies import (
    PlacementGroupSchedulingStrategy,
    SchedulingStrategyT,
)

if TYPE_CHECKING:
    from ray.data import DataIterator
    from ray.tune.execution.placement_groups import PlacementGroupFactory


logger = logging.getLogger(__name__)


@dataclass
class TrialInfo:
    """The trial information to propagate to TrainSession."""

    name: str
    id: str
    resources: Dict[str, float]
    logdir: str
    driver_ip: str
    driver_node_id: str
    experiment_name: Optional[str] = None
    run_id: Optional[str] = None


class _FutureTrainingResult:
    """A future that will be resolved to a `_TrainingResult`.

    This is needed for specific schedulers such as PBT that schedule saves.

    This wrapper should be removed after refactoring PBT to not schedule saves anymore.
    """

    def __init__(self, future: ray.ObjectRef):
        self.future = future

    def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
        """Resolve into ``_TrainingResult``.

        This will return None for function trainables if no checkpoint has been
        saved before.
        """
        if block:
            timeout = None
        else:
            timeout = 1e-9
        try:
            return ray.get(self.future, timeout=timeout)
        except TimeoutError:
            # Not ready, yet
            pass
        except Exception as exc:
            logger.error(f"Error resolving result: {exc}")


class _TrainingResult:
    """A (checkpoint, metrics) result reported by the user."""

    def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]):
        self.checkpoint = checkpoint
        self.metrics = metrics

    def __repr__(self) -> str:
        return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})"


# TODO(xwjiang): This needs a better name.
@DeveloperAPI
class _TrainSession:
    """Holds information for training on each worker."""

    def __init__(
        self,
        training_func: Callable,
        world_rank: Optional[int],
        local_rank: Optional[int],
        node_rank: Optional[int],
        local_world_size: Optional[int],
        world_size: Optional[int],
        trial_info: Optional[TrialInfo] = None,
        dataset_shard: Optional[Dict[str, Dataset]] = None,
        metadata: Dict[str, Any] = None,
        checkpoint: Optional[Checkpoint] = None,
        detailed_autofilled_metrics: bool = False,
        storage: Optional[StorageContext] = None,
        synchronous_result_reporting: bool = False,
    ):
        # `synchronous_result_reporting` refers to whether or not the
        # training function is immediately unblocked to continue running
        # after the main thread receives its result.
        # Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True,
        # the worker that produces a result first will immediately will continue
        # onto the next iteration.
        # Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`,
        # training will only continue with an explicit call to `session.get_next`.
        # Synchronous reporting in example 2 is needed for Tune schedulers to
        # be able to stop the execution of the training function at will,
        # for advanced pausing schedulers (PBT, BOHB) and actor reuse.
        self.synchronous_result_reporting = synchronous_result_reporting

        # Ray Train worker properties
        # Note: These are set to None for Tune function Trainables.
        self.dataset_shard = dataset_shard
        self.metadata = metadata

        self.world_rank = world_rank
        self.local_rank = local_rank
        self.node_rank = node_rank
        self.local_world_size = local_world_size
        self.world_size = world_size

        assert storage
        logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}")

        # NOTE: `reset` will initialize many properties needed to start running the
        # training_func as a thread.
        self.reset(
            training_func=training_func,
            trial_info=trial_info,
            storage=storage,
            loaded_checkpoint=checkpoint,
        )

        # Autofilled metrics attributes.
        self.detailed_autofilled_metrics = detailed_autofilled_metrics
        self.last_report_time = time.time()
        self.iteration = 0
        self.time_total = 0.0
        self.local_ip = self.get_current_ip()

        self.accelerator = None
        self._state = {}

    def get_state(self, key: str) -> Any:
        return self._state.get(key)

    def set_state(self, key: str, value: Any):
        self._state[key] = value

    def get_current_ip(self):
        self.local_ip = ray.util.get_node_ip_address()
        return self.local_ip

    def start(self):
        """Starts the training thread."""
        self.training_started = True
        self.training_thread.start()

    def reset(
        self,
        training_func: Callable,
        trial_info: TrialInfo,
        storage: StorageContext,
        loaded_checkpoint=None,
    ):
        # This lock is used to control the execution of the training thread.
        self.continue_lock = threading.Semaphore(0)

        # This event is used to signal the training thread to stop.
        self.stop_event = threading.Event()

        # Queue for sending results across threads.
        self.result_queue = queue.Queue(1)

        # Queue for raising exceptions from runner thread to main thread.
        # The error queue has a max size of one to prevent stacking error and force
        # error reporting to block until finished.
        self.error_queue = queue.Queue(1)

        # The Thread object that is running the training function.
        self.training_thread = RunnerThread(
            target=training_func, daemon=True, error_queue=self.error_queue
        )

        # Possibly override with new state
        self.trial_info = trial_info
        self.storage = storage
        self.loaded_checkpoint = loaded_checkpoint

        # Reset state
        self._state = {}
        self.ignore_report = False
        self.training_started = False
        self._first_report = True

        # Change the working directory to a special trial folder.
        # This is to ensure that all Ray Train workers have a common working directory.
        os.makedirs(storage.trial_working_directory, exist_ok=True)
        if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
            logger.debug(
                f"Changing the working directory to: {storage.trial_working_directory}"
            )
            os.chdir(storage.trial_working_directory)

    def pause_reporting(self):
        """Ignore all future ``session.report()`` calls."""
        self.ignore_report = True

    def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
        """Finishes the training thread.

        Raises any Exception from training.
        """
        # Set the stop event for the training thread to gracefully exit.
        self.stop_event.set()

        # Release the lock so that training thread can process this event.
        self.continue_lock.release()

        # Force a final (blocking) sync of artifacts in the trial path to storage.
        self.storage.persist_artifacts(force=True)

        # Wait for training to finish.
        # This will raise any errors that occur during training, including SystemError
        # This returns the result of the training function.
        output = None
        if self.training_started:
            output = self.training_thread.join(timeout=timeout)

        return output

    def get_next(self) -> Optional[_TrainingResult]:
        """Gets the next ``_TrainingResult`` from the result queue.

        If the result queue is empty, then this function returns ``None``.
        """
        if not self.training_started:
            raise RuntimeError("Please call start before calling get_next.")

        if self.synchronous_result_reporting:
            # There's no need to release the lock on the first report
            # since `start` already started the training thread.
            if not self._first_report:
                # Release the lock to trigger training to continue,
                # until the next call to report.
                self.continue_lock.release()
            self._first_report = False

        result = None
        # While training is still ongoing, attempt to get the result.
        while result is None and self.training_thread.is_alive():
            try:
                result = self.result_queue.get(
                    block=True, timeout=_RESULT_FETCH_TIMEOUT
                )
            except queue.Empty:
                pass

        # If no result was found, then the runner must no longer be alive.
        if result is None:
            # Try one last time to fetch results in case results were
            # reported in between the time of the last check and the
            # termination of the thread runner.
            try:
                result = self.result_queue.get(
                    block=False, timeout=_RESULT_FETCH_TIMEOUT
                )
            except queue.Empty:
                pass

        # check if error occurred inside the thread runner.
        if result is None:
            # only raise an error from the runner if all results are consumed
            self._report_thread_runner_error(block=True)
        else:
            if not self.error_queue.empty():
                logger.debug(
                    (
                        "Runner error waiting to be raised in main thread. "
                        "Logging all available results first."
                    )
                )

        if not self.synchronous_result_reporting:
            # At this point, the training thread has reached
            # the `train.report` and is blocked there.
            # If performing asynchronous result reporting,
            # release the lock to allow each worker to keep training
            # immediately after the coordinator fetches their result.
            self.continue_lock.release()

        # Return None if there are no more results to fetch.
        return result

    def _auto_fill_metrics(self, result: dict) -> dict:
        """Add autofilled metrics and update attributes."""
        current_time = time.time()
        current_datetime = datetime.now()
        if TIME_THIS_ITER_S in result:
            time_this_iter = result[TIME_THIS_ITER_S]
        else:
            time_this_iter = current_time - self.last_report_time
        self.iteration += 1
        self.time_total += time_this_iter
        self.last_report_time = current_time

        auto_filled_metrics = {
            TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
            TIME_TOTAL_S: self.time_total,
            WORKER_PID: os.getpid(),
            WORKER_HOSTNAME: platform.node(),
            WORKER_NODE_IP: self.local_ip,
        }

        if not self.detailed_autofilled_metrics:
            auto_filled_metrics = {
                k: v
                for k, v in auto_filled_metrics.items()
                if k not in DETAILED_AUTOFILLED_KEYS
            }

        result = result.copy()
        result.update(auto_filled_metrics)
        return result

    def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
        """Add autofilled metrics and update attributes."""
        current_datetime = datetime.now()

        auto_filled_metrics = {
            TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
        }
        result = result.copy()
        result.update(auto_filled_metrics)
        return result

    def _report_thread_runner_error(self, block=False):
        try:
            e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
            raise StartTraceback from e
        except queue.Empty:
            pass

    def _report_training_result(self, training_result: _TrainingResult) -> None:
        """Place a training result on the result queue for the main thread to process,
        then block until the main thread signals that training should continue.

        NOTE: This is used internally to report results from Train to Tune
        without persisting checkpoints to storage 2 times.
        `report` is the public API that directly persists to storage, which
        should only be called by user code.
        """
        if training_result.checkpoint:
            # NOTE: This populates `train.get_checkpoint`
            self.loaded_checkpoint = training_result.checkpoint

        # Add result to a thread-safe queue.
        self.result_queue.put(training_result, block=True)

        # Acquire lock to stop the training thread until main thread
        # triggers resume.
        self.continue_lock.acquire()

        # If the trial should be terminated, exit gracefully.
        # NOTE: This is only really useful if `synchronous_result_reporting=True`.
        # Otherwise, the lock is immediately released on reporting, and this
        # check is skipped before the main thread decides to set the stop event.
        if self.stop_event.is_set():
            self.stop_event.clear()
            sys.exit(0)

    def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
        # Special case: early fail for Torch tensors
        if "torch" in sys.modules:
            from ray.air._internal.torch_utils import contains_tensor

            if contains_tensor(metrics):
                raise ValueError(
                    "Passing objects containg Torch tensors as metrics "
                    "is not supported as it will throw an exception on "
                    "deserialization. You can either convert the tensors "
                    "to Python objects or report a `train.Checkpoint` "
                    "with `ray.train.report` to store your Torch objects."
                )

        if self.ignore_report:
            return

        metrics = self._auto_fill_metrics(metrics)

        persisted_checkpoint = None
        if checkpoint:
            self.storage._update_checkpoint_index(metrics)

            # Persist the reported checkpoint files to storage.
            persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)

            metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
        else:
            metrics[CHECKPOINT_DIR_NAME] = None

        # Persist trial artifacts to storage.
        force_artifact_sync = (
            persisted_checkpoint
            and self.storage.sync_config.sync_artifacts_on_checkpoint
        )
        self.storage.persist_artifacts(force=force_artifact_sync)

        # Set additional user metadata from the Trainer.
        if persisted_checkpoint and self.metadata:
            user_metadata = persisted_checkpoint.get_metadata()
            for k, v in self.metadata.items():
                # Update keys not already set by the user. This gives user-set keys
                # precedence over keys set at the Trainer level.
                if k not in user_metadata:
                    user_metadata[k] = v
            persisted_checkpoint.set_metadata(user_metadata)

        result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)

        self._report_training_result(result)

    @property
    def experiment_name(self) -> str:
        return self.trial_info.experiment_name

    @property
    def trial_name(self) -> str:
        return self.trial_info.name

    @property
    def trial_id(self) -> str:
        return self.trial_info.id

    @property
    def run_id(self) -> str:
        return self.trial_info.run_id

    @property
    def trial_resources(self) -> "PlacementGroupFactory":
        return self.trial_info.resources

    @property
    def trial_dir(self) -> str:
        return self.trial_info.logdir

    def get_dataset_shard(
        self,
        dataset_name: Optional[str] = None,
    ) -> Optional["DataIterator"]:
        shard = self.dataset_shard
        if shard is None:
            warnings.warn(
                "No dataset passed in. Returning None. Make sure to "
                "pass in a Dataset to Trainer.run to use this "
                "function."
            )
        elif isinstance(shard, dict):
            if not dataset_name:
                raise RuntimeError(
                    "Multiple datasets were passed into ``Trainer``, "
                    "but no ``dataset_name`` is passed into "
                    "``get_dataset_shard``. Please specify which "
                    "dataset shard to retrieve."
                )
            return shard.get(dataset_name)
        return shard


# Cache of resource dicts that have been checked by the launch hook already.
_checked_resources: Set[frozenset] = set()

# Global _TrainSession object initialized by Ray Tune function trainables
# and Ray Train V1 workers.
_session: Optional[_TrainSession] = None


def _tune_task_and_actor_launch_hook(
    fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT]
):
    """Launch hook to catch nested tasks that can't fit in the placement group.

    This gives users a nice warning in case they launch a nested task in a Tune trial
    without reserving resources in the trial placement group to fit it.
    """

    # Already checked, skip for performance reasons.
    key = frozenset({(k, v) for k, v in resources.items() if v > 0})
    if not key or key in _checked_resources:
        return

    # No need to check if placement group is None.
    if (
        not isinstance(strategy, PlacementGroupSchedulingStrategy)
        or strategy.placement_group is None
    ):
        return

    # Check if the resource request is targeting the current placement group.
    cur_pg = ray.util.get_current_placement_group()
    if not cur_pg or strategy.placement_group.id != cur_pg.id:
        return

    _checked_resources.add(key)

    # Check if the request can be fulfilled by the current placement group.
    pgf = get_trial_resources()

    if pgf.head_bundle_is_empty:
        available_bundles = cur_pg.bundle_specs[0:]
    else:
        available_bundles = cur_pg.bundle_specs[1:]

    # Check if the request can be fulfilled by the current placement group.
    if _valid_resource_shape(resources, available_bundles):
        return

    if fn.class_name:
        submitted = "actor"
        name = fn.module_name + "." + fn.class_name + "." + fn.function_name
    else:
        submitted = "task"
        name = fn.module_name + "." + fn.function_name

    # Normalize the resource spec so it looks the same as the placement group bundle.
    main_resources = cur_pg.bundle_specs[0]
    resources = {k: float(v) for k, v in resources.items() if v > 0}

    raise RuntimeError(
        f"No trial resources are available for launching the {submitted} `{name}`. "
        "To resolve this, specify the Tune option:\n\n"
        ">  resources_per_trial=tune.PlacementGroupFactory(\n"
        f">    [{main_resources}] + [{resources}] * N\n"
        ">  )\n\n"
        f"Where `N` is the number of slots to reserve for trial {submitted}s. "
        "If you are using a Ray training library, there might be a utility function "
        "to set this automatically for you. For more information, refer to "
        "https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html"
    )


def init_session(*args, **kwargs) -> None:
    global _session
    if _session:
        raise ValueError(
            "A Train session is already in use. Do not call "
            "`init_session()` manually."
        )

    # Setup hooks for generating placement group resource deadlock warnings.
    from ray import actor, remote_function

    if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ:
        actor._actor_launch_hook = _tune_task_and_actor_launch_hook
        remote_function._task_launch_hook = _tune_task_and_actor_launch_hook

    _session = _TrainSession(*args, **kwargs)


def get_session() -> Optional[_TrainSession]:
    return _session


def shutdown_session():
    """Shuts down the initialized session."""
    global _session
    _session = None


def _raise_accelerator_session_misuse():
    """Raises a SessionMisuseError because a utility function was used improperly."""
    raise SessionMisuseError(
        "prepare/accelerate utility functions should be called inside a training "
        "function executed by `Trainer.run`"
    )


def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
    """The accelerator for this training session.

    If an accelerator has not been set, then this method will construct an
    accelerator using the provided accelerator class.

    Raises:
        SessionMisuseError: if the session is uninitialized.
    """
    session = get_session()
    if session is None:
        _raise_accelerator_session_misuse()
    if session.accelerator is None:
        session.accelerator = default_accelerator_cls()
    return session.accelerator


def set_accelerator(accelerator: Accelerator) -> None:
    """Sets the accelerator for this training session.

    Args:
        accelerator: The accelerator to use for training.

    Raises:
        SessionMisuseError: if the session is unitialized.
        RuntimeError: if the accelerator has already been set.
    """
    session = get_session()
    if session is None:
        _raise_accelerator_session_misuse()
    if session.accelerator is not None:
        raise RuntimeError("Cannot change accelerator once set.")
    session.accelerator = accelerator


def _warn_session_misuse(default_value: Any = None):
    """Warns if fn is being used outside of session and returns ``default_value``."""

    def inner(fn: Callable):
        fn_name = fn.__name__

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            session = get_session()
            if not session:
                if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
                    warnings.warn(
                        f"`{fn_name}` is meant to only be "
                        "called inside a function that is executed by a Tuner"
                        f" or Trainer. Returning `{default_value}`."
                    )
                return default_value
            return fn(*args, **kwargs)

        return wrapper

    return inner


[docs] @PublicAPI(stability="stable") @_warn_session_misuse() def report( metrics: Dict, *, checkpoint: Optional[Checkpoint] = None, checkpoint_dir_name: Optional[str] = None, ) -> None: """Report metrics and optionally save a checkpoint. If a checkpoint is provided, it will be :ref:`persisted to storage <persistent-storage-guide>`. If this is called in multiple distributed training workers: - Only the metrics reported by the rank 0 worker will be tracked by Ray Train. See :ref:`the metrics logging guide <train-monitoring-and-logging>`. - A checkpoint will be registered as long as one or more workers reports checkpoint that is not None. See the :ref:`checkpointing guide <train-dl-saving-checkpoints>`. - Checkpoints from multiple workers will be merged into one directory in persistent storage. See :ref:`the distributed checkpointing guide <train-distributed-checkpointing>`. .. note:: Each invocation of this method will automatically increment the underlying ``training_iteration`` number. The physical meaning of this "iteration" is defined by user depending on how often they call ``report``. It does not necessarily map to one epoch. .. warning:: All workers must call `ray.train.report` the same number of times so that Ray Train can properly synchronize the training state across workers. Otherwise, your training will hang. .. warning:: This method does NOT act as a barrier for distributed training workers. Workers will upload their checkpoint, then continue training immediately. If you need to synchronize workers, you can use a framework-native barrier such as `torch.distributed.barrier()`. Example: .. testcode:: import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... # torch.save(...) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) # Example: Only the rank 0 worker uploads the checkpoint. if ray.train.get_context().get_world_rank() == 0: train.report(metrics, checkpoint=checkpoint) else: train.report(metrics, checkpoint=None) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) ) Args: metrics: The metrics you want to report. checkpoint: The optional checkpoint you want to report. """ if checkpoint_dir_name is not None: logger.warning( "`checkpoint_dir_name` is only supported in the new Ray Train " "implementation, which can be enabled with `RAY_TRAIN_V2_ENABLED=1`. " "This argument will be ignored." ) # If we are running in a Tune function, switch to `ray.tune.report`. from ray.tune.trainable.trainable_fn_utils import _in_tune_session if _in_tune_session(): import ray.tune if _v2_migration_warnings_enabled(): _log_deprecation_warning( "`ray.train.report` should be switched to " "`ray.tune.report` when running in a function " "passed to Ray Tune. This will be an error in the future. " "See this issue for more context: " "https://github.com/ray-project/ray/issues/49454" ) return ray.tune.report(metrics, checkpoint=checkpoint) get_session().report(metrics, checkpoint=checkpoint)
[docs] @PublicAPI(stability="stable") @_warn_session_misuse() def get_checkpoint() -> Optional[Checkpoint]: """Access the latest reported checkpoint to resume from if one exists. Example: .. testcode:: import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) train.report(metrics, checkpoint=checkpoint) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) ) Returns: Checkpoint object if the session is currently being resumed. Otherwise, return None. """ # If we are running in a Tune function, switch to `ray.tune.get_checkpoint`. from ray.tune.trainable.trainable_fn_utils import _in_tune_session if _in_tune_session(): import ray.tune if _v2_migration_warnings_enabled(): _log_deprecation_warning( "`ray.train.get_checkpoint` should be switched to " "`ray.tune.get_checkpoint` when running in a function " "passed to Ray Tune. This will be an error in the future. " "See this issue for more context: " "https://github.com/ray-project/ray/issues/49454" ) return ray.tune.get_checkpoint() return get_session().loaded_checkpoint
@PublicAPI(stability="beta") @_warn_session_misuse() def get_metadata() -> Dict[str, Any]: """User metadata dict passed to the Trainer constructor.""" return get_session().metadata @PublicAPI(stability="beta") @_warn_session_misuse() def get_experiment_name() -> str: """Experiment name for the corresponding trial.""" return get_session().experiment_name @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_name() -> str: """Trial name for the corresponding trial.""" return get_session().trial_name @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_id() -> str: """Trial id for the corresponding trial.""" return get_session().trial_id @PublicAPI(stability="alpha") @_warn_session_misuse() def get_run_id() -> str: """Unique Train Run id for the corresponding trial.""" return get_session().run_id @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_resources() -> "PlacementGroupFactory": """Trial resources for the corresponding trial.""" return get_session().trial_resources @PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_dir() -> str: """Log directory corresponding to the trial directory for a Tune session. If calling from a Train session, this will give the trial directory of its parent Tune session. .. testcode:: import ray.tune def train_func(config): print(ray.tune.get_context().get_trial_dir()) tuner = ray.tune.Tuner(train_func) tuner.fit() .. testoutput:: :options: +MOCK /Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40 """ return get_session().trial_dir @PublicAPI(stability="beta") @_warn_session_misuse(default_value=1) def get_world_size() -> int: """Get the current world size (i.e. total number of workers) for this run. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer NUM_WORKERS = 2 def train_loop_per_worker(config): assert train.get_context().get_world_size() == NUM_WORKERS train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=NUM_WORKERS), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "world_size"): raise RuntimeError( "`get_world_size` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.world_size @PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_world_rank() -> int: """Get the world rank of this worker. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer def train_loop_per_worker(config): if train.get_context().get_world_rank() == 0: print("Worker 0") train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "world_rank"): raise RuntimeError( "`get_world_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.world_rank @PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). .. testcode:: import torch import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): if torch.cuda.is_available(): torch.cuda.set_device(train.get_context().get_local_rank()) ... train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2, use_gpu=True), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "local_rank"): raise RuntimeError( "`get_local_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.local_rank @PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_world_size() -> int: """Get the local world size of this node (i.e. number of workers on this node). Example: .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(): print(train.get_context().get_local_world_size()) train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TorchTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "local_world_size"): raise RuntimeError( "`get_local_world_size` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.local_world_size @PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_node_rank() -> int: """Get the rank of this node. Example: .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(): print(train.get_context().get_node_rank()) train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TorchTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() .. testoutput:: :hide: ... """ session = get_session() if not hasattr(session, "node_rank"): raise RuntimeError( "`get_node_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.node_rank
[docs] @PublicAPI(stability="stable") @_warn_session_misuse() def get_dataset_shard( dataset_name: Optional[str] = None, ) -> Optional["DataIterator"]: """Returns the :class:`ray.data.DataIterator` shard for this worker. Call :meth:`~ray.data.DataIterator.iter_torch_batches` or :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the appropriate framework-specific data type. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): ... for epoch in range(2): # Trainer will automatically handle sharding. data_shard = train.get_dataset_shard("train") for batch in data_shard.iter_torch_batches(): ... train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset} ) trainer.fit() .. testoutput:: :hide: ... Args: dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then specifies which dataset shard to return. Returns: The ``DataIterator`` shard to use for this worker. If no dataset is passed into Trainer, then return None. """ session = get_session() if not hasattr(session, "get_dataset_shard"): raise RuntimeError( "`get_dataset_shard` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.get_dataset_shard(dataset_name)
@DeveloperAPI @_warn_session_misuse() def get_storage() -> StorageContext: """Returns the :class:`~ray.train._internal.storage.StorageContext` storage context which gives advanced access to the filesystem and paths configured through `RunConfig`. NOTE: This is a developer API, and the `StorageContext` interface may change without notice between minor versions. """ return get_session().storage