Ray Train Overview
Ray Train Overview#
To use Ray Train effectively, you need to understand four main concepts:
Training function: A Python function that contains your model training logic.
Worker: A process that runs the training function.
Scaling configuration: A configuration of the number of workers and compute resources (for example, CPUs or GPUs).
Trainer: A Python class that ties together the training function, workers, and scaling configuration to execute a distributed training job.
The training function is a user-defined Python function that contains the end-to-end model training loop logic. When launching a distributed training job, each worker executes this training function.
Ray Train documentation uses the following conventions:
train_funcis a user-defined function that contains the training code.
train_funcis passed into the Trainer’s
def train_func(): """User-defined training function that runs on each distributed worker process. This function typically contains logic for loading the model, loading the dataset, training the model, saving checkpoints, and logging metrics. """ ...
Ray Train distributes model training compute to individual worker processes across the cluster.
Each worker is a process that executes the
The number of workers determines the parallelism of the training job and is configured in the
ScalingConfig is the mechanism for defining the scale of the training job.
Specify two basic parameters for worker parallelism and compute resources:
num_workers: The number of workers to launch for a distributed training job.
use_gpu: Whether each worker should use a GPU or CPU.
from ray.train import ScalingConfig # Single worker with a CPU scaling_config = ScalingConfig(num_workers=1, use_gpu=False) # Single worker with a GPU scaling_config = ScalingConfig(num_workers=1, use_gpu=True) # Multiple workers, each with a GPU scaling_config = ScalingConfig(num_workers=4, use_gpu=True)
The Trainer ties the previous three concepts together to launch distributed training jobs.
Ray Train provides Trainer classes for different frameworks.
fit() method executes the training job by:
Launching workers as defined by the scaling_config.
Setting up the framework’s distributed environment on all workers.
train_funcon all workers.
from ray.train.torch import TorchTrainer trainer = TorchTrainer(train_func, scaling_config=scaling_config) trainer.fit()