ray.train.v2.jax.JaxTrainer#

class ray.train.v2.jax.JaxTrainer(train_loop_per_worker: Callable[[], None] | Callable[[Dict], None], *, train_loop_config: Dict | None = None, jax_config: JaxConfig | None = None, scaling_config: ScalingConfig | None = None, dataset_config: Dict[str, DataConfig] | None = None, run_config: RunConfig | None = None, datasets: Dict[str, Dataset | Callable[[], Dataset]] | None = None, validation_config: ValidationConfig | None = None)[source]#

Bases: DataParallelTrainer

A Trainer for Single-Program Multi-Data (SPMD) JAX training.

At a high level, this Trainer does the following:

  1. Launches multiple workers as defined by the scaling_config.

  2. Sets up a distributed JAX environment for TPUs or GPUs on these workers as defined by the jax_config.

  3. Ingests the input datasets based on the dataset_config.

  4. Runs the input train_loop_per_worker(train_loop_config) on all workers.

For more details, see:

import os
from absl import app
import logging
from typing import Sequence

import ray
from ray.train import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer
from MaxText.train import main as maxtext_main

def train_loop_per_worker(config):
    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    ray.init()

    # If you want to use TPUs, specify the TPU topology and accelerator type.
    tpu_scaling_config = ScalingConfig(
        use_tpu=True,
        num_workers=4,
        topology="4x4",
        accelerator_type="TPU-V6E",
        placement_strategy="SPREAD",
        resources_per_worker={"TPU": 4},
    )

    # If you want to use GPUs, specify the GPU scaling config like below.
    # gpu_scaling_config = ScalingConfig(
    #     use_gpu=True,
    #     num_workers=4,
    #     resources_per_worker={"GPU": 1},
    # )


    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": absolute_argv},
        scaling_config=tpu_scaling_config,
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "env_vars": {
                    "JAX_PLATFORMS": "tpu",
                    # If you want to use GPUs, set the JAX_PLATFORMS to "cuda".
                    # "JAX_PLATFORMS": "cuda",
                }
            },
        ),
    )

    result = trainer.fit()

If the datasets dict contains datasets (e.g. “train” and “val”), then it will be split into multiple dataset shards that can then be accessed by ray.train.get_dataset_shard("train") and ray.train.get_dataset_shard("val").

Note

  • If you are using TPUs, importing jax should occur within train_loop_per_worker to avoid driver-side TPU lock issues.

Parameters:
  • train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single Dict argument which is set by defining train_loop_config. Within this function you can use any of the Ray Train Loop utilities.

  • train_loop_config – A configuration Dict to pass in as an argument to train_loop_per_worker. This is typically used for specifying hyperparameters. Passing large datasets via train_loop_config is not recommended and may introduce large overhead and unknown issues with serialization and deserialization.

  • jax_config – The configuration for setting up the JAX backend. If set to None, a default configuration will be used based on the scaling_config and JAX_PLATFORMS environment variable.

  • scaling_config – Configuration for how to scale data parallel training with SPMD. num_workers should be set to the number of TPU hosts or GPU workers. If using TPUs, topology should be set to the TPU topology. See ScalingConfig for more info.

  • dataset_config – The configuration for ingesting the input datasets. By default, all the Ray Dataset are split equally across workers. See DataConfig for more details.

  • run_config – The configuration for the execution of the training run. See RunConfig for more info.

  • datasets – The Ray Datasets to ingest for training. Datasets are keyed by name ({name: dataset}). Each dataset can be accessed from within the train_loop_per_worker by calling ray.train.get_dataset_shard(name). Sharding and additional configuration can be done by passing in a dataset_config.

  • validation_config – [Alpha] Configuration for checkpoint validation. If provided and ray.train.report is called with the validation argument, Ray Train will validate the reported checkpoint using the validation function specified in this config.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

can_restore

[Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run.

fit

Launches the Ray Train controller to run training on workers.

restore

[Deprecated] Restores a Train experiment from a previously interrupted/failed run.