Distributed PyTorch Lightning Training on Ray¶
This library adds new PyTorch Lightning plugins for distributed training using the Ray distributed computing framework.
These PyTorch Lightning Plugins on Ray enable quick and easy parallel training while still leveraging all the benefits of PyTorch Lightning and using your desired training protocol, either PyTorch Distributed Data Parallel or Horovod.
Once you add your plugin to the PyTorch Lightning Trainer, you can parallelize training to all the cores in your laptop, or across a massive multi-node, multi-GPU cluster with no additional code changes.
This library also comes with an integration with Ray Tune for distributed hyperparameter tuning experiments.
Installation¶
You can install Ray Lightning via pip
:
pip install ray_lightning
Or to install master:
pip install git+https://github.com/ray-project/ray_lightning#ray_lightning
PyTorch Distributed Data Parallel Plugin on Ray¶
The RayPlugin
provides Distributed Data Parallel training on a Ray cluster. PyTorch DDP is used as the distributed training protocol, and Ray is used to launch and manage the training worker processes.
Here is a simplified example:
import pytorch_lightning as pl
from ray_lightning import RayPlugin
# Create your PyTorch Lightning model here.
ptl_model = MNISTClassifier(...)
plugin = RayPlugin(num_workers=4, num_cpus_per_worker=1, use_gpu=True)
# Don't set ``gpus`` in the ``Trainer``.
# The actual number of GPUs is determined by ``num_workers``.
trainer = pl.Trainer(..., plugins=[plugin])
trainer.fit(ptl_model)
Because Ray is used to launch processes, instead of the same script being called multiple times, you CAN use this plugin even in cases when you cannot use the standard DDPPlugin
such as
Jupyter Notebooks, Google Colab, Kaggle
Calling
fit
ortest
multiple times in the same script
Multi-node Distributed Training¶
Using the same examples above, you can run distributed training on a multi-node cluster with just 2 simple steps.
1) Use Ray’s cluster launcher to start a Ray cluster- ray up my_cluster_config.yaml
.
2) Execute your Python script on the Ray cluster- ray submit my_cluster_config.yaml train.py
. This will rsync
your training script to the head node, and execute it on the Ray cluster. (Note: The training script can also be executed using Ray Job Submission,
which is in beta starting with Ray 1.12. Try it out!)
You no longer have to set environment variables or configurations and run your training script on every single node.
Multi-node Interactive Training from your Laptop¶
Ray provides capabilities to run multi-node and GPU training all from your laptop through Ray Client
You can follow the instructions here to set up the cluster. Then, add this line to the beginning of your script to connect to the cluster:
# replace with the appropriate host and port
ray.init("ray://<head_node_host>:10001")
Now you can run your training script on the laptop, but have it execute as if your laptop has all the resources of the cluster essentially providing you with an infinite laptop.
Note: When using with Ray Client, you must disable checkpointing and logging for your Trainer by setting checkpoint_callback
and logger
to False
.
Horovod Plugin on Ray¶
Or if you prefer to use Horovod as the distributed training protocol, use the HorovodRayPlugin
instead.
import pytorch_lightning as pl
from ray_lightning import HorovodRayPlugin
# Create your PyTorch Lightning model here.
ptl_model = MNISTClassifier(...)
# 2 nodes, 4 workers per node, each using 1 CPU and 1 GPU.
plugin = HorovodRayPlugin(num_hosts=2, num_slots=4, use_gpu=True)
# Don't set ``gpus`` in the ``Trainer``.
# The actual number of GPUs is determined by ``num_slots``.
trainer = pl.Trainer(..., plugins=[plugin])
trainer.fit(ptl_model)
Model Parallel Sharded Training on Ray¶
The RayShardedPlugin
integrates with FairScale to provide sharded DDP training on a Ray cluster.
With sharded training, leverage the scalability of data parallel training while drastically reducing memory usage when training large models.
import pytorch_lightning as pl
from ray_lightning import RayShardedPlugin
# Create your PyTorch Lightning model here.
ptl_model = MNISTClassifier(...)
plugin = RayShardedPlugin(num_workers=4, num_cpus_per_worker=1, use_gpu=True)
# Don't set ``gpus`` in the ``Trainer``.
# The actual number of GPUs is determined by ``num_workers``.
trainer = pl.Trainer(..., plugins=[plugin])
trainer.fit(ptl_model)
See the Pytorch Lightning docs for more information on sharded training.
Hyperparameter Tuning with Ray Tune¶
ray_lightning
also integrates with Ray Tune to provide distributed hyperparameter tuning for your distributed model training. You can run multiple PyTorch Lightning training runs in parallel, each with a different hyperparameter configuration, and each training run parallelized by itself. All you have to do is move your training code to a function, pass the function to tune.run, and make sure to add the appropriate callback (Either TuneReportCallback
or TuneReportCheckpointCallback
) to your PyTorch Lightning Trainer.
Example using ray_lightning
with Tune:
from ray import tune
from ray_lightning import RayPlugin
from ray_lightning.tune import TuneReportCallback, get_tune_ddp_resources
def train_mnist(config):
# Create your PTL model.
model = MNISTClassifier(config)
# Create the Tune Reporting Callback
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
callbacks = [TuneReportCallback(metrics, on="validation_end")]
trainer = pl.Trainer(
max_epochs=4,
callbacks=callbacks,
plugins=[RayPlugin(num_workers=4, use_gpu=False)])
trainer.fit(model)
config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
}
# Make sure to pass in ``resources_per_trial`` using the ``get_tune_ddp_resources`` utility.
analysis = tune.run(
train_mnist,
metric="loss",
mode="min",
config=config,
num_samples=num_samples,
resources_per_trial=get_tune_ddp_resources(num_workers=4),
name="tune_mnist")
print("Best hyperparameters found were: ", analysis.best_config)
FAQ¶
RaySGD already has a Pytorch Lightning integration. What’s the difference between this integration and that?
The key difference is which Trainer you’ll be interacting with. In this library, you will still be using Pytorch Lightning’s Trainer
. You’ll be able to leverage all the features of Pytorch Lightning, and Ray is used just as a backend to handle distributed training.
With RaySGD’s integration, you’ll be converting your LightningModule
to be RaySGD compatible, and will be interacting with RaySGD’s TorchTrainer
. RaySGD’s TorchTrainer
is not as feature rich nor as easy to use as Pytorch Lightning’s Trainer
(no built in support for logging, early stopping, etc.). However, it does have built in support for fault-tolerant and elastic training. If these are hard requirements for you, then RaySGD’s integration with PTL might be a better option.
I see that
RayPlugin
is based off of Pytorch Lightning’sDDPSpawnPlugin
. However, doesn’t the PTL team discourage the use of spawn?
As discussed here, using a spawn approach instead of launch is not all that detrimental. The original factors for discouraging spawn were:
not being able to use ‘spawn’ in a Jupyter or Colab notebook, and
not being able to use multiple workers for data loading.
Neither of these should be an issue with the RayPlugin
due to Ray’s serialization mechanisms. The only thing to keep in mind is that when using this plugin, your model does have to be serializable/pickleable.
API Reference¶
- class ray_lightning.RayPlugin(*args, **kw)[source]¶
Pytorch Lightning plugin for DDP training on a Ray cluster.
This plugin is used to manage distributed training using DDP and Ray for process launching. Internally, the specified number of Ray actors are launched in the cluster and are registered as part of a Pytorch DDP process group. The Pytorch Lightning trainer is instantiated on the driver and sent to each of these training workers where training is executed. The distributed training protocol is handled by Pytorch DDP.
Each training worker is configured to reserve
num_cpus_per_worker
CPUS and 1 GPU ifuse_gpu
is set toTrue
.If using this plugin, you should run your code like a normal Python script:
python train.py
, and only on the head node if running in a distributed Ray cluster. There is no need to run this script on every single node.- Parameters
num_workers (int) – Number of training workers to use.
num_cpus_per_worker (int) – Number of CPUs per worker.
use_gpu (bool) – Whether to use GPU for allocation. For GPU to be used, you must also set the
gpus
arg in your Pytorch Lightning Trainer to a value > 0.init_hook (Callable) – A function to run on each worker upon instantiation.
resources_per_worker (Optional[Dict]) – If specified, the resources defined in this Dict will be reserved for each worker. The
CPU
andGPU
keys (case-sensitive) can be defined to override the number of CPU/GPUs used by each worker.**ddp_kwargs – Additional arguments to pass into
DistributedDataParallel
initialization
Example
import pytorch_lightning as ptl from ray_lightning import RayAccelerator ptl_model = MNISTClassifier(...) plugin = RayPlugin(num_workers=4, cpus_per_worker=1, use_gpu=True) # Don't set ``gpus`` in ``Trainer``. # The actual number of GPUs is determined by ``num_workers``. trainer = pl.Trainer(..., plugins=[plugin]) trainer.fit(ptl_model)
PublicAPI (beta): This API is in beta and may change before becoming stable.
- class ray_lightning.HorovodRayPlugin(*args, **kw)[source]¶
Pytorch Lightning Plugin for Horovod training on a Ray cluster.
This plugin is used to manage distributed training on a Ray cluster via the Horovod training framework. Internally, the specified number of Ray actors are launched in the cluster and are configured as part of the Horovod ring. The Pytorch Lightning trainer is instantiated on the driver and sent to each of these training workers where training is executed. The distributed training protocol is handled by Horovod.
Each training worker is configured to reserve 1 CPU and if 1 GPU if
use_gpu
is set toTrue
.If using this plugin, you should run your code like a normal Python script:
python train.py
, and not withhorovodrun
.- Parameters
num_workers (int) – Number of training workers to use.
num_cpus_per_worker (int) – Number of CPUs per worker.
use_gpu (bool) – Whether to use GPU for allocation. For GPU to be used, you must also set the
gpus
arg in your Pytorch Lightning Trainer to a value > 0.
Example
import pytorch_lightning as ptl from ray_lightning import HorovodRayPlugin ptl_model = MNISTClassifier(...) plugin = HorovodRayPlugin(num_workers=2, use_gpu=True) # Don't set ``gpus`` in ``Trainer``. # The actual number of GPUs is determined by ``num_workers``. trainer = pl.Trainer(..., plugins=[plugin]) trainer.fit(ptl_model)
PublicAPI (beta): This API is in beta and may change before becoming stable.
- class ray_lightning.RayShardedPlugin(*args, **kw)[source]¶
PublicAPI (beta): This API is in beta and may change before becoming stable.
Tune Integration¶
- class ray_lightning.tune.TuneReportCallback(*args, **kw)[source]¶
Distributed PyTorch Lightning to Ray Tune reporting callback
Reports metrics to Ray Tune, specifically when training is done remotely with Ray via the accelerators in this library.
- Args:
- metrics (str|list|dict): Metrics to report to Tune.
If this is a list, each item describes the metric key reported to PyTorch Lightning, and it will reported under the same name to Tune. If this is a dict, each key will be the name reported to Tune and the respective value will be the metric key reported to PyTorch Lightning.
- on (str|list): When to trigger checkpoint creations.
Must be one of the PyTorch Lightning event hooks (less the
on_
), e.g. “batch_start”, or “train_end”. Defaults to “validation_end”.
Example:
import pytorch_lightning as pl from ray_lightning import RayPlugin from ray_lightning.tune import TuneReportCallback # Create plugin. ray_plugin = RayPlugin(num_workers=4, use_gpu=True) # Report loss and accuracy to Tune after each validation epoch: trainer = pl.Trainer(plugins=[ray_plugin], callbacks=[ TuneReportCallback(["val_loss", "val_acc"], on="validation_end")]) # Same as above, but report as `loss` and `mean_accuracy`: trainer = pl.Trainer(plugins=[ray_plugin], callbacks=[ TuneReportCallback( {"loss": "val_loss", "mean_accuracy": "val_acc"}, on="validation_end")])
PublicAPI (beta): This API is in beta and may change before becoming stable.
- class ray_lightning.tune.TuneReportCheckpointCallback(*args, **kw)[source]¶
PyTorch Lightning to Tune reporting and checkpointing callback.
Saves checkpoints after each validation step. Also reports metrics to Tune, which is needed for checkpoint registration. To be used specifically with the plugins in this library.
- Args:
- metrics (str|list|dict): Metrics to report to Tune.
If this is a list, each item describes the metric key reported to PyTorch Lightning, and it will reported under the same name to Tune. If this is a dict, each key will be the name reported to Tune and the respective value will be the metric key reported to PyTorch Lightning.
- filename (str): Filename of the checkpoint within the
checkpoint directory. Defaults to “checkpoint”.
- on (str|list): When to trigger checkpoint creations. Must be
one of the PyTorch Lightning event hooks (less the
on_
), e.g. “batch_start”, or “train_end”. Defaults to “validation_end”.
Example:
import pytorch_lightning as pl from ray_lightning import RayPlugin from ray_lightning.tune import TuneReportCheckpointCallback. # Create the Ray plugin. ray_plugin = RayPlugin() # Save checkpoint after each training batch and after each # validation epoch. trainer = pl.Trainer(plugins=[ray_plugin], callbacks=[ TuneReportCheckpointCallback( metrics={"loss": "val_loss", "mean_accuracy": "val_acc"}, filename="trainer.ckpt", on="validation_end")])
PublicAPI (beta): This API is in beta and may change before becoming stable.
- ray_lightning.tune.get_tune_resources(num_workers: int = 1, num_cpus_per_worker: int = 1, use_gpu: bool = False, cpus_per_worker: Optional[int] = None) Dict[str, int] [source]¶
Returns the PlacementGroupFactory to use for Ray Tune. PublicAPI (beta): This API is in beta and may change before becoming stable.