import functools
import logging
import os
import platform
import queue
import sys
import threading
import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type
import ray
from ray.air._internal.util import RunnerThread, StartTraceback
from ray.air.constants import (
    _ERROR_FETCH_TIMEOUT,
    _RESULT_FETCH_TIMEOUT,
    SESSION_MISUSE_LOG_ONCE_KEY,
    TIME_THIS_ITER_S,
    TIMESTAMP,
)
from ray.data import Dataset
from ray.train import Checkpoint
from ray.train._internal.accelerator import Accelerator
from ray.train._internal.storage import StorageContext
from ray.train.constants import (
    CHECKPOINT_DIR_NAME,
    DETAILED_AUTOFILLED_KEYS,
    RAY_CHDIR_TO_TRIAL_DIR,
    TIME_TOTAL_S,
    WORKER_HOSTNAME,
    WORKER_NODE_IP,
    WORKER_PID,
    _v2_migration_warnings_enabled,
)
from ray.train.error import SessionMisuseError
from ray.train.utils import _log_deprecation_warning
from ray.util import queue as ray_queue
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.debug import log_once
from ray.util.placement_group import _valid_resource_shape
from ray.util.scheduling_strategies import (
    PlacementGroupSchedulingStrategy,
    SchedulingStrategyT,
)
if TYPE_CHECKING:
    from ray.data import DataIterator
    from ray.tune.execution.placement_groups import PlacementGroupFactory
logger = logging.getLogger(__name__)
@dataclass
class TrialInfo:
    """The trial information to propagate to TrainSession."""
    name: str
    id: str
    resources: Dict[str, float]
    logdir: str
    driver_ip: str
    driver_node_id: str
    experiment_name: Optional[str] = None
    run_id: Optional[str] = None
class _FutureTrainingResult:
    """A future that will be resolved to a `_TrainingResult`.
    This is needed for specific schedulers such as PBT that schedule saves.
    This wrapper should be removed after refactoring PBT to not schedule saves anymore.
    """
    def __init__(self, future: ray.ObjectRef):
        self.future = future
    def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
        """Resolve into ``_TrainingResult``.
        This will return None for function trainables if no checkpoint has been
        saved before.
        """
        if block:
            timeout = None
        else:
            timeout = 1e-9
        try:
            return ray.get(self.future, timeout=timeout)
        except TimeoutError:
            # Not ready, yet
            pass
        except Exception as exc:
            logger.error(f"Error resolving result: {exc}")
class _TrainingResult:
    """A (checkpoint, metrics) result reported by the user."""
    def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]):
        self.checkpoint = checkpoint
        self.metrics = metrics
    def __repr__(self) -> str:
        return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})"
