ray.train.v2.jax.JaxTrainer#
- class ray.train.v2.jax.JaxTrainer(train_loop_per_worker: Callable[[], None] | Callable[[Dict], None], *, train_loop_config: Dict | None = None, jax_config: JaxConfig | None = None, scaling_config: ScalingConfig | None = None, dataset_config: Dict[str, DataConfig] | None = None, run_config: RunConfig | None = None, datasets: Dict[str, Dataset | Callable[[], Dataset]] | None = None, validation_config: ValidationConfig | None = None)[source]#
Bases:
DataParallelTrainerA Trainer for Single-Program Multi-Data (SPMD) JAX training.
At a high level, this Trainer does the following:
Launches multiple workers as defined by the
scaling_config.Sets up a distributed JAX environment for TPUs or GPUs on these workers as defined by the
jax_config.Ingests the input
datasetsbased on thedataset_config.Runs the input
train_loop_per_worker(train_loop_config)on all workers.
For more details, see:
import os from absl import app import logging from typing import Sequence import ray from ray.train import ScalingConfig, RunConfig from ray.train.v2.jax import JaxTrainer from MaxText.train import main as maxtext_main def train_loop_per_worker(config): argv = config["argv"] maxtext_main(argv) def main(argv: Sequence[str]): ray.init() # If you want to use TPUs, specify the TPU topology and accelerator type. tpu_scaling_config = ScalingConfig( use_tpu=True, num_workers=4, topology="4x4", accelerator_type="TPU-V6E", placement_strategy="SPREAD", resources_per_worker={"TPU": 4}, ) # If you want to use GPUs, specify the GPU scaling config like below. # gpu_scaling_config = ScalingConfig( # use_gpu=True, # num_workers=4, # resources_per_worker={"GPU": 1}, # ) trainer = JaxTrainer( train_loop_per_worker=train_loop_per_worker, train_loop_config={"argv": absolute_argv}, scaling_config=tpu_scaling_config, run_config=RunConfig( name="maxtext_jaxtrainer", worker_runtime_env={ "env_vars": { "JAX_PLATFORMS": "tpu", # If you want to use GPUs, set the JAX_PLATFORMS to "cuda". # "JAX_PLATFORMS": "cuda", } }, ), ) result = trainer.fit()
If the
datasetsdict contains datasets (e.g. “train” and “val”), then it will be split into multiple dataset shards that can then be accessed byray.train.get_dataset_shard("train")andray.train.get_dataset_shard("val").Note
If you are using TPUs, importing
jaxshould occur withintrain_loop_per_workerto avoid driver-side TPU lock issues.
- Parameters:
train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single
Dictargument which is set by definingtrain_loop_config. Within this function you can use any of the Ray Train Loop utilities.train_loop_config – A configuration
Dictto pass in as an argument totrain_loop_per_worker. This is typically used for specifying hyperparameters. Passing large datasets viatrain_loop_configis not recommended and may introduce large overhead and unknown issues with serialization and deserialization.jax_config – The configuration for setting up the JAX backend. If set to None, a default configuration will be used based on the
scaling_configandJAX_PLATFORMSenvironment variable.scaling_config – Configuration for how to scale data parallel training with SPMD.
num_workersshould be set to the number of TPU hosts or GPU workers. If using TPUs,topologyshould be set to the TPU topology. SeeScalingConfigfor more info.dataset_config – The configuration for ingesting the input
datasets. By default, all the Ray Dataset are split equally across workers. SeeDataConfigfor more details.run_config – The configuration for the execution of the training run. See
RunConfigfor more info.datasets – The Ray Datasets to ingest for training. Datasets are keyed by name (
{name: dataset}). Each dataset can be accessed from within thetrain_loop_per_workerby callingray.train.get_dataset_shard(name). Sharding and additional configuration can be done by passing in adataset_config.validation_config – [Alpha] Configuration for checkpoint validation. If provided and
ray.train.reportis called with thevalidationargument, Ray Train will validate the reported checkpoint using the validation function specified in this config.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
[Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run.
Launches the Ray Train controller to run training on workers.
[Deprecated] Restores a Train experiment from a previously interrupted/failed run.