class ray.train.mosaic.MosaicTrainer(*args, **kwargs)[source]#

Bases: ray.train.torch.torch_trainer.TorchTrainer

A Trainer for data parallel Mosaic Composers on PyTorch training.

This Trainer runs the composer.trainer.Trainer.fit() method on multiple Ray Actors. The training is carried out in a distributed fashion through PyTorch DDP. These actors already have the necessary torch process group already configured for distributed PyTorch training.

The training function ran on every Actor will first run the specified trainer_init_per_worker function to obtain an instantiated composer.Trainer object. The trainer_init_per_worker function will have access to preprocessed train and evaluation datasets.


>>> import torch.utils.data  
>>> import torchvision  
>>> from torchvision import transforms, datasets  
>>> from composer.models.tasks import ComposerClassifier 
>>> import composer.optim 
>>> from composer.algorithms import LabelSmoothing 
>>> import ray
>>> from ray.air.config import ScalingConfig
>>> import ray.train as train
>>> from ray.air import session
>>> from ray.train.mosaic import MosaicTrainer 
>>> def trainer_init_per_worker(config):
...     # prepare the model for distributed training and wrap with
...     # ComposerClassifier for Composer Trainer compatibility
...     model = torchvision.models.resnet18(num_classes=10)
...     model = ComposerClassifier(ray.train.torch.prepare_model(model))
...     # prepare train/test dataset
...     mean = (0.507, 0.487, 0.441)
...     std = (0.267, 0.256, 0.276)
...     cifar10_transforms = transforms.Compose(
...         [transforms.ToTensor(), transforms.Normalize(mean, std)]
...     )
...     data_directory = "~/data"
...     train_dataset = datasets.CIFAR10(
...         data_directory,
...         train=True,
...         download=True,
...         transform=cifar10_transforms
...     )
...     # prepare train dataloader
...     batch_size_per_worker = BATCH_SIZE // session.get_world_size()
...     train_dataloader = torch.utils.data.DataLoader(
...         train_dataset,
...         batch_size=batch_size_per_worker
...     )
...     train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
...     # prepare optimizer
...     optimizer = composer.optim.DecoupledSGDW(
...         model.parameters(),
...         lr=0.05,
...         momentum=0.9,
...         weight_decay=2.0e-3,
...     )
...     return composer.trainer.Trainer(
...         model=model,
...         train_dataloader=train_dataloader,
...         optimizers=optimizer,
...         **config
...     )
>>> scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
>>> trainer_init_config = {
...     "max_duration": "1ba",
...     "algorithms": [LabelSmoothing()],
... } 
>>> trainer = MosaicTrainer(
...     trainer_init_per_worker=trainer_init_per_worker,
...     trainer_init_config=trainer_init_config,
...     scaling_config=scaling_config,
... ) 
>>> trainer.fit() 
  • trainer_init_per_worker – The function that returns an instantiated composer.Trainer object and takes in configuration dictionary (config) as an argument. This dictionary is based on trainer_init_config and is modified for Ray - Composer integration.

  • datasets – Any Ray Datasets to use for training. At the moment, we do not support passing datasets to the trainer and using the dataset shards in the trainer loop. Instead, configure and load the datasets inside trainer_init_per_worker function

  • trainer_init_config – Configurations to pass into trainer_init_per_worker as kwargs. Although the kwargs can be hard-coded in the trainer_init_per_worker, using the config allows the flexibility of reusing the same worker init function while changing the trainer arguments. For example, when hyperparameter tuning you can reuse the same trainer_init_per_worker function with different hyperparameter values rather than having multiple trainer_init_per_worker functions with different hard-coded hyperparameter values.

  • torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the backend_config arg of DataParallelTrainer. Same as in TorchTrainer.

  • scaling_config – Configuration for how to scale data parallel training.

  • dataset_config – Configuration for dataset ingest.

  • run_config – Configuration for the execution of the training run.

  • preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A MosiacCheckpoint to resume training from.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

classmethod restore(**kwargs) ray.train.mosaic.mosaic_trainer.MosaicTrainer[source]#

Restores a DataParallelTrainer from a previously interrupted/failed run.

  • train_loop_per_worker – Optionally re-specified train loop function. This should be used to re-specify a function that is not restorable in a new Ray cluster (e.g., it holds onto outdated object references). This should be the same training loop that was passed to the original trainer constructor.

  • train_loop_config – Optionally re-specified train config. This should similarly be used if the original train_loop_config contained outdated object references, and it should not be modified from what was originally passed in.

See BaseTrainer.restore() for descriptions of the other arguments.


A restored instance of the DataParallelTrainer subclass that is calling this method.

Return type