{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GPT-J-6B Fine-Tuning with Ray Train and DeepSpeed\n",
"\n",
"\n",
" \n",
"\n",
" \n",
"\n",
"This example showcases how to use Ray Train for **GPT-J fine-tuning**. GPT-J is a GPT-2-like causal language model trained on the Pile dataset. This particular model has 6 billion parameters. For more information, see [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj).\n",
"\n",
"This example uses the Ray Train 🤗 Transformers integration and a pre-trained model from the Hugging Face Hub. Note that this example is adaptable to other similar models.\n",
"\n",
"This is an advanced example that focuses on the performance and distributed computing aspects of Ray Train. For a beginner-friendly introduction to the Ray Train 🤗 Transformers integration, see {ref}`Basic Example for HuggingFace Transformers `.\n",
"\n",
"Read [Ray Train Key Concepts](train-key-concepts) and [Ray Data Integration User Guides](data-ingest-torch) before starting this example.\n",
"\n",
"```{note}\n",
"To run this example, make sure your Ray cluster has access to at least one GPU with 16 or more GBs of memory. The required amount of memory depends on the model. This notebook is tested with 16 g4dn.4xlarge instances (including the head node).\n",
"```\n",
"\n",
"This notebook has the following steps:\n",
"1. [Set up Ray](#gptj-setup)\n",
"2. [Load the dataset](#gptj-load)\n",
"3. [Preprocess the dataset with Ray Data](#gptj-preprocess)\n",
"4. [Run the training with Ray Train](#gptj-train)\n",
"5. [Generate text from prompt](#gptj-predict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Uncomment and run the following line in order to install all the necessary dependencies (this notebook was tested with `accelerate=0.18.0`, `transformers==4.26.0`, `deepspeed==0.12.3`):"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"! pip install -q \"datasets\" \"evaluate\" \"accelerate==0.18.0\" \"transformers==4.26.0\" \"torch>=1.12.0\" \"deepspeed==0.12.3\""
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import os"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(gptj-setup)=\n",
"## Set up Ray\n",
"\n",
"First, let's set some global variables. We will use 16 workers, each being assigned 1 GPU and 8 CPUs."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model_name = \"EleutherAI/gpt-j-6B\"\n",
"use_gpu = True\n",
"num_workers = 16\n",
"cpus_per_worker = 8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use `ray.init()` to initialize a local cluster. By default, this cluster will be comprised of only the machine you are running this notebook on. You can also run this notebook on an Anyscale cluster.\n",
"\n",
"We define a {ref}`runtime environment ` to ensure that the Ray workers have access to all the necessary packages. You can omit the `runtime_env` argument if you have all of the packages already installed on each node in your cluster."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import ray\n",
"\n",
"ray.init(\n",
" runtime_env={\n",
" \"pip\": [\n",
" \"datasets\",\n",
" \"evaluate\",\n",
" # The latest combination accelerate==0.25.0, transformers==4.36.0, deepspeed==0.12.4\n",
" # has issues with DeepSpeed process group initialization,\n",
" # and will result in a batch_size validation problem.\n",
" # TODO(ml-team): get rid of the pins once the issue is fixed.\n",
" \"accelerate==0.18.0\",\n",
" \"transformers==4.26.0\",\n",
" \"torch>=1.12.0\",\n",
" \"deepspeed==0.12.3\",\n",
" ],\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"# THIS SHOULD BE HIDDEN IN DOCS AND ONLY RAN IN CI\n",
"# Download the model from our S3 mirror as it's faster\n",
"\n",
"import ray\n",
"import subprocess\n",
"import ray.util.scheduling_strategies\n",
"\n",
"\n",
"def force_on_node(node_id: str, remote_func_or_actor_class):\n",
" scheduling_strategy = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(\n",
" node_id=node_id, soft=False\n",
" )\n",
" options = {\"scheduling_strategy\": scheduling_strategy}\n",
" return remote_func_or_actor_class.options(**options)\n",
"\n",
"\n",
"def run_on_every_node(remote_func_or_actor_class, **remote_kwargs):\n",
" refs = []\n",
" for node in ray.nodes():\n",
" if node[\"Alive\"] and node[\"Resources\"].get(\"GPU\", None):\n",
" refs.append(\n",
" force_on_node(node[\"NodeID\"], remote_func_or_actor_class).remote(\n",
" **remote_kwargs\n",
" )\n",
" )\n",
" return ray.get(refs)\n",
"\n",
"\n",
"@ray.remote(num_gpus=1)\n",
"def download_model():\n",
" from transformers.utils.hub import TRANSFORMERS_CACHE\n",
"\n",
" path = os.path.expanduser(\n",
" os.path.join(TRANSFORMERS_CACHE, \"models--EleutherAI--gpt-j-6B\")\n",
" )\n",
" subprocess.run([\"mkdir\", \"-p\", os.path.join(path, \"snapshots\", \"main\")])\n",
" subprocess.run([\"mkdir\", \"-p\", os.path.join(path, \"refs\")])\n",
" if os.path.exists(os.path.join(path, \"refs\", \"main\")):\n",
" return\n",
" subprocess.run(\n",
" [\n",
" \"aws\",\n",
" \"s3\",\n",
" \"sync\",\n",
" \"--no-sign-request\",\n",
" \"s3://large-dl-models-mirror/models--EleutherAI--gpt-j-6B/main/\",\n",
" os.path.join(path, \"snapshots\", \"main\"),\n",
" ]\n",
" )\n",
" with open(os.path.join(path, \"snapshots\", \"main\", \"hash\"), \"r\") as f:\n",
" f_hash = f.read().strip()\n",
" with open(os.path.join(path, \"refs\", \"main\"), \"w\") as f:\n",
" f.write(f_hash)\n",
" os.rename(\n",
" os.path.join(path, \"snapshots\", \"main\"), os.path.join(path, \"snapshots\", f_hash)\n",
" )\n",
"\n",
"\n",
"_ = run_on_every_node(download_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(gptj-load)=\n",
"## Loading the dataset\n",
"\n",
"We will be fine-tuning the model on the [`tiny_shakespeare` dataset](https://huggingface.co/datasets/tiny_shakespeare), comprised of 40,000 lines of Shakespeare from a variety of Shakespeare's plays. The aim will be to make the GPT-J model better at generating text in the style of Shakespeare."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"print(\"Loading tiny_shakespeare dataset\")\n",
"current_dataset = load_dataset(\"tiny_shakespeare\")\n",
"current_dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use [Ray Data](data) for distributed preprocessing and data ingestion. We can easily convert the dataset obtained from Hugging Face Hub to Ray Data by using {meth}`ray.data.from_huggingface`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'train': MaterializedDataset(num_blocks=1, num_rows=1, schema={text: string}),\n",
" 'validation': MaterializedDataset(num_blocks=1, num_rows=1, schema={text: string})}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ray.data\n",
"\n",
"ray_datasets = {\n",
" \"train\": ray.data.from_huggingface(current_dataset[\"train\"]),\n",
" \"validation\": ray.data.from_huggingface(current_dataset[\"validation\"]),\n",
"}\n",
"\n",
"ray_datasets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(gptj-preprocess)=\n",
"Note that the dataset is represented by a single line of large string, and needs some preprocessing. To do this, use the {meth}`~ray.data.Dataset.map_batches` API to apply transformation functions to batches of data.\n",
"\n",
"The `split_text` function takes the single string and splits it into separate lines, removing empty lines and character names ending with ':' (eg. 'ROMEO:'). The `tokenize` function takes the lines and tokenizes them using the 🤗 Tokenizer associated with the model, ensuring each entry has the same length (`block_size`) by padding and truncating. This preprocessing is necessary for training.\n",
"\n",
"```{note}\n",
"This preprocessing can be done in other ways. A common pattern is to tokenize first, and then split the obtained tokens into equally-sized blocks.\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"block_size = 512"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'train': MapBatches(tokenize)\n",
" +- MapBatches(split_text)\n",
" +- Dataset(num_blocks=1, num_rows=1, schema={text: string}),\n",
" 'validation': MapBatches(tokenize)\n",
" +- MapBatches(split_text)\n",
" +- Dataset(num_blocks=1, num_rows=1, schema={text: string})}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"\n",
"def split_text(batch: pd.DataFrame) -> pd.DataFrame:\n",
" text = list(batch[\"text\"])\n",
" flat_text = \"\".join(text)\n",
" split_text = [\n",
" x.strip()\n",
" for x in flat_text.split(\"\\n\")\n",
" if x.strip() and not x.strip()[-1] == \":\"\n",
" ]\n",
" return pd.DataFrame(split_text, columns=[\"text\"])\n",
"\n",
"\n",
"def tokenize(batch: pd.DataFrame) -> dict:\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
" ret = tokenizer(\n",
" list(batch[\"text\"]),\n",
" truncation=True,\n",
" max_length=block_size,\n",
" padding=\"max_length\",\n",
" return_tensors=\"np\",\n",
" )\n",
" ret[\"labels\"] = ret[\"input_ids\"].copy()\n",
" return dict(ret)\n",
"\n",
"\n",
"processed_datasets = {\n",
" key: (\n",
" ds.map_batches(split_text, batch_format=\"pandas\")\n",
" .map_batches(tokenize, batch_format=\"pandas\")\n",
" )\n",
" for key, ds in ray_datasets.items()\n",
"}\n",
"processed_datasets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(gptj-train)=\n",
"### Fine-tuning the model with Ray Train\n",
"\n",
"Configure Ray Train's {class}`~ray.train.torch.TorchTrainer` to perform distributed fine-tuning of the model. Specify a `train_loop_per_worker` function, which defines the training logic to be distributed by Ray using Distributed Data Parallelism, which uses the PyTorch Distributed backend internally. Each worker has its own copy of the model, but operates on different data. At the end of each step, all the workers sync gradients.\n",
"\n",
"Because GPT-J is a relatively large model, it may not be possible to fit it on smaller GPU types (<=16 GB GRAM). To deal with that issue, this example uses [DeepSpeed](https://github.com/microsoft/DeepSpeed), a library to optimize the training process and to offload and partition optimizer and parameter states, reducing GRAM usage. Furthermore, DeepSpeed ZeRO Stage 3 can load large models without running out of memory.\n",
"\n",
"🤗 Transformers and Ray Train's {ref}`integrations ` allow you to easily configure and use DDP and DeepSpeed. All you need to do is specify the DeepSpeed configuration in the [`TrainingArguments`](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) object.\n",
"\n",
"```{tip}\n",
"There are many DeepSpeed settings that allow you to trade-off speed for memory usage. The settings used below are tailored to the cluster setup used (16 g4dn.4xlarge nodes) and per device batch size of 16. Some things to keep in mind:\n",
"- If your GPUs support bfloat16, use that instead of float16 mixed precision to get better performance and prevent overflows. Replace `fp16=True` with `bf16=True` in `TrainingArguments`.\n",
"- If you are running out of GRAM: try reducing batch size (defined in the cell below the next one), set `\"overlap_comm\": False` in DeepSpeed config.\n",
"- If you are running out of RAM, add more nodes to your cluster, use nodes with more RAM, set `\"pin_memory\": False` in the DeepSpeed config, reduce the batch size, and remove `\"offload_param\"` from the DeepSpeed config.\n",
"\n",
"For more information on DeepSpeed configuration, refer to [Hugging Face documentation](https://huggingface.co/docs/transformers/main_classes/deepspeed) and [DeepSpeed documentation](https://www.deepspeed.ai/docs/config-json/).\n",
"\n",
"Additionally, if you prefer a lower-level API, the logic below can be expressed as an [Accelerate training loop](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/deepspeed_with_config_support.py) distributed by a Ray Train {class}`~ray.train.torch.torch_trainer.TorchTrainer`.\n",
"```\n",
"\n",
"#### Training speed\n",
"\n",
"As this example uses data parallelism, each worker operates on its own shard of the data. The batch size set in `train_ds.iter_torch_batches` is the **per device batch size** (per worker batch size). By changing the number of workers, you can change the **effective batch size** and thus the time needed for training to complete. Calculate the effective batch size as `per device batch size * number of workers * number of gradient accumulation steps`. As you add more workers, the effective batch size rises and thus less time is needed to complete a full epoch. While the speedup is not exactly linear due to extra communication overheads, in many cases it can be close to linear.\n",
"\n",
"The preprocessed dataset has 1348 examples. We have set per device batch size to 16.\n",
"\n",
"* With 16 g4dn.4xlarge nodes, the effective batch size was 256, which equals to 85 steps per epoch. One epoch took **~2440 seconds** (including initialization time).\n",
"\n",
"* With 32 g4dn.4xlarge nodes, the effective batch size was 512, which equals to 43 steps per epoch. One epoch took **~1280 seconds** (including initialization time)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import evaluate\n",
"import torch\n",
"from transformers import (\n",
" Trainer,\n",
" TrainingArguments,\n",
" GPTJForCausalLM,\n",
" AutoTokenizer,\n",
" default_data_collator,\n",
")\n",
"from transformers.utils.logging import disable_progress_bar, enable_progress_bar\n",
"\n",
"from ray import train\n",
"from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback\n",
"\n",
"\n",
"def train_func(config):\n",
" # Use the actual number of CPUs assigned by Ray\n",
" os.environ[\"OMP_NUM_THREADS\"] = str(\n",
" train.get_context().get_trial_resources().bundles[-1].get(\"CPU\", 1)\n",
" )\n",
" # Enable tf32 for better performance\n",
" torch.backends.cuda.matmul.allow_tf32 = True\n",
"\n",
" batch_size = config.get(\"batch_size\", 4)\n",
" epochs = config.get(\"epochs\", 2)\n",
" warmup_steps = config.get(\"warmup_steps\", 0)\n",
" learning_rate = config.get(\"learning_rate\", 0.00002)\n",
" weight_decay = config.get(\"weight_decay\", 0.01)\n",
" steps_per_epoch = config.get(\"steps_per_epoch\")\n",
"\n",
" deepspeed = {\n",
" \"fp16\": {\n",
" \"enabled\": \"auto\",\n",
" \"initial_scale_power\": 8,\n",
" \"hysteresis\": 4,\n",
" \"consecutive_hysteresis\": True,\n",
" },\n",
" \"bf16\": {\"enabled\": \"auto\"},\n",
" \"optimizer\": {\n",
" \"type\": \"AdamW\",\n",
" \"params\": {\n",
" \"lr\": \"auto\",\n",
" \"betas\": \"auto\",\n",
" \"eps\": \"auto\",\n",
" },\n",
" },\n",
" \"zero_optimization\": {\n",
" \"stage\": 3,\n",
" \"offload_optimizer\": {\n",
" \"device\": \"cpu\",\n",
" \"pin_memory\": True,\n",
" },\n",
" \"overlap_comm\": True,\n",
" \"contiguous_gradients\": True,\n",
" \"reduce_bucket_size\": \"auto\",\n",
" \"stage3_prefetch_bucket_size\": \"auto\",\n",
" \"stage3_param_persistence_threshold\": \"auto\",\n",
" \"gather_16bit_weights_on_model_save\": True,\n",
" \"round_robin_gradients\": True,\n",
" },\n",
" \"gradient_accumulation_steps\": \"auto\",\n",
" \"gradient_clipping\": \"auto\",\n",
" \"steps_per_print\": 10,\n",
" \"train_batch_size\": \"auto\",\n",
" \"train_micro_batch_size_per_gpu\": \"auto\",\n",
" \"wall_clock_breakdown\": False,\n",
" }\n",
"\n",
" print(\"Preparing training arguments\")\n",
" training_args = TrainingArguments(\n",
" \"output\",\n",
" logging_steps=1,\n",
" save_strategy=\"steps\",\n",
" save_steps=steps_per_epoch,\n",
" max_steps=steps_per_epoch * epochs,\n",
" per_device_train_batch_size=batch_size,\n",
" gradient_accumulation_steps=1,\n",
" learning_rate=learning_rate,\n",
" weight_decay=weight_decay,\n",
" warmup_steps=warmup_steps,\n",
" label_names=[\"input_ids\", \"attention_mask\"],\n",
" push_to_hub=False,\n",
" report_to=\"none\",\n",
" disable_tqdm=True, # declutter the output a little\n",
" fp16=True,\n",
" gradient_checkpointing=True,\n",
" deepspeed=deepspeed,\n",
" )\n",
" disable_progress_bar()\n",
"\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
" print(\"Loading model\")\n",
"\n",
" model = GPTJForCausalLM.from_pretrained(model_name, use_cache=False)\n",
" model.resize_token_embeddings(len(tokenizer))\n",
"\n",
" print(\"Model loaded\")\n",
"\n",
" enable_progress_bar()\n",
"\n",
" metric = evaluate.load(\"accuracy\")\n",
"\n",
" train_ds = train.get_dataset_shard(\"train\")\n",
" eval_ds = train.get_dataset_shard(\"validation\")\n",
"\n",
" train_ds_iterable = train_ds.iter_torch_batches(\n",
" batch_size=batch_size,\n",
" local_shuffle_buffer_size=train.get_context().get_world_size() * batch_size,\n",
" )\n",
" eval_ds_iterable = eval_ds.iter_torch_batches(batch_size=batch_size)\n",
"\n",
" def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)\n",
"\n",
" trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_ds_iterable,\n",
" eval_dataset=eval_ds_iterable,\n",
" compute_metrics=compute_metrics,\n",
" tokenizer=tokenizer,\n",
" data_collator=default_data_collator,\n",
" )\n",
"\n",
" # Add callback to report checkpoints to Ray Train\n",
" trainer.add_callback(RayTrainReportCallback())\n",
" trainer = prepare_trainer(trainer)\n",
" trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After defining the training function, instantiate the {class}`~ray.train.torch.TorchTrainer`. Aside from the function, set the `scaling_config` to control the number of workers and amount of resources to use, and `datasets`(the preprocessed Ray Datasets) to use for training and evaluation.\n",
"\n",
"```{note}\n",
"Running with multiple nodes necessitates the persistence of checkpoints\n",
"and other outputs to some external storage for access after training has completed.\n",
"**You should set up cloud storage or NFS, then replace `storage_path` with your own cloud bucket URI or NFS path.**\n",
"\n",
"See {ref}`Configuration and Persistent Storage` for more details.\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"storage_path = \"s3://your-bucket-here\" # TODO: Set up cloud storage\n",
"# storage_path=\"/mnt/path/to/nfs\" # TODO: Alternatively, set up NFS"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"import os, re\n",
"\n",
"artifact_storage = os.environ.get(\"ANYSCALE_ARTIFACT_STORAGE\", \"artifact_storage\")\n",
"user_name = re.sub(r\"\\s+\", \"__\", os.environ.get(\"ANYSCALE_USERNAME\", \"user\"))\n",
"storage_path = f\"{artifact_storage}/{user_name}/gptj-deepspeed-finetune\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 16\n",
"train_ds_size = processed_datasets[\"train\"].count()\n",
"steps_per_epoch = train_ds_size // (batch_size * num_workers)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"# SMOKE TEST SETTINGS FOR CI\n",
"steps_per_epoch = 10\n",
"num_workers = 8\n",
"batch_size = 1"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from ray.train.torch import TorchTrainer\n",
"from ray.train import RunConfig, ScalingConfig\n",
"\n",
"trainer = TorchTrainer(\n",
" train_loop_per_worker=train_func,\n",
" train_loop_config={\n",
" \"epochs\": 1,\n",
" \"batch_size\": batch_size, # per device\n",
" \"steps_per_epoch\": steps_per_epoch,\n",
" },\n",
" scaling_config=ScalingConfig(\n",
" num_workers=num_workers,\n",
" use_gpu=use_gpu,\n",
" resources_per_worker={\"GPU\": 1, \"CPU\": cpus_per_worker},\n",
" ),\n",
" datasets=processed_datasets,\n",
" run_config=RunConfig(storage_path=storage_path),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, call the {meth}`~ray.train.torch.TorchTrainer.fit` method to start training with Ray Train. Save the {class}`~ray.train.Result` object to a variable to access metrics and checkpoints."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [
{
"data": {
"text/html": [
"