Source code for ray.train.v2.api.context

from ray.util.annotations import PublicAPI, Deprecated, DeveloperAPI
from ray.train.v2._internal.execution.context import (
    get_train_context as get_internal_train_context,
)
from typing import Any, Dict


[docs] @PublicAPI(stability="stable") class TrainContext:
[docs] @Deprecated def get_metadata(self) -> Dict[str, Any]: """[Deprecated] User metadata dict passed to the Trainer constructor.""" from ray.train.context import _GET_METADATA_DEPRECATION_MESSAGE raise DeprecationWarning(_GET_METADATA_DEPRECATION_MESSAGE)
[docs] @Deprecated def get_trial_name(self) -> str: """[Deprecated] Trial name for the corresponding trial.""" from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE raise DeprecationWarning( _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_name") )
[docs] @Deprecated def get_trial_id(self) -> str: """[Deprecated] Trial id for the corresponding trial.""" from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE raise DeprecationWarning( _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_id") )
[docs] @Deprecated def get_trial_resources(self): """[Deprecated] Trial resources for the corresponding trial.""" from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE raise DeprecationWarning( _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_resources") )
[docs] @Deprecated def get_trial_dir(self) -> str: """[Deprecated] Log directory corresponding to the trial directory for a Tune session. This is deprecated for Ray Train and should no longer be called in Ray Train workers. If this directory is needed, please pass it into the `train_loop_config` directly. """ from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE raise DeprecationWarning( _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_dir") )
[docs] def get_experiment_name(self) -> str: """Experiment name for the corresponding trial.""" return get_internal_train_context().get_experiment_name()
[docs] def get_world_size(self) -> int: """Get the current world size (i.e. total number of workers) for this run. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer NUM_WORKERS = 2 def train_loop_per_worker(config): assert train.get_context().get_world_size() == NUM_WORKERS trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=NUM_WORKERS), ) trainer.fit() .. testoutput:: :hide: ... """ return get_internal_train_context().get_world_size()
[docs] def get_world_rank(self) -> int: """Get the world rank of this worker. .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.tensorflow import TensorflowTrainer def train_loop_per_worker(config): if train.get_context().get_world_rank() == 0: print("Worker 0") trainer = TensorflowTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), ) trainer.fit() .. testoutput:: :hide: ... """ return get_internal_train_context().get_world_rank()
[docs] def get_local_rank(self) -> int: """Get the local rank of this worker (rank of the worker on its node). .. testcode:: import torch import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): if torch.cuda.is_available(): torch.cuda.set_device(train.get_context().get_local_rank()) ... trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2, use_gpu=True), ) trainer.fit() .. testoutput:: :hide: ... """ return get_internal_train_context().get_local_rank()
[docs] def get_local_world_size(self) -> int: """Get the local world size of this node (i.e. number of workers on this node). Example: .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(): print(train.get_context().get_local_world_size()) trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), ) trainer.fit() .. testoutput:: :hide: ... """ return get_internal_train_context().get_local_world_size()
[docs] def get_node_rank(self) -> int: """Get the rank of this node. Example: .. testcode:: import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(): print(train.get_context().get_node_rank()) trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), ) trainer.fit() .. testoutput:: :hide: ... """ return get_internal_train_context().get_node_rank()
[docs] @DeveloperAPI def get_storage(self): """Returns the :class:`~ray.train._internal.storage.StorageContext` storage context which gives advanced access to the filesystem and paths configured through `RunConfig`. NOTE: This is a developer API, and the `StorageContext` interface may change without notice between minor versions. """ return get_internal_train_context().get_storage()