ray.train.huggingface.HuggingFaceTrainer
ray.train.huggingface.HuggingFaceTrainer#
- class ray.train.huggingface.HuggingFaceTrainer(*args, **kwargs)[source]#
Bases:
ray.train.torch.torch_trainer.TorchTrainer
A Trainer for data parallel HuggingFace Transformers on PyTorch training.
This Trainer runs the
transformers.Trainer.train()
method on multiple Ray Actors. The training is carried out in a distributed fashion through PyTorch DDP. These actors already have the necessary torch process group already configured for distributed PyTorch training. If you have PyTorch >= 1.12.0 installed, you can also run FSDP training by specifying thefsdp
argument inTrainingArguments
. For more information on configuring FSDP, refer to Hugging Face documentation.The training function ran on every Actor will first run the specified
trainer_init_per_worker
function to obtain an instantiatedtransformers.Trainer
object. Thetrainer_init_per_worker
function will have access to preprocessed train and evaluation datasets.If the
datasets
dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards, with each Actor training on a single shard. All the other datasets will not be split.Please note that if you use a custom
transformers.Trainer
subclass, theget_train_dataloader
method will be wrapped around to disable sharding bytransformers.IterableDatasetShard
, as the dataset will already be sharded on the Ray AIR side.HuggingFace loggers will be automatically disabled, and the
local_rank
argument inTrainingArguments
will be automatically set. Please note that if you want to use CPU training, you will need to set theno_cuda
argument inTrainingArguments
manually - otherwise, an exception (segfault) may be thrown.This Trainer requires
transformers>=4.19.0
package.Example
# Based on # huggingface/notebooks/examples/language_modeling_from_scratch.ipynb # Hugging Face imports from datasets import load_dataset import transformers from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import ray from ray.train.huggingface import HuggingFaceTrainer from ray.air.config import ScalingConfig # If using GPUs, set this to True. use_gpu = False model_checkpoint = "gpt2" tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" block_size = 128 datasets = load_dataset("wikitext", "wikitext-2-raw-v1") tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) def tokenize_function(examples): return tokenizer(examples["text"]) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=1, remove_columns=["text"] ) def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: sum(examples[k], []) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model # supported it. # instead of this drop, you can customize this part to your needs. total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [ t[i : i + block_size] for i in range(0, total_length, block_size) ] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result lm_datasets = tokenized_datasets.map( group_texts, batched=True, batch_size=1000, num_proc=1, ) ray_train_ds = ray.data.from_huggingface(lm_datasets["train"]) ray_evaluation_ds = ray.data.from_huggingface( lm_datasets["validation"] ) def trainer_init_per_worker(train_dataset, eval_dataset, **config): model_config = AutoConfig.from_pretrained(model_checkpoint) model = AutoModelForCausalLM.from_config(model_config) args = transformers.TrainingArguments( output_dir=f"{model_checkpoint}-wikitext2", evaluation_strategy="epoch", save_strategy="epoch", logging_strategy="epoch", learning_rate=2e-5, weight_decay=0.01, no_cuda=(not use_gpu), ) return transformers.Trainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu) trainer = HuggingFaceTrainer( trainer_init_per_worker=trainer_init_per_worker, scaling_config=scaling_config, datasets={"train": ray_train_ds, "evaluation": ray_evaluation_ds}, ) result = trainer.fit()
- Parameters
trainer_init_per_worker – The function that returns an instantiated
transformers.Trainer
object and takes in the following arguments: trainTorch.Dataset
, optional evaluationTorch.Dataset
and config as kwargs. The Torch Datasets are automatically created by converting the Ray Datasets internally before they are passed into the function.datasets – Any Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset and (optionally) key “evaluation” to denote the evaluation dataset. Can only contain a training dataset and up to one extra dataset to be used for evaluation. If a
preprocessor
is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by thepreprocessor
if one is provided.trainer_init_config – Configurations to pass into
trainer_init_per_worker
as kwargs.torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the
backend_config
arg ofDataParallelTrainer
. Same as inTorchTrainer
.scaling_config – Configuration for how to scale data parallel training.
dataset_config – Configuration for dataset ingest.
run_config – Configuration for the execution of the training run.
preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.
resume_from_checkpoint – A checkpoint to resume training from.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- classmethod restore(path: str, trainer_init_per_worker: Optional[Callable[[torch.utils.data.dataset.Dataset, Optional[torch.utils.data.dataset.Dataset], Any], transformers.trainer.Trainer]] = None, trainer_init_config: Optional[Dict] = None, datasets: Optional[Dict[str, Union[Dataset, Callable[[], Dataset]]]] = None, preprocessor: Optional[Preprocessor] = None, scaling_config: Optional[ray.air.config.ScalingConfig] = None) HuggingFaceTrainer [source]#
Restores a HuggingFaceTrainer from a previously interrupted/failed run.
- Parameters
trainer_init_per_worker – Optionally re-specified trainer init function. This should be used to re-specify a function that is not restorable in a new Ray cluster (e.g., it holds onto outdated object references). This should be the same trainer init that was passed to the original trainer constructor.
trainer_init_config – Optionally re-specified trainer init config. This should similarly be used if the original
train_loop_config
contained outdated object references, and it should not be modified from what was originally passed in.
See
BaseTrainer.restore()
for descriptions of the other arguments.- Returns
A restored instance of
HuggingFaceTrainer
- Return type