ray.tune.with_parameters#

ray.tune.with_parameters(trainable: Type[Trainable] | Callable, **kwargs)[source]#

Wrapper for trainables to pass arbitrary large data objects.

This wrapper function will store all passed parameters in the Ray object store and retrieve them when calling the function. It can thus be used to pass arbitrary data, even datasets, to Tune trainables.

This can also be used as an alternative to functools.partial to pass default arguments to trainables.

When used with the function API, the trainable function is called with the passed parameters as keyword arguments. When used with the class API, the Trainable.setup() method is called with the respective kwargs.

If the data already exists in the object store (are instances of ObjectRef), using tune.with_parameters() is not necessary. You can instead pass the object refs to the training function via the config or use Python partials.

Parameters:
  • trainable – Trainable to wrap.

  • **kwargs – parameters to store in object store.

Function API example:

from ray import train, tune

def train_fn(config, data=None):
    for sample in data:
        loss = update_model(sample)
        train.report(loss=loss)

data = HugeDataset(download=True)

tuner = Tuner(
    tune.with_parameters(train_fn, data=data),
    # ...
)
tuner.fit()

Class API example:

from ray import tune

class MyTrainable(tune.Trainable):
    def setup(self, config, data=None):
        self.data = data
        self.iter = iter(self.data)
        self.next_sample = next(self.iter)

    def step(self):
        loss = update_model(self.next_sample)
        try:
            self.next_sample = next(self.iter)
        except StopIteration:
            return {"loss": loss, done: True}
        return {"loss": loss}

data = HugeDataset(download=True)

tuner = Tuner(
    tune.with_parameters(MyTrainable, data=data),
    # ...
)

PublicAPI (beta): This API is in beta and may change before becoming stable.