ray.tune.with_parameters#
- ray.tune.with_parameters(trainable: Type[Trainable] | Callable, **kwargs)[source]#
Wrapper for trainables to pass arbitrary large data objects.
This wrapper function will store all passed parameters in the Ray object store and retrieve them when calling the function. It can thus be used to pass arbitrary data, even datasets, to Tune trainables.
This can also be used as an alternative to
functools.partial
to pass default arguments to trainables.When used with the function API, the trainable function is called with the passed parameters as keyword arguments. When used with the class API, the
Trainable.setup()
method is called with the respective kwargs.If the data already exists in the object store (are instances of ObjectRef), using
tune.with_parameters()
is not necessary. You can instead pass the object refs to the training function via theconfig
or use Python partials.- Parameters:
trainable – Trainable to wrap.
**kwargs – parameters to store in object store.
Function API example:
from ray import train, tune def train_fn(config, data=None): for sample in data: loss = update_model(sample) train.report(loss=loss) data = HugeDataset(download=True) tuner = Tuner( tune.with_parameters(train_fn, data=data), # ... ) tuner.fit()
Class API example:
from ray import tune class MyTrainable(tune.Trainable): def setup(self, config, data=None): self.data = data self.iter = iter(self.data) self.next_sample = next(self.iter) def step(self): loss = update_model(self.next_sample) try: self.next_sample = next(self.iter) except StopIteration: return {"loss": loss, done: True} return {"loss": loss} data = HugeDataset(download=True) tuner = Tuner( tune.with_parameters(MyTrainable, data=data), # ... )
PublicAPI (beta): This API is in beta and may change before becoming stable.