import functools
import logging
import os
import platform
import queue
import sys
import threading
import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type
import ray
from ray.air._internal.util import RunnerThread, StartTraceback
from ray.air.constants import (
_ERROR_FETCH_TIMEOUT,
_RESULT_FETCH_TIMEOUT,
SESSION_MISUSE_LOG_ONCE_KEY,
TIME_THIS_ITER_S,
TIMESTAMP,
)
from ray.data import Dataset
from ray.train import Checkpoint
from ray.train._internal.accelerator import Accelerator
from ray.train._internal.storage import StorageContext
from ray.train.constants import (
CHECKPOINT_DIR_NAME,
DETAILED_AUTOFILLED_KEYS,
RAY_CHDIR_TO_TRIAL_DIR,
TIME_TOTAL_S,
WORKER_HOSTNAME,
WORKER_NODE_IP,
WORKER_PID,
_v2_migration_warnings_enabled,
)
from ray.train.error import SessionMisuseError
from ray.train.utils import _log_deprecation_warning
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.debug import log_once
from ray.util.placement_group import _valid_resource_shape
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy,
SchedulingStrategyT,
)
if TYPE_CHECKING:
from ray.data import DataIterator
from ray.tune.execution.placement_groups import PlacementGroupFactory
logger = logging.getLogger(__name__)
@dataclass
class TrialInfo:
"""The trial information to propagate to TrainSession."""
name: str
id: str
resources: Dict[str, float]
logdir: str
driver_ip: str
driver_node_id: str
experiment_name: Optional[str] = None
run_id: Optional[str] = None
class _FutureTrainingResult:
"""A future that will be resolved to a `_TrainingResult`.
This is needed for specific schedulers such as PBT that schedule saves.
This wrapper should be removed after refactoring PBT to not schedule saves anymore.
"""
def __init__(self, future: ray.ObjectRef):
self.future = future
def resolve(self, block: bool = True) -> Optional["_TrainingResult"]:
"""Resolve into ``_TrainingResult``.
This will return None for function trainables if no checkpoint has been
saved before.
"""
if block:
timeout = None
else:
timeout = 1e-9
try:
return ray.get(self.future, timeout=timeout)
except TimeoutError:
# Not ready, yet
pass
except Exception as exc:
logger.error(f"Error resolving result: {exc}")
class _TrainingResult:
"""A (checkpoint, metrics) result reported by the user."""
def __init__(self, checkpoint: Optional[Checkpoint], metrics: Dict[str, Any]):
self.checkpoint = checkpoint
self.metrics = metrics
def __repr__(self) -> str:
return f"TrainingResult(checkpoint={self.checkpoint}, metrics={self.metrics})"
# TODO(xwjiang): This needs a better name.
@DeveloperAPI
class _TrainSession:
"""Holds information for training on each worker."""
def __init__(
self,
training_func: Callable,
world_rank: Optional[int],
local_rank: Optional[int],
node_rank: Optional[int],
local_world_size: Optional[int],
world_size: Optional[int],
trial_info: Optional[TrialInfo] = None,
dataset_shard: Optional[Dict[str, Dataset]] = None,
metadata: Dict[str, Any] = None,
checkpoint: Optional[Checkpoint] = None,
detailed_autofilled_metrics: bool = False,
storage: Optional[StorageContext] = None,
synchronous_result_reporting: bool = False,
):
# `synchronous_result_reporting` refers to whether or not the
# training function is immediately unblocked to continue running
# after the main thread receives its result.
# Ex 1: For 2 Ray Train workers with synchronous_result_reporting=True,
# the worker that produces a result first will immediately will continue
# onto the next iteration.
# Ex 2: For a Tune function Trainable with `synchronous_result_reporting=False`,
# training will only continue with an explicit call to `session.get_next`.
# Synchronous reporting in example 2 is needed for Tune schedulers to
# be able to stop the execution of the training function at will,
# for advanced pausing schedulers (PBT, BOHB) and actor reuse.
self.synchronous_result_reporting = synchronous_result_reporting
# Ray Train worker properties
# Note: These are set to None for Tune function Trainables.
self.dataset_shard = dataset_shard
self.metadata = metadata
self.world_rank = world_rank
self.local_rank = local_rank
self.node_rank = node_rank
self.local_world_size = local_world_size
self.world_size = world_size
assert storage
logger.debug(f"StorageContext on SESSION (rank={world_rank}):\n{storage}")
# NOTE: `reset` will initialize many properties needed to start running the
# training_func as a thread.
self.reset(
training_func=training_func,
trial_info=trial_info,
storage=storage,
loaded_checkpoint=checkpoint,
)
# Autofilled metrics attributes.
self.detailed_autofilled_metrics = detailed_autofilled_metrics
self.last_report_time = time.time()
self.iteration = 0
self.time_total = 0.0
self.local_ip = self.get_current_ip()
self.accelerator = None
self._state = {}
def get_state(self, key: str) -> Any:
return self._state.get(key)
def set_state(self, key: str, value: Any):
self._state[key] = value
def get_current_ip(self):
self.local_ip = ray.util.get_node_ip_address()
return self.local_ip
def start(self):
"""Starts the training thread."""
self.training_started = True
self.training_thread.start()
def reset(
self,
training_func: Callable,
trial_info: TrialInfo,
storage: StorageContext,
loaded_checkpoint=None,
):
# This lock is used to control the execution of the training thread.
self.continue_lock = threading.Semaphore(0)
# This event is used to signal the training thread to stop.
self.stop_event = threading.Event()
# Queue for sending results across threads.
self.result_queue = queue.Queue(1)
# Queue for raising exceptions from runner thread to main thread.
# The error queue has a max size of one to prevent stacking error and force
# error reporting to block until finished.
self.error_queue = queue.Queue(1)
# The Thread object that is running the training function.
self.training_thread = RunnerThread(
target=training_func, daemon=True, error_queue=self.error_queue
)
# Possibly override with new state
self.trial_info = trial_info
self.storage = storage
self.loaded_checkpoint = loaded_checkpoint
# Reset state
self._state = {}
self.ignore_report = False
self.training_started = False
self._first_report = True
# Change the working directory to a special trial folder.
# This is to ensure that all Ray Train workers have a common working directory.
os.makedirs(storage.trial_working_directory, exist_ok=True)
if bool(int(os.environ.get(RAY_CHDIR_TO_TRIAL_DIR, "1"))):
logger.debug(
f"Changing the working directory to: {storage.trial_working_directory}"
)
os.chdir(storage.trial_working_directory)
def pause_reporting(self):
"""Ignore all future ``session.report()`` calls."""
self.ignore_report = True
def finish(self, timeout: Optional[float] = None) -> Optional[Any]:
"""Finishes the training thread.
Raises any Exception from training.
"""
# Set the stop event for the training thread to gracefully exit.
self.stop_event.set()
# Release the lock so that training thread can process this event.
self.continue_lock.release()
# Force a final (blocking) sync of artifacts in the trial path to storage.
self.storage.persist_artifacts(force=True)
# Wait for training to finish.
# This will raise any errors that occur during training, including SystemError
# This returns the result of the training function.
output = None
if self.training_started:
output = self.training_thread.join(timeout=timeout)
return output
def get_next(self) -> Optional[_TrainingResult]:
"""Gets the next ``_TrainingResult`` from the result queue.
If the result queue is empty, then this function returns ``None``.
"""
if not self.training_started:
raise RuntimeError("Please call start before calling get_next.")
if self.synchronous_result_reporting:
# There's no need to release the lock on the first report
# since `start` already started the training thread.
if not self._first_report:
# Release the lock to trigger training to continue,
# until the next call to report.
self.continue_lock.release()
self._first_report = False
result = None
# While training is still ongoing, attempt to get the result.
while result is None and self.training_thread.is_alive():
try:
result = self.result_queue.get(
block=True, timeout=_RESULT_FETCH_TIMEOUT
)
except queue.Empty:
pass
# If no result was found, then the runner must no longer be alive.
if result is None:
# Try one last time to fetch results in case results were
# reported in between the time of the last check and the
# termination of the thread runner.
try:
result = self.result_queue.get(
block=False, timeout=_RESULT_FETCH_TIMEOUT
)
except queue.Empty:
pass
# check if error occurred inside the thread runner.
if result is None:
# only raise an error from the runner if all results are consumed
self._report_thread_runner_error(block=True)
else:
if not self.error_queue.empty():
logger.debug(
(
"Runner error waiting to be raised in main thread. "
"Logging all available results first."
)
)
if not self.synchronous_result_reporting:
# At this point, the training thread has reached
# the `train.report` and is blocked there.
# If performing asynchronous result reporting,
# release the lock to allow each worker to keep training
# immediately after the coordinator fetches their result.
self.continue_lock.release()
# Return None if there are no more results to fetch.
return result
def _auto_fill_metrics(self, result: dict) -> dict:
"""Add autofilled metrics and update attributes."""
current_time = time.time()
current_datetime = datetime.now()
if TIME_THIS_ITER_S in result:
time_this_iter = result[TIME_THIS_ITER_S]
else:
time_this_iter = current_time - self.last_report_time
self.iteration += 1
self.time_total += time_this_iter
self.last_report_time = current_time
auto_filled_metrics = {
TIMESTAMP: int(time.mktime(current_datetime.timetuple())),
TIME_TOTAL_S: self.time_total,
WORKER_PID: os.getpid(),
WORKER_HOSTNAME: platform.node(),
WORKER_NODE_IP: self.local_ip,
}
if not self.detailed_autofilled_metrics:
auto_filled_metrics = {
k: v
for k, v in auto_filled_metrics.items()
if k not in DETAILED_AUTOFILLED_KEYS
}
result = result.copy()
result.update(auto_filled_metrics)
return result
def _auto_fill_checkpoint_metrics(self, result: dict) -> dict:
"""Add autofilled metrics and update attributes."""
current_datetime = datetime.now()
auto_filled_metrics = {
TIMESTAMP: int(time.mktime(current_datetime.timetuple()))
}
result = result.copy()
result.update(auto_filled_metrics)
return result
def _report_thread_runner_error(self, block=False):
try:
e = self.error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
raise StartTraceback from e
except queue.Empty:
pass
def _report_training_result(self, training_result: _TrainingResult) -> None:
"""Place a training result on the result queue for the main thread to process,
then block until the main thread signals that training should continue.
NOTE: This is used internally to report results from Train to Tune
without persisting checkpoints to storage 2 times.
`report` is the public API that directly persists to storage, which
should only be called by user code.
"""
if training_result.checkpoint:
# NOTE: This populates `train.get_checkpoint`
self.loaded_checkpoint = training_result.checkpoint
# Add result to a thread-safe queue.
self.result_queue.put(training_result, block=True)
# Acquire lock to stop the training thread until main thread
# triggers resume.
self.continue_lock.acquire()
# If the trial should be terminated, exit gracefully.
# NOTE: This is only really useful if `synchronous_result_reporting=True`.
# Otherwise, the lock is immediately released on reporting, and this
# check is skipped before the main thread decides to set the stop event.
if self.stop_event.is_set():
self.stop_event.clear()
sys.exit(0)
def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# Special case: early fail for Torch tensors
if "torch" in sys.modules:
from ray.air._internal.torch_utils import contains_tensor
if contains_tensor(metrics):
raise ValueError(
"Passing objects containg Torch tensors as metrics "
"is not supported as it will throw an exception on "
"deserialization. You can either convert the tensors "
"to Python objects or report a `train.Checkpoint` "
"with `ray.train.report` to store your Torch objects."
)
if self.ignore_report:
return
metrics = self._auto_fill_metrics(metrics)
persisted_checkpoint = None
if checkpoint:
self.storage._update_checkpoint_index(metrics)
# Persist the reported checkpoint files to storage.
persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)
metrics[CHECKPOINT_DIR_NAME] = self.storage.checkpoint_dir_name
else:
metrics[CHECKPOINT_DIR_NAME] = None
# Persist trial artifacts to storage.
force_artifact_sync = (
persisted_checkpoint
and self.storage.sync_config.sync_artifacts_on_checkpoint
)
self.storage.persist_artifacts(force=force_artifact_sync)
# Set additional user metadata from the Trainer.
if persisted_checkpoint and self.metadata:
user_metadata = persisted_checkpoint.get_metadata()
for k, v in self.metadata.items():
# Update keys not already set by the user. This gives user-set keys
# precedence over keys set at the Trainer level.
if k not in user_metadata:
user_metadata[k] = v
persisted_checkpoint.set_metadata(user_metadata)
result = _TrainingResult(checkpoint=persisted_checkpoint, metrics=metrics)
self._report_training_result(result)
@property
def experiment_name(self) -> str:
return self.trial_info.experiment_name
@property
def trial_name(self) -> str:
return self.trial_info.name
@property
def trial_id(self) -> str:
return self.trial_info.id
@property
def run_id(self) -> str:
return self.trial_info.run_id
@property
def trial_resources(self) -> "PlacementGroupFactory":
return self.trial_info.resources
@property
def trial_dir(self) -> str:
return self.trial_info.logdir
def get_dataset_shard(
self,
dataset_name: Optional[str] = None,
) -> Optional["DataIterator"]:
shard = self.dataset_shard
if shard is None:
warnings.warn(
"No dataset passed in. Returning None. Make sure to "
"pass in a Dataset to Trainer.run to use this "
"function."
)
elif isinstance(shard, dict):
if not dataset_name:
raise RuntimeError(
"Multiple datasets were passed into ``Trainer``, "
"but no ``dataset_name`` is passed into "
"``get_dataset_shard``. Please specify which "
"dataset shard to retrieve."
)
return shard.get(dataset_name)
return shard
# Cache of resource dicts that have been checked by the launch hook already.
_checked_resources: Set[frozenset] = set()
# Global _TrainSession object initialized by Ray Tune function trainables
# and Ray Train V1 workers.
_session: Optional[_TrainSession] = None
def _tune_task_and_actor_launch_hook(
fn, resources: Dict[str, float], strategy: Optional[SchedulingStrategyT]
):
"""Launch hook to catch nested tasks that can't fit in the placement group.
This gives users a nice warning in case they launch a nested task in a Tune trial
without reserving resources in the trial placement group to fit it.
"""
# Already checked, skip for performance reasons.
key = frozenset({(k, v) for k, v in resources.items() if v > 0})
if not key or key in _checked_resources:
return
# No need to check if placement group is None.
if (
not isinstance(strategy, PlacementGroupSchedulingStrategy)
or strategy.placement_group is None
):
return
# Check if the resource request is targeting the current placement group.
cur_pg = ray.util.get_current_placement_group()
if not cur_pg or strategy.placement_group.id != cur_pg.id:
return
_checked_resources.add(key)
# Check if the request can be fulfilled by the current placement group.
pgf = get_trial_resources()
if pgf.head_bundle_is_empty:
available_bundles = cur_pg.bundle_specs[0:]
else:
available_bundles = cur_pg.bundle_specs[1:]
# Check if the request can be fulfilled by the current placement group.
if _valid_resource_shape(resources, available_bundles):
return
if fn.class_name:
submitted = "actor"
name = fn.module_name + "." + fn.class_name + "." + fn.function_name
else:
submitted = "task"
name = fn.module_name + "." + fn.function_name
# Normalize the resource spec so it looks the same as the placement group bundle.
main_resources = cur_pg.bundle_specs[0]
resources = {k: float(v) for k, v in resources.items() if v > 0}
raise RuntimeError(
f"No trial resources are available for launching the {submitted} `{name}`. "
"To resolve this, specify the Tune option:\n\n"
"> resources_per_trial=tune.PlacementGroupFactory(\n"
f"> [{main_resources}] + [{resources}] * N\n"
"> )\n\n"
f"Where `N` is the number of slots to reserve for trial {submitted}s. "
"If you are using a Ray training library, there might be a utility function "
"to set this automatically for you. For more information, refer to "
"https://docs.ray.io/en/latest/tune/tutorials/tune-resources.html"
)
def init_session(*args, **kwargs) -> None:
global _session
if _session:
raise ValueError(
"A Train session is already in use. Do not call "
"`init_session()` manually."
)
# Setup hooks for generating placement group resource deadlock warnings.
from ray import actor, remote_function
if "TUNE_DISABLE_RESOURCE_CHECKS" not in os.environ:
actor._actor_launch_hook = _tune_task_and_actor_launch_hook
remote_function._task_launch_hook = _tune_task_and_actor_launch_hook
_session = _TrainSession(*args, **kwargs)
def get_session() -> Optional[_TrainSession]:
return _session
def shutdown_session():
"""Shuts down the initialized session."""
global _session
_session = None
def _raise_accelerator_session_misuse():
"""Raises a SessionMisuseError because a utility function was used improperly."""
raise SessionMisuseError(
"prepare/accelerate utility functions should be called inside a training "
"function executed by `Trainer.run`"
)
def get_accelerator(default_accelerator_cls: Type[Accelerator]) -> Accelerator:
"""The accelerator for this training session.
If an accelerator has not been set, then this method will construct an
accelerator using the provided accelerator class.
Raises:
SessionMisuseError: if the session is uninitialized.
"""
session = get_session()
if session is None:
_raise_accelerator_session_misuse()
if session.accelerator is None:
session.accelerator = default_accelerator_cls()
return session.accelerator
def set_accelerator(accelerator: Accelerator) -> None:
"""Sets the accelerator for this training session.
Args:
accelerator: The accelerator to use for training.
Raises:
SessionMisuseError: if the session is unitialized.
RuntimeError: if the accelerator has already been set.
"""
session = get_session()
if session is None:
_raise_accelerator_session_misuse()
if session.accelerator is not None:
raise RuntimeError("Cannot change accelerator once set.")
session.accelerator = accelerator
def _warn_session_misuse(default_value: Any = None):
"""Warns if fn is being used outside of session and returns ``default_value``."""
def inner(fn: Callable):
fn_name = fn.__name__
@functools.wraps(fn)
def wrapper(*args, **kwargs):
session = get_session()
if not session:
if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
warnings.warn(
f"`{fn_name}` is meant to only be "
"called inside a function that is executed by a Tuner"
f" or Trainer. Returning `{default_value}`."
)
return default_value
return fn(*args, **kwargs)
return wrapper
return inner
[docs]
@PublicAPI(stability="stable")
@_warn_session_misuse()
def report(
metrics: Dict,
*,
checkpoint: Optional[Checkpoint] = None,
checkpoint_dir_name: Optional[str] = None,
) -> None:
"""Report metrics and optionally save a checkpoint.
If a checkpoint is provided, it will be
:ref:`persisted to storage <persistent-storage-guide>`.
If this is called in multiple distributed training workers:
- Only the metrics reported by the rank 0 worker will be tracked by Ray Train.
See :ref:`the metrics logging guide <train-monitoring-and-logging>`.
- A checkpoint will be registered as long as one or more workers reports
checkpoint that is not None.
See the :ref:`checkpointing guide <train-dl-saving-checkpoints>`.
- Checkpoints from multiple workers will be merged into one directory
in persistent storage.
See :ref:`the distributed checkpointing guide <train-distributed-checkpointing>`.
.. note::
Each invocation of this method will automatically increment the underlying
``training_iteration`` number. The physical meaning of this "iteration" is
defined by user depending on how often they call ``report``.
It does not necessarily map to one epoch.
.. warning::
All workers must call `ray.train.report` the same number of times
so that Ray Train can properly synchronize the training state across
workers. Otherwise, your training will hang.
.. warning::
This method does NOT act as a barrier for distributed training workers.
Workers will upload their checkpoint, then continue training immediately.
If you need to synchronize workers, you can use a framework-native barrier
such as `torch.distributed.barrier()`.
Example:
.. testcode::
import tempfile
from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer
def train_func(config):
start_epoch = 0
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
# Load back training state
...
for epoch in range(start_epoch, config.get("num_epochs", 10)):
# Do training...
metrics = {"loss": ...}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
# Save the checkpoint...
# torch.save(...)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
# Example: Only the rank 0 worker uploads the checkpoint.
if ray.train.get_context().get_world_rank() == 0:
train.report(metrics, checkpoint=checkpoint)
else:
train.report(metrics, checkpoint=None)
trainer = TorchTrainer(
train_func, scaling_config=train.ScalingConfig(num_workers=2)
)
Args:
metrics: The metrics you want to report.
checkpoint: The optional checkpoint you want to report.
"""
if checkpoint_dir_name is not None:
logger.warning(
"`checkpoint_dir_name` is only supported in the new Ray Train "
"implementation, which can be enabled with `RAY_TRAIN_V2_ENABLED=1`. "
"This argument will be ignored."
)
# If we are running in a Tune function, switch to `ray.tune.report`.
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
if _in_tune_session():
import ray.tune
if _v2_migration_warnings_enabled():
_log_deprecation_warning(
"`ray.train.report` should be switched to "
"`ray.tune.report` when running in a function "
"passed to Ray Tune. This will be an error in the future. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)
return ray.tune.report(metrics, checkpoint=checkpoint)
get_session().report(metrics, checkpoint=checkpoint)
[docs]
@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_checkpoint() -> Optional[Checkpoint]:
"""Access the latest reported checkpoint to resume from if one exists.
Example:
.. testcode::
import tempfile
from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer
def train_func(config):
start_epoch = 0
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
# Load back training state
...
for epoch in range(start_epoch, config.get("num_epochs", 10)):
# Do training...
metrics = {"loss": ...}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
# Save the checkpoint...
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
train.report(metrics, checkpoint=checkpoint)
trainer = TorchTrainer(
train_func, scaling_config=train.ScalingConfig(num_workers=2)
)
Returns:
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
"""
# If we are running in a Tune function, switch to `ray.tune.get_checkpoint`.
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
if _in_tune_session():
import ray.tune
if _v2_migration_warnings_enabled():
_log_deprecation_warning(
"`ray.train.get_checkpoint` should be switched to "
"`ray.tune.get_checkpoint` when running in a function "
"passed to Ray Tune. This will be an error in the future. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)
return ray.tune.get_checkpoint()
return get_session().loaded_checkpoint
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_metadata() -> Dict[str, Any]:
"""User metadata dict passed to the Trainer constructor."""
return get_session().metadata
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_experiment_name() -> str:
"""Experiment name for the corresponding trial."""
return get_session().experiment_name
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_name() -> str:
"""Trial name for the corresponding trial."""
return get_session().trial_name
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_id() -> str:
"""Trial id for the corresponding trial."""
return get_session().trial_id
@PublicAPI(stability="alpha")
@_warn_session_misuse()
def get_run_id() -> str:
"""Unique Train Run id for the corresponding trial."""
return get_session().run_id
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_resources() -> "PlacementGroupFactory":
"""Trial resources for the corresponding trial."""
return get_session().trial_resources
@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_dir() -> str:
"""Log directory corresponding to the trial directory for a Tune session.
If calling from a Train session, this will give the trial directory of its parent
Tune session.
.. testcode::
import ray.tune
def train_func(config):
print(ray.tune.get_context().get_trial_dir())
tuner = ray.tune.Tuner(train_func)
tuner.fit()
.. testoutput::
:options: +MOCK
/Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
"""
return get_session().trial_dir
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=1)
def get_world_size() -> int:
"""Get the current world size (i.e. total number of workers) for this run.
.. testcode::
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
NUM_WORKERS = 2
def train_loop_per_worker(config):
assert train.get_context().get_world_size() == NUM_WORKERS
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TensorflowTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=NUM_WORKERS),
datasets={"train": train_dataset}
)
trainer.fit()
.. testoutput::
:hide:
...
"""
session = get_session()
if not hasattr(session, "world_size"):
raise RuntimeError(
"`get_world_size` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.world_size
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_world_rank() -> int:
"""Get the world rank of this worker.
.. testcode::
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
def train_loop_per_worker(config):
if train.get_context().get_world_rank() == 0:
print("Worker 0")
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TensorflowTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=2),
datasets={"train": train_dataset}
)
trainer.fit()
.. testoutput::
:hide:
...
"""
session = get_session()
if not hasattr(session, "world_rank"):
raise RuntimeError(
"`get_world_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.world_rank
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_local_rank() -> int:
"""Get the local rank of this worker (rank of the worker on its node).
.. testcode::
import torch
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker(config):
if torch.cuda.is_available():
torch.cuda.set_device(train.get_context().get_local_rank())
...
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TorchTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=2, use_gpu=True),
datasets={"train": train_dataset}
)
trainer.fit()
.. testoutput::
:hide:
...
"""
session = get_session()
if not hasattr(session, "local_rank"):
raise RuntimeError(
"`get_local_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.local_rank
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_local_world_size() -> int:
"""Get the local world size of this node (i.e. number of workers on this node).
Example:
.. testcode::
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker():
print(train.get_context().get_local_world_size())
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TorchTrainer(train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={"train": train_dataset})
trainer.fit()
.. testoutput::
:hide:
...
"""
session = get_session()
if not hasattr(session, "local_world_size"):
raise RuntimeError(
"`get_local_world_size` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.local_world_size
@PublicAPI(stability="beta")
@_warn_session_misuse(default_value=0)
def get_node_rank() -> int:
"""Get the rank of this node.
Example:
.. testcode::
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker():
print(train.get_context().get_node_rank())
train_dataset = ray.data.from_items(
[{"x": x, "y": x + 1} for x in range(32)])
trainer = TorchTrainer(train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=1),
datasets={"train": train_dataset})
trainer.fit()
.. testoutput::
:hide:
...
"""
session = get_session()
if not hasattr(session, "node_rank"):
raise RuntimeError(
"`get_node_rank` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.node_rank
[docs]
@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_dataset_shard(
dataset_name: Optional[str] = None,
) -> Optional["DataIterator"]:
"""Returns the :class:`ray.data.DataIterator` shard for this worker.
Call :meth:`~ray.data.DataIterator.iter_torch_batches` or
:meth:`~ray.data.DataIterator.to_tf` on this shard to convert it to the
appropriate framework-specific data type.
.. testcode::
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_loop_per_worker(config):
...
for epoch in range(2):
# Trainer will automatically handle sharding.
data_shard = train.get_dataset_shard("train")
for batch in data_shard.iter_torch_batches():
...
train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TorchTrainer(
train_loop_per_worker,
scaling_config=ScalingConfig(num_workers=2),
datasets={"train": train_dataset}
)
trainer.fit()
.. testoutput::
:hide:
...
Args:
dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
specifies which dataset shard to return.
Returns:
The ``DataIterator`` shard to use for this worker.
If no dataset is passed into Trainer, then return None.
"""
session = get_session()
if not hasattr(session, "get_dataset_shard"):
raise RuntimeError(
"`get_dataset_shard` can only be called for TrainSession! "
"Make sure you only use that in `train_loop_per_worker` function"
"that is passed into `DataParallelTrainer`."
)
return session.get_dataset_shard(dataset_name)
@DeveloperAPI
@_warn_session_misuse()
def get_storage() -> StorageContext:
"""Returns the :class:`~ray.train._internal.storage.StorageContext` storage
context which gives advanced access to the filesystem and paths
configured through `RunConfig`.
NOTE: This is a developer API, and the `StorageContext` interface may change
without notice between minor versions.
"""
return get_session().storage