# TODO(xwjiang): This needs a better name.
@DeveloperAPI
class _TrainSession:
    """Holds information for training on each worker."""
    def __init__(
        self,
        training_func: Callable,
        world_rank: Optional[int],
        local_rank: Optional[int],
        node_rank: Optional[int],
        local_world_size: Optional[int],
        world_size: Optional[int],
        trial_info: Optional[TrialInfo] = None,
        dataset_shard: Optional[Dict[str, Dataset]] = None,
        metadata: Dict[str, Any] = None,
        checkpoint: Optional[Checkpoint] = None,
        detailed_autofilled_metrics: bool = False,
        storage: Optional[StorageContext] = None,
        synchronous_result_reporting: bool = False,
    ):
        # `synchronous_result_reporting` refers to whether or not the
        # training function is immediately unblocked to continue running
        # after the main thread receives its result.
        # Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True,
        # the worker that produces a result first will immediately will continue
        # onto the next iteration.
        # Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`,
        # training will only continue with an explicit call to `session.get_next`.
        # Synchronous reporting in example 2 is needed for Tune schedulers to
        # be able to stop the execution of the training function at will,
        # for advanced pausing schedulers (PBT, BOHB) and actor reuse.
        self.synchronous_result_reporting = synchronous_result_reporting
        # Ray Train worker properties
        # Note: These are set to None for Tune function Trainables.
        self.dataset_shard = dataset_shard
        self.metadata = metadata
        self.world_rank = world_rank
        self.local_rank = local_rank
        self.node_rank = node_rank
        self.local_world_size = local_world_size
        self.world_size = world_size
        assert storage
        logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}")
        # NOTE: `reset` will initialize many properties needed to start running the
        # training_func as a thread.
        self.reset(
            training_func=training_func,
            trial_info=trial_info,
            storage=storage,
            loaded_checkpoint=checkpoint,
        )
        # Autofilled metrics attributes.
        self.detailed_autofilled_metrics = detailed_autofilled_metrics
        self.last_report_time = time.time()
        self.iteration = 0
        self.time_total = 0.0
        self.local_ip = self.get_current_ip()
        self.accelerator = None
        self._state = {}
    def get_state(self, key: str) -> Any:
        return self._state.get(key)
    def set_state(self, key: str, value: Any):
        self._state[key] = value
    def get_current_ip(self):
        self.local_ip = ray.util.get_node_ip_address()
        return self.local_ip
    def start(self):
        """Starts the training thread."""
        self.training_started = True
        self.training_thread.start()
    def reset(
        self,
        training_func: Callable,
        trial_info: TrialInfo,
        storage: StorageContext,
        loaded_checkpoint=None,
    ):
        # This lock is used to control the execution of the training thread.
        self.continue_lock = threading.Semaphore(0)
        # This event is used to signal the training thread to stop.
        self.stop_event = threading.Event()
        # Queue for sending results across threads.
        self.result_queue = queue.Queue(1)
        # Queue for sending results from training actor to main thread.
        self._inter_actor_queue: Optional[ray_queue.Queue[Dict]] = None
        # Queue for raising exceptions from runner thread to main thread.
        # The error queue has a max size of one to prevent stacking error and force
        # error reporting to block until finished.
        self.error_queue = queue.Queue(1)
        # The Thread object that is running the training function.
        self.training_thread = RunnerThread(
            target=training_func, daemon=True, error_queue=self.error_queue
        )
        # Possibly override with new state
        self.trial_info = trial_info
        self.storage = storage
        self.loaded_checkpoint = loaded_checkpoint
        # Reset state
        self._state = {}
        self.ignore_report = False
        self.training_started = False
        self._first_report = True
        # Change the working directory to a special trial folder.
        # This is to ensure that all Ray Train workers have a common working directory.
        os.makedirs(storage.trial_working_directory, exist_ok=True)
        if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
            logger.debug(
                f"Changing the working directory to: {storage.trial_working_directory}"
            )
            os.chdir(storage.trial_working_directory)
    def pause_reporting(self):
        """Ignore all future ``session.report()`` calls."""
        self.ignore_report = True
    def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
        """Finishes the training thread.
        Raises any Exception from training.
        """
        # Set the stop event for the training thread to gracefully exit.
        self.stop_event.set()
        # Release the lock so that training thread can process this event.
        self.continue_lock.release()
        # Force a final (blocking) sync of artifacts in the trial path to storage.
        self.storage.persist_artifacts(force=True)
        # Wait for training to finish.
        # This will raise any errors that occur during training, including SystemError
        # This returns the result of the training function.
        output = None
        if self.training_started:
            output = self.training_thread.join(timeout=timeout)
        return output
    def get_next(self) -> Optional[_TrainingResult]:
        """Gets the next ``_TrainingResult`` from the result queue.
        If the result queue is empty, then this function returns ``None``.
        """
        if not self.training_started:
            raise RuntimeError("Please call start before calling get_next.")
        if self.synchronous_result_reporting:
            # There's no need to release the lock on the first report
            # since `start` already started the training thread.
            if not self._first_report:
                # Release the lock to trigger training to continue,
                # until the next call to report.
                self.continue_lock.release()
            self._first_report = False
        result = None
        # While training is still ongoing, attempt to get the result.
        while result is None and self.training_thread.is_alive():
            result = self._get_result_from_queues(block=True)
        # If no result was found, then the runner must no longer be alive.
        if result is None:
            # Try one last time to fetch results in case results were
            # reported in between the time of the last check and the
            # termination of the thread runner.
            result = self._get_result_from_queues(block=False)
        # check if error occurred inside the thread runner.
        if result is None:
            # only raise an error from the runner if all results are consumed
            self._report_thread_runner_error(block=True)
        else:
            if not self.error_queue.empty():
                logger.debug(
                    (
                        "Runner error waiting to be raised in main thread. "
                        "Logging all available results first."
                    )
                )
        if not self.synchronous_result_reporting:
            # At this point, the training thread has reached
            # the `train.report` and is blocked there.
            # If performing asynchronous result reporting,
            # release the lock to allow each worker to keep training
            # immediately after the coordinator fetches their result.
            self.continue_lock.release()
        # Return None if there are no more results to fetch.
        return result
    def _get_or_create_inter_actor_queue(self):
        """Get or create the inter-actor queue."""
        if self._inter_actor_queue is None:
            self._inter_actor_queue = ray_queue.Queue(1, actor_options={"num_cpus": 0})
        return self._inter_actor_queue
    def _get_result_from_queues(self, block: bool) -> Optional[_TrainingResult]:
        """Get result from result queue. Pass result from training actor result queue if needed."""
        result = None
        if self._inter_actor_queue is not None:
            try:
                inter_actor_item = self._inter_actor_queue.get(
                    block=block, timeout=_RESULT_FETCH_TIMEOUT
                )
                if inter_actor_item:
                    # Must release continue_lock to allow report to work.
                    self.continue_lock.release()
                    self.report(inter_actor_item)
            except ray_queue.Empty:
                pass
        try:
            result = self.result_queue.get(block=block, timeout=_RESULT_FETCH_TIMEOUT)
        except queue.Empty:
            pass
        return result
    def _auto_fill_metrics(self, result: dict) -> dict:
        """Add autofilled metrics and update attributes."""
        current_time = time.time()
        current_datetime = datetime.now()
        if TIME_THIS_ITER_S in result:
            time_this_iter = result[TIME_THIS_ITER_S]
        else:
            time_this_iter = current_time - self.last_report_time
        self.iteration += 1
        self.time_total += time_this_iter
        self.last_report_time = current_time
        auto_filled_metrics = {
            TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
            TIME_TOTAL_S: self.time_total,
            WORKER_PID: os.getpid(),
            WORKER_HOSTNAME: platform.node(),
            WORKER_NODE_IP: self.local_ip,
        }
        if not self.detailed_autofilled_metrics:
            auto_filled_metrics = {
                k: v
                for k, v in auto_filled_metrics.items()
                if k not in DETAILED_AUTOFILLED_KEYS
            }
        result = result.copy()
        result.update(auto_filled_metrics)
        return result
    def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
        """Add autofilled metrics and update attributes."""
        current_datetime = datetime.now()
        auto_filled_metrics = {
            TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
        }
        result = result.copy()
        result.update(auto_filled_metrics)
        return result
    def _report_thread_runner_error(self, block=False):
        try:
            e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
            raise StartTraceback from e
        except queue.Empty:
            pass
    def _report_training_result(self, training_result: _TrainingResult) -> None:
        """Place a training result on the result queue for the main thread to process,
        then block until the main thread signals that training should continue.
        NOTE: This is used internally to report results from Train to Tune
        without persisting checkpoints to storage 2 times.
        `report` is the public API that directly persists to storage, which
        should only be called by user code.
        """
        if training_result.checkpoint:
            # NOTE: This populates `train.get_checkpoint`
            self.loaded_checkpoint = training_result.checkpoint
        # Add result to a thread-safe queue.
        self.result_queue.put(training_result, block=True)
        # Acquire lock to stop the training thread until main thread
        # triggers resume.
        self.continue_lock.acquire()
        # If the trial should be terminated, exit gracefully.
        # NOTE: This is only really useful if `synchronous_result_reporting=True`.
        # Otherwise, the lock is immediately released on reporting, and this
        # check is skipped before the main thread decides to set the stop event.
        if self.stop_event.is_set():
            self.stop_event.clear()
            sys.exit(0)
    def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
        # Special case: early fail for Torch tensors
        if "torch" in sys.modules:
            from ray.air._internal.torch_utils import contains_tensor
            if contains_tensor(metrics):
                raise ValueError(
                    "Passing objects containg Torch tensors as metrics "
                    "is not supported as it will throw an exception on "
                    "deserialization. You can either convert the tensors "
                    "to Python objects or report a `train.Checkpoint` "
                    "with `ray.train.report` to store your Torch objects."
                )
        if self.ignore_report:
            return
        metrics = self._auto_fill_metrics(metrics)
        persisted_checkpoint = None
        if checkpoint:
            self.storage._update_checkpoint_index(metrics)
            # Persist the reported checkpoint files to storage.
            persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)
            metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
        else:
            metrics[CHECKPOINT_DIR_NAME] = None
        # Persist trial artifacts to storage.
        force_artifact_sync = (
            persisted_checkpoint
            and self.storage.sync_config.sync_artifacts_on_checkpoint
        )
        self.storage.persist_artifacts(force=force_artifact_sync)
        # Set additional user metadata from the Trainer.
        if persisted_checkpoint and self.metadata:
            user_metadata = persisted_checkpoint.get_metadata()
            for k, v in self.metadata.items():
                # Update keys not already set by the user. This gives user-set keys
                # precedence over keys set at the Trainer level.
                if k not in user_metadata:
                    user_metadata[k] = v
            persisted_checkpoint.set_metadata(user_metadata)
        result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)
        self._report_training_result(result)
    @property
    def experiment_name(self) -> str:
        return self.trial_info.experiment_name
    @property
    def trial_name(self) -> str:
        return self.trial_info.name
    @property
    def trial_id(self) -> str:
        return self.trial_info.id
    @property
    def run_id(self) -> str:
        return self.trial_info.run_id
    @property
    def trial_resources(self) -> "PlacementGroupFactory":
        return self.trial_info.resources
    @property
    def trial_dir(self) -> str:
        return self.trial_info.logdir
    def get_dataset_shard(
        self,
        dataset_name: Optional[str] = None,
    ) -> Optional["DataIterator"]:
        shard = self.dataset_shard
        if shard is None:
            warnings.warn(
                "No dataset passed in. Returning None. Make sure to "
                "pass in a Dataset to Trainer.run to use this "
                "function."
            )
        elif isinstance(shard, dict):
            if not dataset_name:
                raise RuntimeError(
                    "Multiple datasets were passed into ``Trainer``, "
                    "but no ``dataset_name`` is passed into "
                    "``get_dataset_shard``. Please specify which "
                    "dataset shard to retrieve."
                )
            return shard.get(dataset_name)
        return shard
