Training (tune.Trainable, tune.report)¶
Training can be done with either a Class API (tune.Trainable
) or function API (tune.report
).
For the sake of example, let’s maximize this objective function:
def objective(x, a, b):
return a * (x ** 0.5) + b
Function API¶
With the Function API, you can report intermediate metrics by simply calling tune.report
within the provided function.
def trainable(config):
# config (dict): A dict of hyperparameters.
for x in range(20):
intermediate_score = objective(x, config["a"], config["b"])
tune.report(score=intermediate_score) # This sends the score to Tune.
analysis = tune.run(
trainable,
config={"a": 2, "b": 4}
)
print("best config: ", analysis.get_best_config(metric="score", mode="max"))
Tip
Do not use tune.report
within a Trainable
class.
Tune will run this function on a separate thread in a Ray actor process.
You’ll notice that Ray Tune will output extra values in addition to the user reported metrics,
such as iterations_since_restore
. See How to use log metrics in Tune? for an explanation/glossary of these values.
Tip
If you want to leverage multi-node data parallel training with PyTorch while using parallel hyperparameter tuning, check out our PyTorch user guide and Tune’s distributed pytorch integrations.
Function API return and yield values¶
Instead of using tune.report()
, you can also use Python’s yield
statement to report metrics to Ray Tune:
def trainable(config):
# config (dict): A dict of hyperparameters.
for x in range(20):
intermediate_score = objective(x, config["a"], config["b"])
yield {"score": intermediate_score} # This sends the score to Tune.
analysis = tune.run(
trainable,
config={"a": 2, "b": 4}
)
print("best config: ", analysis.get_best_config(metric="score", mode="max"))
If you yield a dictionary object, this will work just as tune.report()
.
If you yield a number, if will be reported to Ray Tune with the key _metric
, i.e.
as if you had called tune.report(_metric=value)
.
Ray Tune supports the same functionality for return values if you only report metrics at the end of each run:
def trainable(config):
# config (dict): A dict of hyperparameters.
final_score = 0
for x in range(20):
final_score = objective(x, config["a"], config["b"])
return {"score": final_score} # This sends the score to Tune.
analysis = tune.run(
trainable,
config={"a": 2, "b": 4}
)
print("best config: ", analysis.get_best_config(metric="score", mode="max"))
Function API Checkpointing¶
Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance.
To use Tune’s checkpointing features, you must expose a checkpoint_dir
argument in the function signature,
and call tune.checkpoint_dir
:
import time
from ray import tune
def train_func(config, checkpoint_dir=None):
start = 0
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
state = json.loads(f.read())
start = state["step"] + 1
for iter in range(start, 100):
time.sleep(1)
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.report(hello="world", ray="tune")
tune.run(train_func)
Note
checkpoint_freq
and checkpoint_at_end
will not work with Function API checkpointing.
In this example, checkpoints will be saved by training iteration to local_dir/exp_name/trial_name/checkpoint_<step>
.
You can restore a single trial checkpoint by using tune.run(restore=<checkpoint_dir>)
:
analysis = tune.run(
train,
config={
"max_iter": 5
},
).trials
last_ckpt = trial.checkpoint.value
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)
Tune also may copy or move checkpoints during the course of tuning. For this purpose,
it is important not to depend on absolute paths in the implementation of save
.
Trainable Class API¶
Caution
Do not use tune.report
within a Trainable
class.
The Trainable class API will require users to subclass ray.tune.Trainable
. Here’s a naive example of this API:
from ray import tune
class Trainable(tune.Trainable):
def setup(self, config):
# config (dict): A dict of hyperparameters
self.x = 0
self.a = config["a"]
self.b = config["b"]
def step(self): # This is called iteratively.
score = objective(self.x, self.a, self.b)
self.x += 1
return {"score": score}
analysis = tune.run(
Trainable,
stop={"training_iteration": 20},
config={
"a": 2,
"b": 4
})
print('best config: ', analysis.get_best_config(metric="score", mode="max"))
As a subclass of tune.Trainable
, Tune will create a Trainable
object on a
separate process (using the Ray Actor API).
setup
function is invoked once training starts.
step
is invoked multiple times. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.
cleanup
is invoked when training is finished.
Tip
As a rule of thumb, the execution time of step
should be large enough to avoid overheads
(i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).
You’ll notice that Ray Tune will output extra values in addition to the user reported metrics,
such as iterations_since_restore
.
See How to use log metrics in Tune? for an explanation/glossary of these values.
Class API Checkpointing¶
You can also implement checkpoint/restore using the Trainable Class API:
class MyTrainableClass(Trainable):
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
tune.run(MyTrainableClass, checkpoint_freq=2)
You can checkpoint with three different mechanisms: manually, periodically, and at termination.
Manual Checkpointing: A custom Trainable can manually trigger checkpointing by returning should_checkpoint: True
(or tune.result.SHOULD_CHECKPOINT: True
) in the result dictionary of step.
This can be especially helpful in spot instances:
def step(self):
# training code
result = {"mean_accuracy": accuracy}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result
Periodic Checkpointing: periodic checkpointing can be used to provide fault-tolerance for experiments.
This can be enabled by setting checkpoint_freq=<int>
and max_failures=<int>
to checkpoint trials
every N iterations and recover from up to M crashes per trial, e.g.:
tune.run(
my_trainable,
checkpoint_freq=10,
max_failures=5,
)
Checkpointing at Termination: The checkpoint_freq may not coincide with the exact end of an experiment.
If you want a checkpoint to be created at the end of a trial, you can additionally set the checkpoint_at_end=True
:
tune.run(
my_trainable,
checkpoint_freq=10,
checkpoint_at_end=True,
max_failures=5,
)
Use validate_save_restore
to catch save_checkpoint
/load_checkpoint
errors before execution.
from ray.tune.utils import validate_save_restore
# both of these should return
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)
Advanced: Reusing Actors¶
Note
This feature is only for the Trainable Class API.
Your Trainable can often take a long time to start.
To avoid this, you can do tune.run(reuse_actors=True)
to reuse the same Trainable Python process and
object for multiple hyperparameters.
This requires you to implement Trainable.reset_config
, which provides a new set of hyperparameters.
It is up to the user to correctly update the hyperparameters of your trainable.
class PytorchTrainble(tune.Trainable):
"""Train a Pytorch ConvNet."""
def setup(self, config):
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet()
self.optimizer = optim.SGD(
self.model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
def reset_config(self, new_config):
for param_group in self.optimizer.param_groups:
if "lr" in new_config:
param_group["lr"] = new_config["lr"]
if "momentum" in new_config:
param_group["momentum"] = new_config["momentum"]
self.model = ConvNet()
self.config = new_config
return True
Advanced Resource Allocation¶
Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks
that also consume CPU / GPU resources, you will want to add more bundles to the PlacementGroupFactory
to reserve extra resource slots.
For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU,
then you should use this:
tune.run(
my_trainable,
name="my_trainable",
resources_per_trial=tune.PlacementGroupFactory([
{"CPU": 1, "GPU": 1},
{"GPU": 1},
{"GPU": 1},
{"GPU": 1},
{"GPU": 1}
])
)
The Trainable
also provides the default_resource_requests
interface to automatically
declare the resources_per_trial
based on the given configuration.
It is also possible to specify memory ("memory"
, in bytes) and custom resource requirements.
tune.report / tune.checkpoint (Function API)¶
- ray.tune.report(_metric=None, **kwargs)[source]¶
Logs all keyword arguments.
import time from ray import tune def run_me(config): for iter in range(100): time.sleep(1) tune.report(hello="world", ray="tune") analysis = tune.run(run_me)
- Parameters
_metric – Optional default anonymous metric for
tune.report(value)
**kwargs – Any key value pair to be logged by Tune. Any of these metrics can be used for early stopping or optimization.
PublicAPI: This API is stable across Ray releases.
- ray.tune.checkpoint_dir(step)[source]¶
Returns a checkpoint dir inside a context.
Store any files related to restoring state within the provided checkpoint dir.
You should call this before calling
tune.report
. The reason is because we want checkpoints to be correlated with the result (i.e., be able to retrieve the best checkpoint, etc). Many algorithms depend on this behavior too.Calling
checkpoint_dir
after report could introduce inconsistencies.- Parameters
step (int) – Index for the checkpoint. Expected to be a monotonically increasing quantity.
import os import json import time from ray import tune def func(config, checkpoint_dir=None): start = 0 if checkpoint_dir: with open(os.path.join(checkpoint_dir, "checkpoint")) as f: state = json.loads(f.read()) accuracy = state["acc"] start = state["step"] + 1 for iter in range(start, 10): time.sleep(1) with tune.checkpoint_dir(step=iter) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"step": start})) tune.report(hello="world", ray="tune")
- Yields
checkpoint_dir (str) – Directory for checkpointing.
New in version 0.8.7.
PublicAPI: This API is stable across Ray releases.
- ray.tune.get_trial_dir()[source]¶
Returns the directory where trial results are saved.
For function API use only.
DeveloperAPI: This API may change across minor Ray releases.
tune.Trainable (Class API)¶
- class ray.tune.Trainable(config: Optional[Dict[str, Any]] = None, logger_creator: Optional[Callable[[Dict[str, Any]], ray.tune.logger.Logger]] = None, remote_checkpoint_dir: Optional[str] = None, sync_function_tpl: Optional[str] = None)[source]¶
Abstract class for trainable models, functions, etc.
A call to
train()
on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).Calling
save()
should save the training state of a trainable to disk, andrestore(path)
should restore a trainable to the given state.Generally you only need to implement
setup
,step
,save_checkpoint
, andload_checkpoint
when subclassing Trainable.Other implementation methods that may be helpful to override are
log_result
,reset_config
,cleanup
, and_export_model
.When using Tune, Tune will convert this class into a Ray actor, which runs on a separate process. Tune will also change the current working directory of this process to
self.logdir
. This is designed so that different trials that run on the same physical node won’t accidently write to the same location and overstep each other.If you want to know the orginal working directory path on the driver node, you can do so through env variable “TUNE_ORIG_WORKING_DIR”. It is advised that you access this path for read only purposes and you need to make sure that the path exists on the remote nodes.
This class supports checkpointing to and restoring from remote storage.
PublicAPI: This API is stable across Ray releases.
- _create_logger(config: Dict[str, Any], logger_creator: Optional[Callable[[Dict[str, Any]], ray.tune.logger.Logger]] = None)[source]¶
Create logger from logger creator.
Sets _logdir and _result_logger.
_logdir is the per trial directory for the Trainable.
- _export_model(export_formats, export_dir)[source]¶
Subclasses should override this to export model.
- Parameters
export_formats (list) – List of formats that should be exported.
export_dir (str) – Directory to place exported models.
- Returns
A dict that maps ExportFormats to successfully exported models.
- _postprocess_checkpoint(checkpoint_path: str)[source]¶
Run extra postprocessing before the checkpoint is saved to cloud.
- _storage_path(local_path)[source]¶
Converts a local_path to be based off of self.remote_checkpoint_dir.
- cleanup()[source]¶
Subclasses should override this for any cleanup on stop.
If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here.
You can kill a Ray actor by calling actor.__ray_terminate__.remote() on the actor.
New in version 0.8.7.
- classmethod default_resource_request(config: Dict[str, Any]) Union[ray.tune.resources.Resources, ray.tune.utils.placement_groups.PlacementGroupFactory] [source]¶
Provides a static resource requirement for the given configuration.
This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to.
@classmethod def default_resource_request(cls, config): return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}]])
- Parameters
config[Dict[str – The Trainable’s config dict.
Any]] – The Trainable’s config dict.
- Returns
- A Resources object or
PlacementGroupFactory consumed by Tune for queueing.
- Return type
Union[Resources, PlacementGroupFactory]
- delete_checkpoint(checkpoint_path)[source]¶
Deletes local copy of checkpoint.
- Parameters
checkpoint_path (str) – Path to checkpoint.
- export_model(export_formats, export_dir=None)[source]¶
Exports model based on export_formats.
Subclasses should override _export_model() to actually export model to local directory.
- Parameters
export_formats (Union[list,str]) – Format or list of (str) formats that should be exported.
export_dir (str) – Optional dir to place the exported model. Defaults to self.logdir.
- Returns
A dict that maps ExportFormats to successfully exported models.
- get_auto_filled_metrics(now: Optional[datetime.datetime] = None, time_this_iter: Optional[float] = None, debug_metrics_only: bool = False) dict [source]¶
Return a dict with metrics auto-filled by the trainable.
If
debug_metrics_only
is True, only metrics that don’t require at least one iteration will be returned (ray.tune.result.DEBUG_METRICS
).
- load_checkpoint(checkpoint)[source]¶
Subclasses should override this to implement restore().
Warning
In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in
Trainable.save_checkpoint
may be changed.If
Trainable.save_checkpoint
returned a prefixed string, the prefix of the checkpoint string returned byTrainable.save_checkpoint
may be changed. This is because trial pausing depends on temporary directories.The directory structure under the checkpoint_dir provided to
Trainable.save_checkpoint
is preserved.See the example below.
class Example(Trainable): def save_checkpoint(self, checkpoint_path): print(checkpoint_path) return os.path.join(checkpoint_path, "my/check/point") def load_checkpoint(self, checkpoint): print(checkpoint) >>> trainer = Example() >>> obj = trainer.save_to_object() # This is used when PAUSED. <logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point >>> trainer.restore_from_object(obj) # Note the different prefix. <logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point
New in version 0.8.7.
- Parameters
checkpoint (str|dict) – If dict, the return value is as returned by save_checkpoint. If a string, then it is a checkpoint path that may have a different prefix than that returned by save_checkpoint. The directory structure underneath the checkpoint_dir save_checkpoint is preserved.
- log_result(result)[source]¶
Subclasses can optionally override this to customize logging.
The logging here is done on the worker process rather than the driver. You may want to turn off driver logging via the
loggers
parameter intune.run
when overriding this function.New in version 0.8.7.
- Parameters
result (dict) – Training result returned by step().
- reset(new_config, logger_creator=None)[source]¶
Resets trial for use with new config.
Subclasses should override reset_config() to actually reset actor behavior for the new config.
- reset_config(new_config)[source]¶
Resets configuration without restarting the trial.
This method is optional, but can be implemented to speed up algorithms such as PBT, and to allow performance optimizations such as running experiments with reuse_actors=True.
- Parameters
new_config (dict) – Updated hyperparameter configuration for the trainable.
- Returns
True if reset was successful else False.
- classmethod resource_help(config)[source]¶
Returns a help string for configuring this trainable’s resources.
- Parameters
config (dict) – The Trainer’s config dict.
- restore(checkpoint_path)[source]¶
Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
Subclasses should override
load_checkpoint()
instead to restore state. This method restores additional metadata saved with the checkpoint.checkpoint_path should match with the return from
save()
.checkpoint_path can be ~/ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint. Or, ~/ray_results/exp/MyTrainable_abc/checkpoint_00000.
self.logdir should generally be corresponding to checkpoint_path, for example, ~/ray_results/exp/MyTrainable_abc.
self.remote_checkpoint_dir in this case, is something like, REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc
- restore_from_object(obj)[source]¶
Restores training state from a checkpoint object.
These checkpoints are returned from calls to save_to_object().
- save(checkpoint_dir=None) str [source]¶
Saves the current model state to a checkpoint.
Subclasses should override
save_checkpoint()
instead to save state. This method dumps additional metadata alongside the saved path.If a remote checkpoint dir is given, this will also sync up to remote storage.
- Parameters
checkpoint_dir (str) – Optional dir to place the checkpoint.
- Returns
path that points to xxx.pkl file.
- Return type
str
Note the return path should match up with what is expected of restore().
- save_checkpoint(tmp_checkpoint_dir)[source]¶
Subclasses should override this to implement
save()
.Warning
Do not rely on absolute paths in the implementation of
Trainable.save_checkpoint
andTrainable.load_checkpoint
.Use
validate_save_restore
to catchTrainable.save_checkpoint
/Trainable.load_checkpoint
errors before execution.>>> from ray.tune.utils import validate_save_restore >>> validate_save_restore(MyTrainableClass) >>> validate_save_restore(MyTrainableClass, use_object_store=True)
New in version 0.8.7.
- Parameters
tmp_checkpoint_dir (str) – The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved.
- Returns
A dict or string. If string, the return value is expected to be prefixed by tmp_checkpoint_dir. If dict, the return value will be automatically serialized by Tune and passed to
Trainable.load_checkpoint()
.
Examples
>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) "/tmp/checkpoint_1/my_checkpoint_file" >>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) {"some": "data"}
>>> trainable.save_checkpoint("/tmp/bad_example") "/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
- save_to_object()[source]¶
Saves the current model state to a Python object.
It also saves to disk but does not return the checkpoint path.
- Returns
Object holding checkpoint data.
- setup(config)[source]¶
Subclasses should override this for custom initialization.
New in version 0.8.7.
- Parameters
config (dict) – Hyperparameters and other configs given. Copy of self.config.
- step()[source]¶
Subclasses should override this to implement train().
The return value will be automatically passed to the loggers. Users can also return tune.result.DONE or tune.result.SHOULD_CHECKPOINT as a key to manually trigger termination or checkpointing of this trial. Note that manual checkpointing only works when subclassing Trainables.
New in version 0.8.7.
- Returns
A dict that describes training progress.
- stop()[source]¶
Releases all resources used by this trainable.
Calls
Trainable.cleanup
internally. Subclasses should overrideTrainable.cleanup
for custom cleanup procedures.
- train()[source]¶
Runs one logical iteration of training.
Calls
step()
internally. Subclasses should overridestep()
instead to return results. This method automatically fills the following fields in the result:done (bool): training is terminated. Filled only if not provided.
time_this_iter_s (float): Time in seconds this iteration took to run. This may be overridden in order to override the system-computed time difference.
time_total_s (float): Accumulated time in seconds for this entire experiment.
experiment_id (str): Unique string identifier for this experiment. This id is preserved across checkpoint / restore calls.
training_iteration (int): The index of this training iteration, e.g. call to train(). This is incremented after step() is called.
pid (str): The pid of the training process.
date (str): A formatted date of when the result was processed.
timestamp (str): A UNIX timestamp of when the result was processed.
hostname (str): Hostname of the machine hosting the training process.
node_ip (str): Node ip of the machine hosting the training process.
- Returns
A dict that describes training progress.
- train_buffered(buffer_time_s: float, max_buffer_length: int = 1000)[source]¶
Runs multiple iterations of training.
Calls
train()
internally. Collects and combines multiple results. This function will runself.train()
repeatedly until one of the following conditions is met: 1) the maximum buffer length is reached, 2) the maximum buffer time is reached, or 3) a checkpoint was created. Even if the maximum time is reached, it will always block until at least one result is received.- Parameters
buffer_time_s (float) – Maximum time to buffer. The next result received after this amount of time has passed will return the whole buffer.
max_buffer_length (int) – Maximum number of results to buffer.
- property iteration¶
Current training iteration.
This value is automatically incremented every time train() is called and is automatically inserted into the training result dict.
- property logdir¶
Directory of the results and checkpoints for this Trainable.
Tune will automatically sync this folder with the driver if execution is distributed.
Note that the current working directory will also be changed to this.
- property training_iteration¶
Current training iteration (same as self.iteration).
This value is automatically incremented every time train() is called and is automatically inserted into the training result dict.
- property trial_id¶
Trial ID for the corresponding trial of this Trainable.
This is not set if not using Tune.
trial_id = self.trial_id
- property trial_name¶
Trial name for the corresponding trial of this Trainable.
This is not set if not using Tune.
name = self.trial_name
- property trial_resources: Union[ray.tune.resources.Resources, ray.tune.utils.placement_groups.PlacementGroupFactory]¶
Resources currently assigned to the trial of this Trainable.
This is not set if not using Tune.
trial_resources = self.trial_resources
Utilities¶
- ray.tune.utils.wait_for_gpu(gpu_id=None, target_util=0.01, retry=20, delay_s=5, gpu_memory_limit=None)[source]¶
Checks if a given GPU has freed memory.
Requires
gputil
to be installed:pip install gputil
.- Parameters
gpu_id (Optional[Union[int, str]]) – GPU id or uuid to check. Must be found within GPUtil.getGPUs(). If none, resorts to the first item returned from ray.get_gpu_ids().
target_util (float) – The utilization threshold to reach to unblock. Set this to 0 to block until the GPU is completely free.
retry (int) – Number of times to check GPU limit. Sleeps delay_s seconds between checks.
delay_s (int) – Seconds to wait before check.
gpu_memory_limit (float) – Deprecated.
- Returns
True if free.
- Return type
bool
- Raises
RuntimeError – If GPUtil is not found, if no GPUs are detected or if the check fails.
Example:
def tune_func(config): tune.util.wait_for_gpu() train() tune.run(tune_func, resources_per_trial={"GPU": 1}, num_samples=10)
- ray.tune.utils.diagnose_serialization(trainable)[source]¶
Utility for detecting why your trainable function isn’t serializing.
- Parameters
trainable (func) – The trainable object passed to tune.run(trainable). Currently only supports Function API.
- Returns
bool | set of unserializable objects.
Example:
import threading # this is not serializable e = threading.Event() def test(): print(e) diagnose_serialization(test) # should help identify that 'e' should be moved into # the `test` scope. # correct implementation def test(): e = threading.Event() print(e) assert diagnose_serialization(test) is True
- ray.tune.utils.validate_save_restore(trainable_cls, config=None, num_gpus=0, use_object_store=False)[source]¶
Helper method to check if your Trainable class will resume correctly.
- Parameters
trainable_cls – Trainable class for evaluation.
config (dict) – Config to pass to Trainable when testing.
num_gpus (int) – GPU resources to allocate when testing.
use_object_store (bool) – Whether to save and restore to Ray’s object store. Recommended to set this to True if planning to use algorithms that pause training (i.e., PBT, HyperBand).
Distributed Torch¶
Ray offers lightweight integrations to distribute your PyTorch training on Ray Tune.
- ray.tune.integration.torch.DistributedTrainableCreator(func: Callable, num_workers: int = 1, num_cpus_per_worker: int = 1, num_gpus_per_worker: int = 0, num_workers_per_host: Optional[int] = None, backend: str = 'gloo', timeout_s: int = 1800, use_gpu=None) Type[ray.tune.integration.torch._TorchTrainable] [source]
Creates a class that executes distributed training.
Similar to running torch.distributed.launch.
Note that you typically should not instantiate the object created.
- Parameters
func (callable) – This function is a Tune trainable function. This function must have 2 args in the signature, and the latter arg must contain checkpoint_dir. For example: func(config, checkpoint_dir=None).
num_workers (int) – Number of training workers to include in world.
num_cpus_per_worker (int) – Number of CPU resources to reserve per training worker.
num_gpus_per_worker (int) – Number of GPU resources to reserve per training worker.
num_workers_per_host – Optional[int]: Number of workers to colocate per host.
backend (str) – One of “gloo”, “nccl”.
timeout_s (float) – Seconds before the torch process group times out. Useful when machines are unreliable. Defaults to 1800 seconds. This value is also reused for triggering placement timeouts if forcing colocation.
- Returns
A trainable class object that can be passed to Tune. Resources are automatically set within the object, so users do not need to set resources_per_trainable.
- Return type
type(Trainable)
Example:
trainable_cls = DistributedTrainableCreator( train_func, num_workers=2) analysis = tune.run(trainable_cls)
- ray.tune.integration.torch.distributed_checkpoint_dir(step: int, disable: bool = False) Generator[str, None, None] [source]
ContextManager for creating a distributed checkpoint.
Only checkpoints a file on the “main” training actor, avoiding redundant work.
- Parameters
step (int) – Used to label the checkpoint
disable (bool) – Disable for prototyping.
- Yields
str – A path to a directory. This path will be used again when invoking the training_function.
Example:
def train_func(config, checkpoint_dir): if checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") model_state_dict = torch.load(path) if epoch % 3 == 0: with distributed_checkpoint_dir(step=epoch) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") torch.save(model.state_dict(), path)
- ray.tune.integration.torch.is_distributed_trainable()[source]
Returns True if executing within a DistributedTrainable.
Distributed TensorFlow¶
Ray also offers lightweight integrations to distribute your TensorFlow training on Ray Tune.
- ray.tune.integration.tensorflow.DistributedTrainableCreator(func: Callable, num_workers: int = 2, num_gpus_per_worker: int = 0, num_cpus_per_worker: int = 1, num_workers_per_host: Optional[int] = None, timeout_s: int = 900) Type[ray.tune.integration.tensorflow._TensorFlowTrainable] [source]
Converts TensorFlow MultiWorkerMirror training to be executable by Tune.
Requires TensorFlow > 2.0 to work, recommends TensorFlow > 2.2.
This function wraps and sets resources for a TF distributed training function to be used with Tune. It generates a TensorFlow Trainable which can be a distributed training job.
Note: there is no fault tolerance at the moment.
- Parameters
func (Callable[[dict], None]) – A training function that takes in a config dict for hyperparameters and should initialize horovod via horovod.init.
num_gpus_per_worker (int) – from Ray per worker.
num_cpus_per_worker (int) – Number of CPUs to request from Ray per worker.
num_workers (int) – Number of hosts that each trial is expected to use.
num_workers_per_host (Optional[int]) – Number of workers to colocate per host. None if not specified.
timeout_s (float) – Seconds before triggering placement timeouts if forcing colocation. Default to 15 minutes.
- Returns
Trainable class that can be passed into tune.run.
New in version 1.1.0.
Example:
# Please refer to full example in tf_distributed_keras_example.py tf_trainable = DistributedTrainableCreator( train_mnist, num_workers=2) tune.run(tf_trainable, num_samples=1)
tune.with_parameters¶
- ray.tune.with_parameters(trainable, **kwargs)[source]¶
Wrapper for trainables to pass arbitrary large data objects.
This wrapper function will store all passed parameters in the Ray object store and retrieve them when calling the function. It can thus be used to pass arbitrary data, even datasets, to Tune trainables.
This can also be used as an alternative to
functools.partial
to pass default arguments to trainables.When used with the function API, the trainable function is called with the passed parameters as keyword arguments. When used with the class API, the
Trainable.setup()
method is called with the respective kwargs.If the data already exists in the object store (are instances of ObjectRef), using
tune.with_parameters()
is not necessary. You can instead pass the object refs to the training function via theconfig
or use Python partials.- Parameters
trainable – Trainable to wrap.
**kwargs – parameters to store in object store.
Function API example:
from ray import tune def train(config, data=None): for sample in data: loss = update_model(sample) tune.report(loss=loss) data = HugeDataset(download=True) tune.run( tune.with_parameters(train, data=data), # ... )
Class API example:
from ray import tune class MyTrainable(tune.Trainable): def setup(self, config, data=None): self.data = data self.iter = iter(self.data) self.next_sample = next(self.iter) def step(self): loss = update_model(self.next_sample) try: self.next_sample = next(self.iter) except StopIteration: return {"loss": loss, done: True} return {"loss": loss} data = HugeDataset(download=True) tune.run( tune.with_parameters(MyTrainable, data=data), # ... )
StatusReporter¶
- class ray.tune.function_runner.StatusReporter(result_queue, continue_semaphore, end_event, trial_name=None, trial_id=None, logdir=None, trial_resources=None)[source]¶
Object passed into your function that you can report status through.
Example
>>> def trainable_function(config, reporter): >>> assert isinstance(reporter, StatusReporter) >>> reporter(timesteps_this_iter=1)
- __call__(_metric=None, **kwargs)[source]¶
Report updated training status.
Pass in done=True when the training job is completed.
- Parameters
kwargs – Latest training result status.
Example
>>> reporter(mean_accuracy=1, training_iteration=4) >>> reporter(mean_accuracy=1, training_iteration=4, done=True)
- Raises
StopIteration – A StopIteration exception is raised if the trial has been signaled to stop.