ray.train.v2.jax.JaxTrainer#
- class ray.train.v2.jax.JaxTrainer(train_loop_per_worker: Callable[[], None] | Callable[[Dict], None], *, train_loop_config: Dict | None = None, jax_config: JaxConfig | None = None, scaling_config: ScalingConfig | None = None, dataset_config: Dict[str, DataConfig] | None = None, run_config: RunConfig | None = None, datasets: Dict[str, Dataset | Callable[[], Dataset]] | None = None, resume_from_checkpoint: Checkpoint | None = None)[source]#
Bases:
DataParallelTrainer
A Trainer for Single-Program Multi-Data (SPMD) JAX training. Currently only supports TPUs. GPUs will be supported in a future version.
This Trainer runs the function
train_loop_per_worker
on multiple Ray Actors. These actors are expected to be scheduled on TPU VMs within the same TPU slice, connected via inter-chip interconnects (ICI). Thetrain_loop_per_worker
function is expected to take in either 0 or 1 arguments:import os from absl import app import logging from typing import Sequence import ray from ray.train.v2.api.config import ScalingConfig, RunConfig from ray.train.v2.jax import JaxTrainer from MaxText.train import main as maxtext_main def train_loop_per_worker(config): argv = config["argv"] maxtext_main(argv) def main(argv: Sequence[str]): ray.init() trainer = JaxTrainer( train_loop_per_worker=train_loop_per_worker, train_loop_config={"argv": absolute_argv}, scaling_config=ScalingConfig( use_tpu=True, num_workers=4, topology="4x4", accelerator_type="TPU-V6E", resources_per_worker={"TPU": 4}, placement_strategy="SPREAD", ), run_config=RunConfig( name="maxtext_jaxtrainer", worker_runtime_env={ "env_vars": { "JAX_PLATFORMS": "tpu", "ENABLE_PJRT_COMPATIBILITY": "true", "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true", "TPU_SLICE_BUILDER_DUMP_ICI": "true", "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto", } }, ), ) result = trainer.fit()
If
train_loop_per_worker
accepts an argument, thentrain_loop_config
will be passed in as the argument.If the
datasets
dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed bysession.get_dataset_shard("train")
.Note
Only TPU-based distributed training is supported.
Each worker must be assigned one TPU device via
resources_per_worker={"TPU": 1}
.Placement strategy is automatically set to
SPREAD
to ensure TPU workers are placed on separate VMs.Importing
jax
should occur withintrain_loop_per_worker
to avoid driver-side TPU lock issues.
- Parameters:
train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single
Dict
argument which is set by definingtrain_loop_config
. Within this function you can use any of the Ray Train Loop utilities.train_loop_config – A configuration
Dict
to pass in as an argument totrain_loop_per_worker
. This is typically used for specifying hyperparameters. Passing large datasets viatrain_loop_config
is not recommended and may introduce large overhead and unknown issues with serialization and deserialization.jax_config – The configuration for setting up the JAX backend. If set to None, a default configuration with TPUs will be used.
scaling_config – Configuration for how to scale data parallel training with SPMD.
num_workers
should be set to the number of TPU hosts andtopology
should be set to the TPU topology. SeeScalingConfig
for more info.dataset_config – The configuration for ingesting the input
datasets
. By default, all the Ray Dataset are split equally across workers. SeeDataConfig
for more details.run_config – The configuration for the execution of the training run. See
RunConfig
for more info.datasets – The Ray Datasets to ingest for training. Datasets are keyed by name (
{name: dataset}
). Each dataset can be accessed from within thetrain_loop_per_worker
by callingray.train.get_dataset_shard(name)
. Sharding and additional configuration can be done by passing in adataset_config
.resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within
train_loop_per_worker
by callingray.train.get_checkpoint()
.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
[Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run.
Launches the Ray Train controller to run training on workers.
[Deprecated] Restores a Train experiment from a previously interrupted/failed run.