# Cache of resource dicts that have been checked by the launch hook already.
_checked_resources: Set[frozenset] = set()
# Global _TrainSession object initialized by Ray Tune function trainables
# and Ray Train V1 workers.
_session: Optional[_TrainSession] = None
def _tune_task_and_actor_launch_hook(
    fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT]
):
    """Launch hook to catch nested tasks that can't fit in the placement group.
    This gives users a nice warning in case they launch a nested task in a Tune trial
    without reserving resources in the trial placement group to fit it.
    """
    # Already checked, skip for performance reasons.
    key = frozenset({(k, v) for k, v in resources.items() if v > 0})
    if not key or key in _checked_resources:
        return
    # No need to check if placement group is None.
    if (
        not isinstance(strategy, PlacementGroupSchedulingStrategy)
        or strategy.placement_group is None
    ):
        return
    # Check if the resource request is targeting the current placement group.
    cur_pg = ray.util.get_current_placement_group()
    if not cur_pg or strategy.placement_group.id != cur_pg.id:
        return
    _checked_resources.add(key)
    # Check if the request can be fulfilled by the current placement group.
    pgf = get_trial_resources()
    if pgf.head_bundle_is_empty:
        available_bundles = cur_pg.bundle_specs[0:]
    else:
        available_bundles = cur_pg.bundle_specs[1:]
    # Check if the request can be fulfilled by the current placement group.
    if _valid_resource_shape(resources, available_bundles):
        return
    if fn.class_name:
        submitted = "actor"
        name = fn.module_name + "." + fn.class_name + "." + fn.function_name
    else:
        submitted = "task"
        name = fn.module_name + "." + fn.function_name
    # Normalize the resource spec so it looks the same as the placement group bundle.
    main_resources = cur_pg.bundle_specs[0]
    resources = {k: float(v) for k, v in resources.items() if v > 0}
    raise RuntimeError(
        f"No trial resources are available for launching the {submitted} `{name}`. "
        "To resolve this, specify the Tune option:\n\n"
        ">  resources_per_trial=tune.PlacementGroupFactory(\n"
        f">    [{main_resources}] + [{resources}] * N\n"
        ">  )\n\n"
        f"Where `N` is the number of slots to reserve for trial {submitted}s. "
        "If you are using a Ray training library, there might be a utility function "
        "to set this automatically for you. For more information, refer to "
        "https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html"
    )
