import logging
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
from ray.air._internal.config import ensure_only_allowed_dataclass_keys_updated
from ray.train import DataConfig
from ray.train.trainer import GenDataset
from ray.train.v2.api.config import RunConfig, ScalingConfig
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
from ray.train.v2.api.validation_config import ValidationConfig
from ray.train.v2.jax.config import JaxConfig
from ray.util import PublicAPI
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
[docs]
@PublicAPI(stability="alpha")
class JaxTrainer(DataParallelTrainer):
"""A Trainer for Single-Program Multi-Data (SPMD) JAX training.
At a high level, this Trainer does the following:
1. Launches multiple workers as defined by the ``scaling_config``.
2. Sets up a distributed JAX environment for TPUs or GPUs
on these workers as defined by the ``jax_config``.
3. Ingests the input ``datasets`` based on the ``dataset_config``.
4. Runs the input ``train_loop_per_worker(train_loop_config)``
on all workers.
For more details, see:
* :ref:`Jax Guide <train-jax>`
.. testcode::
:skipif: True
import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer
from MaxText.train import main as maxtext_main
def train_loop_per_worker(config):
argv = config["argv"]
maxtext_main(argv)
def main(argv: Sequence[str]):
ray.init()
# If you want to use TPUs, specify the TPU topology and accelerator type.
tpu_scaling_config = ScalingConfig(
use_tpu=True,
num_workers=4,
topology="4x4",
accelerator_type="TPU-V6E",
placement_strategy="SPREAD",
resources_per_worker={"TPU": 4},
)
# If you want to use GPUs, specify the GPU scaling config like below.
# gpu_scaling_config = ScalingConfig(
# use_gpu=True,
# num_workers=4,
# resources_per_worker={"GPU": 1},
# )
trainer = JaxTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"argv": absolute_argv},
scaling_config=tpu_scaling_config,
run_config=RunConfig(
name="maxtext_jaxtrainer",
worker_runtime_env={
"env_vars": {
"JAX_PLATFORMS": "tpu",
# If you want to use GPUs, set the JAX_PLATFORMS to "cuda".
# "JAX_PLATFORMS": "cuda",
}
},
),
)
result = trainer.fit()
If the ``datasets`` dict contains datasets (e.g. "train" and "val"), then it will be split into multiple dataset
shards that can then be accessed by ``ray.train.get_dataset_shard("train")`` and ``ray.train.get_dataset_shard("val")``.
Note:
* If you are using TPUs, importing `jax` should occur within `train_loop_per_worker` to
avoid driver-side TPU lock issues.
Args:
train_loop_per_worker: The training function to execute on each worker.
This function can either take in zero arguments or a single ``Dict``
argument which is set by defining ``train_loop_config``.
Within this function you can use any of the
:ref:`Ray Train Loop utilities <train-loop-api>`.
train_loop_config: A configuration ``Dict`` to pass in as an argument to
``train_loop_per_worker``.
This is typically used for specifying hyperparameters. Passing large
datasets via `train_loop_config` is not recommended and may introduce
large overhead and unknown issues with serialization and deserialization.
jax_config: The configuration for setting up the JAX backend.
If set to None, a default configuration will be used based on the ``scaling_config`` and ``JAX_PLATFORMS`` environment variable.
scaling_config: Configuration for how to scale data parallel training
with SPMD. ``num_workers`` should be set to the number of TPU hosts or GPU workers.
If using TPUs, ``topology`` should be set to the TPU topology.
See :class:`~ray.train.ScalingConfig` for more info.
dataset_config: The configuration for ingesting the input ``datasets``.
By default, all the Ray Dataset are split equally across workers.
See :class:`~ray.train.DataConfig` for more details.
run_config: The configuration for the execution of the training run.
See :class:`~ray.train.RunConfig` for more info.
datasets: The Ray Datasets to ingest for training.
Datasets are keyed by name (``{name: dataset}``).
Each dataset can be accessed from within the ``train_loop_per_worker``
by calling ``ray.train.get_dataset_shard(name)``.
Sharding and additional configuration can be done by
passing in a ``dataset_config``.
validation_config: [Alpha] Configuration for checkpoint validation.
If provided and ``ray.train.report`` is called with the ``validation``
argument, Ray Train will validate the reported checkpoint using
the validation function specified in this config.
"""
def __init__(
self,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
*,
train_loop_config: Optional[Dict] = None,
jax_config: Optional[JaxConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[Dict[str, DataConfig]] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
validation_config: Optional[ValidationConfig] = None,
):
if not jax_config:
jax_config = JaxConfig(
use_tpu=scaling_config.use_tpu,
use_gpu=scaling_config.use_gpu,
)
super(JaxTrainer, self).__init__(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
backend_config=jax_config,
scaling_config=scaling_config,
dataset_config=dataset_config,
run_config=run_config,
datasets=datasets,
validation_config=validation_config,
)
@classmethod
def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
"""Return scaling config dataclass after validating updated keys."""
ensure_only_allowed_dataclass_keys_updated(
dataclass=scaling_config,
allowed_keys=cls._scaling_config_allowed_keys,
)
return scaling_config