Source code for ray.tune.trainable.util

import inspect
import logging
import types
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING

import ray
from ray.tune.execution.placement_groups import (
    PlacementGroupFactory,
    resource_dict_to_pg_factory,
)
from ray.tune.registry import _ParameterRegistry
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.tune.trainable import Trainable

logger = logging.getLogger(__name__)


[docs]@PublicAPI(stability="beta") def with_parameters(trainable: Union[Type["Trainable"], Callable], **kwargs): """Wrapper for trainables to pass arbitrary large data objects. This wrapper function will store all passed parameters in the Ray object store and retrieve them when calling the function. It can thus be used to pass arbitrary data, even datasets, to Tune trainables. This can also be used as an alternative to ``functools.partial`` to pass default arguments to trainables. When used with the function API, the trainable function is called with the passed parameters as keyword arguments. When used with the class API, the ``Trainable.setup()`` method is called with the respective kwargs. If the data already exists in the object store (are instances of ObjectRef), using ``tune.with_parameters()`` is not necessary. You can instead pass the object refs to the training function via the ``config`` or use Python partials. Args: trainable: Trainable to wrap. **kwargs: parameters to store in object store. Function API example: .. code-block:: python from ray import train, tune def train_fn(config, data=None): for sample in data: loss = update_model(sample) train.report(loss=loss) data = HugeDataset(download=True) tuner = Tuner( tune.with_parameters(train_fn, data=data), # ... ) tuner.fit() Class API example: .. code-block:: python from ray import tune class MyTrainable(tune.Trainable): def setup(self, config, data=None): self.data = data self.iter = iter(self.data) self.next_sample = next(self.iter) def step(self): loss = update_model(self.next_sample) try: self.next_sample = next(self.iter) except StopIteration: return {"loss": loss, done: True} return {"loss": loss} data = HugeDataset(download=True) tuner = Tuner( tune.with_parameters(MyTrainable, data=data), # ... ) """ from ray.tune.trainable import Trainable if not callable(trainable) or ( inspect.isclass(trainable) and not issubclass(trainable, Trainable) ): raise ValueError( f"`tune.with_parameters() only works with function trainables " f"or classes that inherit from `tune.Trainable()`. Got type: " f"{type(trainable)}." ) parameter_registry = _ParameterRegistry() ray._private.worker._post_init_hooks.append(parameter_registry.flush) # Objects are moved into the object store prefix = f"{str(trainable)}_" for k, v in kwargs.items(): parameter_registry.put(prefix + k, v) trainable_name = getattr(trainable, "__name__", "tune_with_parameters") keys = set(kwargs.keys()) if inspect.isclass(trainable): # Class trainable class _Inner(trainable): def setup(self, config): setup_kwargs = {} for k in keys: setup_kwargs[k] = parameter_registry.get(prefix + k) super(_Inner, self).setup(config, **setup_kwargs) trainable_with_params = _Inner else: # Function trainable def inner(config): fn_kwargs = {} for k in keys: fn_kwargs[k] = parameter_registry.get(prefix + k) return trainable(config, **fn_kwargs) trainable_with_params = inner if hasattr(trainable, "__mixins__"): trainable_with_params.__mixins__ = trainable.__mixins__ # If the trainable has been wrapped with `tune.with_resources`, we should # keep the `_resources` attribute around if hasattr(trainable, "_resources"): trainable_with_params._resources = trainable._resources trainable_with_params.__name__ = trainable_name return trainable_with_params
[docs]@PublicAPI(stability="beta") def with_resources( trainable: Union[Type["Trainable"], Callable], resources: Union[ Dict[str, float], PlacementGroupFactory, Callable[[dict], PlacementGroupFactory], ], ): """Wrapper for trainables to specify resource requests. This wrapper allows specification of resource requirements for a specific trainable. It will override potential existing resource requests (use with caution!). The main use case is to request resources for function trainables when used with the Tuner() API. Class trainables should usually just implement the ``default_resource_request()`` method. Args: trainable: Trainable to wrap. resources: Resource dict, placement group factory, or callable that takes in a config dict and returns a placement group factory. Example: .. code-block:: python from ray import tune from ray.tune.tuner import Tuner def train_fn(config): return len(ray.get_gpu_ids()) # Returns 2 tuner = Tuner( tune.with_resources(train_fn, resources={"gpu": 2}), # ... ) results = tuner.fit() """ from ray.tune.trainable import Trainable if not callable(trainable) or ( inspect.isclass(trainable) and not issubclass(trainable, Trainable) ): raise ValueError( f"`tune.with_resources() only works with function trainables " f"or classes that inherit from `tune.Trainable()`. Got type: " f"{type(trainable)}." ) if isinstance(resources, PlacementGroupFactory): pgf = resources elif isinstance(resources, dict): pgf = resource_dict_to_pg_factory(resources) elif callable(resources): pgf = resources else: raise ValueError( f"Invalid resource type for `with_resources()`: {type(resources)}" ) if not inspect.isclass(trainable): if isinstance(trainable, types.MethodType): # Methods cannot set arbitrary attributes, so we have to wrap them def _trainable(config): return trainable(config) _trainable._resources = pgf return _trainable # Just set an attribute. This will be resolved later in `wrap_function()`. try: trainable._resources = pgf except AttributeError as e: raise RuntimeError( "Could not use `tune.with_resources()` on the supplied trainable. " "Wrap your trainable in a regular function before passing it " "to Ray Tune." ) from e else: class ResourceTrainable(trainable): @classmethod def default_resource_request( cls, config: Dict[str, Any] ) -> Optional[PlacementGroupFactory]: if not isinstance(pgf, PlacementGroupFactory) and callable(pgf): return pgf(config) return pgf ResourceTrainable.__name__ = trainable.__name__ trainable = ResourceTrainable return trainable