def init_session(*args, **kwargs) -> None:
    global _session
    if _session:
        raise ValueError(
            "A Train session is already in use. Do not call "
            "`init_session()` manually."
        )
    # Setup hooks for generating placement group resource deadlock warnings.
    from ray import actor, remote_function
    if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ:
        actor._actor_launch_hook = _tune_task_and_actor_launch_hook
        remote_function._task_launch_hook = _tune_task_and_actor_launch_hook
    _session = _TrainSession(*args, **kwargs)
def get_session() -> Optional[_TrainSession]:
    return _session
def shutdown_session():
    """Shuts down the initialized session."""
    global _session
    _session = None
def _raise_accelerator_session_misuse():
    """Raises a SessionMisuseError because a utility function was used improperly."""
    raise SessionMisuseError(
        "prepare/accelerate utility functions should be called inside a training "
        "function executed by `Trainer.run`"
    )
def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
    """The accelerator for this training session.
    If an accelerator has not been set, then this method will construct an
    accelerator using the provided accelerator class.
    Raises:
        SessionMisuseError: if the session is uninitialized.
    """
    session = get_session()
    if session is None:
        _raise_accelerator_session_misuse()
    if session.accelerator is None:
        session.accelerator = default_accelerator_cls()
    return session.accelerator
def set_accelerator(accelerator: Accelerator) -> None:
    """Sets the accelerator for this training session.
    Args:
        accelerator: The accelerator to use for training.
    Raises:
        SessionMisuseError: if the session is unitialized.
        RuntimeError: if the accelerator has already been set.
    """
    session = get_session()
    if session is None:
        _raise_accelerator_session_misuse()
    if session.accelerator is not None:
        raise RuntimeError("Cannot change accelerator once set.")
    session.accelerator = accelerator
