ray.train.v2.jax.JaxTrainer#

class ray.train.v2.jax.JaxTrainer(train_loop_per_worker: Callable[[], None] | Callable[[Dict], None], *, train_loop_config: Dict | None = None, jax_config: JaxConfig | None = None, scaling_config: ScalingConfig | None = None, dataset_config: Dict[str, DataConfig] | None = None, run_config: RunConfig | None = None, datasets: Dict[str, Dataset | Callable[[], Dataset]] | None = None, resume_from_checkpoint: Checkpoint | None = None)[source]#

Bases: DataParallelTrainer

A Trainer for Single-Program Multi-Data (SPMD) JAX training. Currently only supports TPUs. GPUs will be supported in a future version.

This Trainer runs the function train_loop_per_worker on multiple Ray Actors. These actors are expected to be scheduled on TPU VMs within the same TPU slice, connected via inter-chip interconnects (ICI). The train_loop_per_worker function is expected to take in either 0 or 1 arguments:

import os
from absl import app
import logging
from typing import Sequence

import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer
from MaxText.train import main as maxtext_main

def train_loop_per_worker(config):
    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    ray.init()

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": absolute_argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=4,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "env_vars": {
                    "JAX_PLATFORMS": "tpu",
                    "ENABLE_PJRT_COMPATIBILITY": "true",
                    "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                    "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                    "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                }
            },
        ),
    )

    result = trainer.fit()

If train_loop_per_worker accepts an argument, then train_loop_config will be passed in as the argument.

If the datasets dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed by session.get_dataset_shard("train").

Note

  • Only TPU-based distributed training is supported.

  • Each worker must be assigned one TPU device via resources_per_worker={"TPU": 1}.

  • Placement strategy is automatically set to SPREAD to ensure TPU workers are placed on separate VMs.

  • Importing jax should occur within train_loop_per_worker to avoid driver-side TPU lock issues.

Parameters:
  • train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single Dict argument which is set by defining train_loop_config. Within this function you can use any of the Ray Train Loop utilities.

  • train_loop_config – A configuration Dict to pass in as an argument to train_loop_per_worker. This is typically used for specifying hyperparameters. Passing large datasets via train_loop_config is not recommended and may introduce large overhead and unknown issues with serialization and deserialization.

  • jax_config – The configuration for setting up the JAX backend. If set to None, a default configuration with TPUs will be used.

  • scaling_config – Configuration for how to scale data parallel training with SPMD. num_workers should be set to the number of TPU hosts and topology should be set to the TPU topology. See ScalingConfig for more info.

  • dataset_config – The configuration for ingesting the input datasets. By default, all the Ray Dataset are split equally across workers. See DataConfig for more details.

  • run_config – The configuration for the execution of the training run. See RunConfig for more info.

  • datasets – The Ray Datasets to ingest for training. Datasets are keyed by name ({name: dataset}). Each dataset can be accessed from within the train_loop_per_worker by calling ray.train.get_dataset_shard(name). Sharding and additional configuration can be done by passing in a dataset_config.

  • resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within train_loop_per_worker by calling ray.train.get_checkpoint().

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

can_restore

[Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run.

fit

Launches the Ray Train controller to run training on workers.

restore

[Deprecated] Restores a Train experiment from a previously interrupted/failed run.