Ray Train Overview#

To use Ray Train effectively, you need to understand four main concepts:

  1. Training function: A Python function that contains your model training logic.

  2. Worker: A process that runs the training function.

  3. Scaling configuration: A configuration of the number of workers and compute resources (for example, CPUs or GPUs).

  4. Trainer: A Python class that ties together the training function, workers, and scaling configuration to execute a distributed training job.


Training function#

The training function is a user-defined Python function that contains the end-to-end model training loop logic. When launching a distributed training job, each worker executes this training function.

Ray Train documentation uses the following conventions:

  1. train_func is a user-defined function that contains the training code.

  2. train_func is passed into the Trainer’s train_loop_per_worker parameter.

def train_func():
    """User-defined training function that runs on each distributed worker process.

    This function typically contains logic for loading the model,
    loading the dataset, training the model, saving checkpoints,
    and logging metrics.


Ray Train distributes model training compute to individual worker processes across the cluster. Each worker is a process that executes the train_func. The number of workers determines the parallelism of the training job and is configured in the ScalingConfig.

Scaling configuration#

The ScalingConfig is the mechanism for defining the scale of the training job. Specify two basic parameters for worker parallelism and compute resources:

  • num_workers: The number of workers to launch for a distributed training job.

  • use_gpu: Whether each worker should use a GPU or CPU.

from ray.train import ScalingConfig

# Single worker with a CPU
scaling_config = ScalingConfig(num_workers=1, use_gpu=False)

# Single worker with a GPU
scaling_config = ScalingConfig(num_workers=1, use_gpu=True)

# Multiple workers, each with a GPU
scaling_config = ScalingConfig(num_workers=4, use_gpu=True)


The Trainer ties the previous three concepts together to launch distributed training jobs. Ray Train provides Trainer classes for different frameworks. Calling the fit() method executes the training job by:

  1. Launching workers as defined by the scaling_config.

  2. Setting up the framework’s distributed environment on all workers.

  3. Running the train_func on all workers.

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(train_func, scaling_config=scaling_config)