import inspect
import logging
import os
import queue
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Optional, Type
from ray.air._internal.util import RunnerThread, StartTraceback
from ray.air.constants import _ERROR_FETCH_TIMEOUT
from ray.train._internal.checkpoint_manager import _TrainingResult
from ray.train._internal.session import (
TrialInfo,
_TrainSession,
get_session,
init_session,
shutdown_session,
)
from ray.train.v2._internal.constants import RUN_CONTROLLER_AS_ACTOR_ENV_VAR
from ray.tune.execution.placement_groups import PlacementGroupFactory
from ray.tune.result import DEFAULT_METRIC, RESULT_DUPLICATE, SHOULD_CHECKPOINT
from ray.tune.trainable.trainable import Trainable
from ray.tune.utils import _detect_config_single
from ray.util.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
# Time between FunctionTrainable checks when fetching
# new results after signaling the reporter to continue
NULL_MARKER = ".null_marker"
TEMP_MARKER = ".temp_marker"
[docs]
@DeveloperAPI
class FunctionTrainable(Trainable):
"""Trainable that runs a user function reporting results.
This mode of execution does not support checkpoint/restore."""
_name = "func"
def setup(self, config):
init_session(
training_func=lambda: self._trainable_func(self.config),
trial_info=TrialInfo(
name=self.trial_name,
id=self.trial_id,
resources=self.trial_resources,
logdir=self._storage.trial_driver_staging_path,
driver_ip=None,
driver_node_id=None,
experiment_name=self._storage.experiment_dir_name,
),
storage=self._storage,
synchronous_result_reporting=True,
# Set all Train-specific properties to None.
world_rank=None,
local_rank=None,
node_rank=None,
local_world_size=None,
world_size=None,
dataset_shard=None,
checkpoint=None,
)
self._last_training_result: Optional[_TrainingResult] = None
# NOTE: This environment variable is used to disable the
# spawning a new actor for Ray Train drivers being launched
# within Tune functions.
# There are 2 reasons for this:
# 1. Ray Tune already spawns an actor, so we can run the Ray Train
# driver directly in the same actor.
# 2. This allows `ray.tune.report` to be called within Ray Train driver
# callbacks, since it needs to be called on the same process as the
# Tune FunctionTrainable actor.
os.environ[RUN_CONTROLLER_AS_ACTOR_ENV_VAR] = "0"
def _trainable_func(self, config: Dict[str, Any]):
"""Subclasses can override this to set the trainable func."""
raise NotImplementedError
def _start(self):
def entrypoint():
try:
return self._trainable_func(self.config)
except Exception as e:
raise StartTraceback from e
# the runner thread is not started until the first call to _train
self._runner = RunnerThread(
target=entrypoint, error_queue=self._error_queue, daemon=True
)
# if not alive, try to start
self._status_reporter._start()
try:
self._runner.start()
except RuntimeError:
# If this is reached, it means the thread was started and is
# now done or has raised an exception.
pass
def step(self):
"""Implements train() for a Function API.
If the RunnerThread finishes without reporting "done",
Tune will automatically provide a magic keyword __duplicate__
along with a result with "done=True". The TrialRunner will handle the
result accordingly (see tune/tune_controller.py).
"""
session: _TrainSession = get_session()
if not session.training_started:
session.start()
training_result: Optional[_TrainingResult] = session.get_next()
if not training_result:
# The `RESULT_DUPLICATE` result should have been the last
# result reported by the session, which triggers cleanup.
raise RuntimeError(
"Should not have reached here. The TuneController should not "
"have scheduled another `train` remote call."
"It should have scheduled a `stop` instead "
"after the training function exits."
)
metrics = training_result.metrics
# This keyword appears if the train_func using the Function API
# finishes without "done=True". This duplicates the last result, but
# the TuneController will not log this result again.
# TuneController will also inject done=True to the result,
# and proceed to queue up a STOP decision for the trial.
if RESULT_DUPLICATE in metrics:
metrics[SHOULD_CHECKPOINT] = False
self._last_training_result = training_result
if training_result.checkpoint is not None:
# TODO(justinvyu): Result/checkpoint reporting can be combined.
# For now, since result/checkpoint reporting is separate, this
# special key will tell Tune to pull the checkpoint from
# the `last_training_result`.
metrics[SHOULD_CHECKPOINT] = True
return metrics
def execute(self, fn):
return fn(self)
def save_checkpoint(self, checkpoint_dir: str = ""):
if checkpoint_dir:
raise ValueError("Checkpoint dir should not be used with function API.")
# TODO(justinvyu): This currently breaks the `save_checkpoint` interface.
# TRAIN -> SAVE remote calls get processed sequentially,
# so `_last_training_result.checkpoint` holds onto the latest ckpt.
return self._last_training_result
def load_checkpoint(self, checkpoint_result: _TrainingResult):
# TODO(justinvyu): This currently breaks the `load_checkpoint` interface.
session = get_session()
session.loaded_checkpoint = checkpoint_result.checkpoint
def cleanup(self):
session = get_session()
try:
# session.finish raises any Exceptions from training.
# Do not wait for thread termination here (timeout=0).
session.finish(timeout=0)
finally:
# Check for any errors that might have been missed.
session._report_thread_runner_error()
# Shutdown session even if session.finish() raises an Exception.
shutdown_session()
def reset_config(self, new_config):
session = get_session()
# Wait for thread termination so it is save to re-use the same actor.
thread_timeout = int(os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2))
session.finish(timeout=thread_timeout)
if session.training_thread.is_alive():
# Did not finish within timeout, reset unsuccessful.
return False
session.reset(
training_func=lambda: self._trainable_func(self.config),
trial_info=TrialInfo(
name=self.trial_name,
id=self.trial_id,
resources=self.trial_resources,
logdir=self._storage.trial_working_directory,
driver_ip=None,
driver_node_id=None,
experiment_name=self._storage.experiment_dir_name,
),
storage=self._storage,
)
self._last_result = {}
return True
def _report_thread_runner_error(self, block=False):
try:
e = self._error_queue.get(block=block, timeout=_ERROR_FETCH_TIMEOUT)
raise StartTraceback from e
except queue.Empty:
pass
[docs]
@DeveloperAPI
def wrap_function(
train_func: Callable[[Any], Any], name: Optional[str] = None
) -> Type["FunctionTrainable"]:
inherit_from = (FunctionTrainable,)
if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + inherit_from
func_args = inspect.getfullargspec(train_func).args
use_config_single = _detect_config_single(train_func)
if not use_config_single:
raise ValueError(
"Unknown argument found in the Trainable function. "
"The function args must include a single 'config' positional parameter.\n"
"Found: {}".format(func_args)
)
resources = getattr(train_func, "_resources", None)
class ImplicitFunc(*inherit_from):
_name = name or (
train_func.__name__ if hasattr(train_func, "__name__") else "func"
)
def __repr__(self):
return self._name
def _trainable_func(self, config):
fn = partial(train_func, config)
def handle_output(output):
if not output:
return
elif isinstance(output, dict):
get_session().report(output)
elif isinstance(output, Number):
get_session().report({DEFAULT_METRIC: output})
else:
raise ValueError(
"Invalid return or yield value. Either return/yield "
"a single number or a dictionary object in your "
"trainable function."
)
output = None
if inspect.isgeneratorfunction(train_func):
for output in fn():
handle_output(output)
else:
output = fn()
handle_output(output)
# If train_func returns, we need to notify the main event loop
# of the last result while avoiding double logging. This is done
# with the keyword RESULT_DUPLICATE -- see tune/tune_controller.py.
get_session().report({RESULT_DUPLICATE: True})
return output
@classmethod
def default_resource_request(
cls, config: Dict[str, Any]
) -> Optional[PlacementGroupFactory]:
if not isinstance(resources, PlacementGroupFactory) and callable(resources):
return resources(config)
return resources
return ImplicitFunc