Source code for ray.air.session

from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import warnings
import functools

from ray.air._internal.session import _get_session
from ray.air.checkpoint import Checkpoint
from ray.air.constants import SESSION_MISUSE_LOG_ONCE_KEY
from ray.train.session import _TrainSessionImpl
from ray.util import log_once
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.data import DatasetIterator
    from ray.tune.execution.placement_groups import PlacementGroupFactory


def _warn_session_misuse(default_value: Any = None):
    """Warns if fn is being used outside of session and returns ``default_value``."""

    def inner(fn: Callable):
        fn_name = fn.__name__

        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            session = _get_session()
            if not session:
                if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
                    warnings.warn(
                        f"`{fn_name}` is meant to only be "
                        "called inside a function that is executed by a Tuner"
                        f" or Trainer. Returning `{default_value}`."
                    )
                return default_value
            return fn(*args, **kwargs)

        return wrapper

    return inner


[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None: """Report metrics and optionally save a checkpoint. Each invocation of this method will automatically increment the underlying iteration number. The physical meaning of this "iteration" is defined by user (or more specifically the way they call ``report``). It does not necessarily map to one epoch. This API is the canonical way to report metrics from Tune and Train, and replaces the legacy ``tune.report``, ``with tune.checkpoint_dir``, ``train.report`` and ``train.save_checkpoint`` calls. Note on directory checkpoints: AIR will take ownership of checkpoints passed to ``report()`` by moving them to a new path. The original directory will no longer be accessible to the caller after the report call. Example: .. code-block: python from ray.air import session from ray.air.checkpoint import Checkpoint from ray.air.config import ScalingConfig ######## Using it in the *per worker* train loop (TrainSession) ####### def train_func(): model = build_model() model.save("my_model", overwrite=True) session.report( metrics={"foo": "bar"}, checkpoint=Checkpoint.from_directory(temp_dir.name) ) # Air guarantees by this point, you can safely write new stuff to # "my_model" directory. scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() # If you navigate to result.checkpoint's path, you will find the content of ``model.save()`` under it. # If you have `SyncConfig` configured, the content should also # show up in the corresponding cloud storage path. Args: metrics: The metrics you want to report. checkpoint: The optional checkpoint you want to report. """ _get_session().report(metrics, checkpoint=checkpoint)
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_checkpoint() -> Optional[Checkpoint]: """Access the session's last checkpoint to resume from if applicable. Returns: Checkpoint object if the session is currently being resumed. Otherwise, return None. .. code-block:: python ######## Using it in the *per worker* train loop (TrainSession) ###### from ray.air import session from ray.air.checkpoint import Checkpoint from ray.air.config import ScalingConfig def train_func(): ckpt = session.get_checkpoint() if ckpt: with ckpt.as_directory() as loaded_checkpoint_dir: import tensorflow as tf model = tf.keras.models.load_model(loaded_checkpoint_dir) else: model = build_model() model.save("my_model", overwrite=True) session.report( metrics={"iter": 1}, checkpoint=Checkpoint.from_directory("my_model") ) scaling_config = ScalingConfig(num_workers=2) trainer = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config ) result = trainer.fit() # trainer2 will pick up from the checkpoint saved by trainer1. trainer2 = TensorflowTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config, # this is ultimately what is accessed through # ``Session.get_checkpoint()`` resume_from_checkpoint=result.checkpoint, ) result2 = trainer2.fit() """ return _get_session().loaded_checkpoint
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_experiment_name() -> str: """Experiment name for the corresponding trial.""" return _get_session().experiment_name
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_name() -> str: """Trial name for the corresponding trial.""" return _get_session().trial_name
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_id() -> str: """Trial id for the corresponding trial.""" return _get_session().trial_id
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_resources() -> "PlacementGroupFactory": """Trial resources for the corresponding trial.""" return _get_session().trial_resources
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_trial_dir() -> str: """Log directory corresponding to the trial directory for a Tune session. If calling from a Train session, this will give the trial directory of its parent Tune session. .. code-block:: python from ray import tune from ray.air import session def train_func(): # Example: # >>> session.get_trial_dir() # ~/ray_results/<exp-name>/<trial-dir> tuner = tune.Tuner(train_func) tuner.fit() """ return _get_session().trial_dir
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=1) def get_world_size() -> int: """Get the current world size (i.e. total number of workers) for this run. .. code-block:: python import time from ray.air import session from ray.air.config import ScalingConfig def train_loop_per_worker(config): assert session.get_world_size() == 4 train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TensorflowTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() """ session = _get_session() if not isinstance(session, _TrainSessionImpl): raise RuntimeError( "`get_world_size` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.world_size
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_world_rank() -> int: """Get the world rank of this worker. .. code-block:: python import time from ray.air import session from ray.air.config import ScalingConfig def train_loop_per_worker(): for iter in range(100): time.sleep(1) if session.get_world_rank() == 0: print("Worker 0") train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TensorflowTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() """ session = _get_session() if not isinstance(session, _TrainSessionImpl): raise RuntimeError( "`get_world_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.world_rank
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). .. code-block:: python import time from ray.air import session from ray.air.config import ScalingConfig def train_loop_per_worker(): if torch.cuda.is_available(): torch.cuda.set_device(session.get_local_rank()) ... train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TensorflowTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=1), datasets={"train": train_dataset}) trainer.fit() """ session = _get_session() if not isinstance(session, _TrainSessionImpl): raise RuntimeError( "`get_local_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.local_rank
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_local_world_size() -> int: """Get the local rank of this worker (rank of the worker on its node). Example: >>> import ray >>> from ray.air import session >>> from ray.air.config import ScalingConfig >>> from ray.train.torch import TorchTrainer >>> >>> def train_loop_per_worker(): ... return session.get_local_world_size() >>> >>> train_dataset = ray.data.from_items( ... [{"x": x, "y": x + 1} for x in range(32)]) >>> trainer = TorchTrainer(train_loop_per_worker, ... scaling_config=ScalingConfig(num_workers=1), ... datasets={"train": train_dataset}) >>> trainer.fit() # doctest: +SKIP """ session = _get_session() if not isinstance(session, _TrainSessionImpl): raise RuntimeError( "`get_local_world_size` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.local_world_size
[docs]@PublicAPI(stability="beta") @_warn_session_misuse(default_value=0) def get_node_rank() -> int: """Get the local rank of this worker (rank of the worker on its node). Example: >>> import ray >>> from ray.air import session >>> from ray.air.config import ScalingConfig >>> from ray.train.torch import TorchTrainer >>> >>> def train_loop_per_worker(): ... return session.get_node_rank() >>> >>> train_dataset = ray.data.from_items( ... [{"x": x, "y": x + 1} for x in range(32)]) >>> trainer = TorchTrainer(train_loop_per_worker, ... scaling_config=ScalingConfig(num_workers=1), ... datasets={"train": train_dataset}) >>> trainer.fit() # doctest: +SKIP """ session = _get_session() if not isinstance(session, _TrainSessionImpl): raise RuntimeError( "`get_node_rank` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.node_rank
[docs]@PublicAPI(stability="beta") @_warn_session_misuse() def get_dataset_shard( dataset_name: Optional[str] = None, ) -> Optional["DatasetIterator"]: """Returns the :class:`ray.data.DatasetIterator` shard for this worker. Call :meth:`~ray.data.DatasetIterator.iter_torch_batches` or :meth:`~ray.data.DatasetIterator.to_tf` on this shard to convert it to the appropriate framework-specific data type. .. code-block:: python import ray from ray import train from ray.air import session from ray.air.config import ScalingConfig def train_loop_per_worker(): model = Net() for iter in range(100): # Trainer will automatically handle sharding. data_shard = session.get_dataset_shard("train") for batch in data_shard.iter_torch_batches(): # ... return model train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TorchTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset}) trainer.fit() Args: dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then specifies which dataset shard to return. Returns: The ``DatasetIterator`` shard to use for this worker. If no dataset is passed into Trainer, then return None. """ session = _get_session() if not isinstance(session, _TrainSessionImpl): raise RuntimeError( "`get_dataset_shard` can only be called for TrainSession! " "Make sure you only use that in `train_loop_per_worker` function" "that is passed into `DataParallelTrainer`." ) return session.get_dataset_shard(dataset_name)