def _warn_session_misuse(default_value: Any = None):
    """Warns if fn is being used outside of session and returns ``default_value``."""
    def inner(fn: Callable):
        fn_name = fn.__name__
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            session = get_session()
            if not session:
                if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
                    warnings.warn(
                        f"`{fn_name}` is meant to only be "
                        "called inside a function that is executed by a Tuner"
                        f" or Trainer. Returning `{default_value}`."
                    )
                return default_value
            return fn(*args, **kwargs)
        return wrapper
    return inner
[docs]
@PublicAPI(stability="stable")
@_warn_session_misuse()
def report(
    metrics: Dict,
    *,
    checkpoint: Optional[Checkpoint] = None,
    checkpoint_dir_name: Optional[str] = None,
) -> None:
    """Report metrics and optionally save a checkpoint.
    If a checkpoint is provided, it will be
    :ref:`persisted to storage <persistent-storage-guide>`.
    If this is called in multiple distributed training workers:
    - Only the metrics reported by the rank 0 worker will be tracked by Ray Train.
      See :ref:`the metrics logging guide <train-monitoring-and-logging>`.
    - A checkpoint will be registered as long as one or more workers reports
      checkpoint that is not None.
      See the :ref:`checkpointing guide <train-dl-saving-checkpoints>`.
    - Checkpoints from multiple workers will be merged into one directory
      in persistent storage.
      See :ref:`the distributed checkpointing guide <train-distributed-checkpointing>`.
    .. note::
        Each invocation of this method will automatically increment the underlying
        ``training_iteration`` number. The physical meaning of this "iteration" is
        defined by user depending on how often they call ``report``.
        It does not necessarily map to one epoch.
    .. warning::
        All workers must call `ray.train.report` the same number of times
        so that Ray Train can properly synchronize the training state across
        workers. Otherwise, your training will hang.
    .. warning::
        This method does NOT act as a barrier for distributed training workers.
        Workers will upload their checkpoint, then continue training immediately.
        If you need to synchronize workers, you can use a framework-native barrier
        such as `torch.distributed.barrier()`.
    Example:
        .. testcode::
            import tempfile
            from ray import train
            from ray.train import Checkpoint
            from ray.train.torch import TorchTrainer
            def train_func(config):
                start_epoch = 0
                checkpoint = train.get_checkpoint()
                if checkpoint:
                    with checkpoint.as_directory() as checkpoint_dir:
                        # Load back training state
                        ...
                for epoch in range(start_epoch, config.get("num_epochs", 10)):
                    # Do training...
                    metrics = {"loss": ...}
                    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                       # Save the checkpoint...
                       # torch.save(...)
                        checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
                        # Example: Only the rank 0 worker uploads the checkpoint.
                        if ray.train.get_context().get_world_rank() == 0:
                            train.report(metrics, checkpoint=checkpoint)
                        else:
                            train.report(metrics, checkpoint=None)
            trainer = TorchTrainer(
                train_func, scaling_config=train.ScalingConfig(num_workers=2)
            )
    Args:
        metrics: The metrics you want to report.
        checkpoint: The optional checkpoint you want to report.
    """
    if checkpoint_dir_name is not None:
        logger.warning(
            "`checkpoint_dir_name` is only supported in the new Ray Train "
            "implementation, which can be enabled with `RAY_TRAIN_V2_ENABLED=1`. "
            "This argument will be ignored."
        )
    # If we are running in a Tune function, switch to `ray.tune.report`.
    from ray.tune.trainable.trainable_fn_utils import _in_tune_session
    if _in_tune_session():
        import ray.tune
        if _v2_migration_warnings_enabled():
            _log_deprecation_warning(
                "`ray.train.report` should be switched to "
                "`ray.tune.report` when running in a function "
                "passed to Ray Tune. This will be an error in the future. "
                "See this issue for more context: "
                "https://github.com/ray-project/ray/issues/49454"
            )
        return ray.tune.report(metrics, checkpoint=checkpoint)
    get_session().report(metrics, checkpoint=checkpoint) 
[docs]
@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_checkpoint() -> Optional[Checkpoint]:
    """Access the latest reported checkpoint to resume from if one exists.
    Example:
        .. testcode::
            import tempfile
            from ray import train
            from ray.train import Checkpoint
            from ray.train.torch import TorchTrainer
            def train_func(config):
                start_epoch = 0
                checkpoint = train.get_checkpoint()
                if checkpoint:
                    with checkpoint.as_directory() as checkpoint_dir:
                        # Load back training state
                        ...
                for epoch in range(start_epoch, config.get("num_epochs", 10)):
                    # Do training...
                    metrics = {"loss": ...}
                    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                       # Save the checkpoint...
                        checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
                        train.report(metrics, checkpoint=checkpoint)
            trainer = TorchTrainer(
                train_func, scaling_config=train.ScalingConfig(num_workers=2)
            )
    Returns:
        Checkpoint object if the session is currently being resumed.
            Otherwise, return None.
    """
    # If we are running in a Tune function, switch to `ray.tune.get_checkpoint`.
    from ray.tune.trainable.trainable_fn_utils import _in_tune_session
    if _in_tune_session():
        import ray.tune
        if _v2_migration_warnings_enabled():
            _log_deprecation_warning(
                "`ray.train.get_checkpoint` should be switched to "
                "`ray.tune.get_checkpoint` when running in a function "
                "passed to Ray Tune. This will be an error in the future. "
                "See this issue for more context: "
                "https://github.com/ray-project/ray/issues/49454"
            )
        return ray.tune.get_checkpoint()
    return get_session().loaded_checkpoint 
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_metadata() -> Dict[str, Any]:
    """User metadata dict passed to the Trainer constructor."""
    return get_session().metadata
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_experiment_name() -> str:
    """Experiment name for the corresponding trial."""
    return get_session().experiment_name
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_name() -> str:
    """Trial name for the corresponding trial."""
    return get_session().trial_name
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_id() -> str:
    """Trial id for the corresponding trial."""
    return get_session().trial_id
@PublicAPI(stability="alpha")
@_warn_session_misuse()
def get_run_id() -> str:
    """Unique Train Run id for the corresponding trial."""
    return get_session().run_id
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_resources() -> "PlacementGroupFactory":
    """Trial resources for the corresponding trial."""
    return get_session().trial_resources
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_dir() -> str:
    """Log directory corresponding to the trial directory for a Tune session.
    If calling from a Train session, this will give the trial directory of its parent
    Tune session.
    .. testcode::
        import ray.tune
        def train_func(config):
            print(ray.tune.get_context().get_trial_dir())
        tuner = ray.tune.Tuner(train_func)
        tuner.fit()
    .. testoutput::
        :options: +MOCK
        /Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
    """
    return get_session().trial_dir
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=1)
def get_world_size() -> int:
    """Get the current world size (i.e. total number of workers) for this run.
    .. testcode::
        import ray
        from ray import train
        from ray.train import ScalingConfig
        from ray.train.tensorflow import TensorflowTrainer
        NUM_WORKERS = 2
        def train_loop_per_worker(config):
            assert train.get_context().get_world_size() == NUM_WORKERS
        train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
        trainer = TensorflowTrainer(
            train_loop_per_worker,
            scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
            datasets={"train": train_dataset}
        )
        trainer.fit()
    .. testoutput::
        :hide:
        ...
    """
    session = get_session()
    if not hasattr(session, "world_size"):
        raise RuntimeError(
            "`get_world_size` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.world_size
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_world_rank() -> int:
    """Get the world rank of this worker.
    .. testcode::
        import ray
        from ray import train
        from ray.train import ScalingConfig
        from ray.train.tensorflow import TensorflowTrainer
        def train_loop_per_worker(config):
            if train.get_context().get_world_rank() == 0:
                print("Worker 0")
        train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
        trainer = TensorflowTrainer(
            train_loop_per_worker,
            scaling_config=ScalingConfig(num_workers=2),
            datasets={"train": train_dataset}
        )
        trainer.fit()
    .. testoutput::
        :hide:
        ...
    """
    session = get_session()
    if not hasattr(session, "world_rank"):
        raise RuntimeError(
            "`get_world_rank` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.world_rank
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_local_rank() -> int:
    """Get the local rank of this worker (rank of the worker on its node).
    .. testcode::
        import torch
        import ray
        from ray import train
        from ray.train import ScalingConfig
        from ray.train.torch import TorchTrainer
        def train_loop_per_worker(config):
            if torch.cuda.is_available():
                torch.cuda.set_device(train.get_context().get_local_rank())
            ...
        train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
        trainer = TorchTrainer(
            train_loop_per_worker,
            scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
            datasets={"train": train_dataset}
        )
        trainer.fit()
    .. testoutput::
        :hide:
        ...
    """
    session = get_session()
    if not hasattr(session, "local_rank"):
        raise RuntimeError(
            "`get_local_rank` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.local_rank
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_local_world_size() -> int:
    """Get the local world size of this node (i.e. number of workers on this node).
    Example:
        .. testcode::
            import ray
            from ray import train
            from ray.train import ScalingConfig
            from ray.train.torch import TorchTrainer
            def train_loop_per_worker():
                print(train.get_context().get_local_world_size())
            train_dataset = ray.data.from_items(
                [{"x": x, "y": x + 1} for x in range(32)])
            trainer = TorchTrainer(train_loop_per_worker,
                scaling_config=ScalingConfig(num_workers=1),
                datasets={"train": train_dataset})
            trainer.fit()
        .. testoutput::
            :hide:
            ...
    """
    session = get_session()
    if not hasattr(session, "local_world_size"):
        raise RuntimeError(
            "`get_local_world_size` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.local_world_size
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_node_rank() -> int:
    """Get the rank of this node.
    Example:
        .. testcode::
            import ray
            from ray import train
            from ray.train import ScalingConfig
            from ray.train.torch import TorchTrainer
            def train_loop_per_worker():
                print(train.get_context().get_node_rank())
            train_dataset = ray.data.from_items(
                [{"x": x, "y": x + 1} for x in range(32)])
            trainer = TorchTrainer(train_loop_per_worker,
                scaling_config=ScalingConfig(num_workers=1),
                datasets={"train": train_dataset})
            trainer.fit()
        .. testoutput::
            :hide:
            ...
    """
    session = get_session()
    if not hasattr(session, "node_rank"):
        raise RuntimeError(
            "`get_node_rank` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.node_rank
[docs]
@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_dataset_shard(
    dataset_name: Optional[str] = None,
) -> Optional["DataIterator"]:
    """Returns the :class:`ray.data.DataIterator` shard for this worker.
    Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
    :meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
    appropriate framework-specific data type.
    .. testcode::
        import ray
        from ray import train
        from ray.train import ScalingConfig
        from ray.train.torch import TorchTrainer
        def train_loop_per_worker(config):
            ...
            for epoch in range(2):
                # Trainer will automatically handle sharding.
                data_shard = train.get_dataset_shard("train")
                for batch in data_shard.iter_torch_batches():
                    ...
        train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
        trainer = TorchTrainer(
            train_loop_per_worker,
            scaling_config=ScalingConfig(num_workers=2),
            datasets={"train": train_dataset}
        )
        trainer.fit()
    .. testoutput::
        :hide:
        ...
    Args:
        dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
            specifies which dataset shard to return.
    Returns:
        The ``DataIterator`` shard to use for this worker.
        If no dataset is passed into Trainer, then return None.
    """
    session = get_session()
    if not hasattr(session, "get_dataset_shard"):
        raise RuntimeError(
            "`get_dataset_shard` can only be called for TrainSession! "
            "Make sure you only use that in `train_loop_per_worker` function"
            "that is passed into `DataParallelTrainer`."
        )
    return session.get_dataset_shard(dataset_name) 
@DeveloperAPI
@_warn_session_misuse()
def get_storage() -> StorageContext:
    """Returns the :class:`~ray.train._internal.storage.StorageContext` storage
    context which gives advanced access to the filesystem and paths
    configured through `RunConfig`.
    NOTE: This is a developer API, and the `StorageContext` interface may change
    without notice between minor versions.
    """
    return get_session().storage
def _in_ray_train_worker() -> bool:
    """Check if the current process is a Ray Train V1 worker."""
    return bool(get_session()) and get_session().world_rank is not None