Fine-tune vicuna-13b with Lightning and DeepSpeed#

In this example, we will demonstrate how to perform full fine-tuning for a vicuna-13b-v1.3 model using Ray Train PyTorch Lightning integrations with the DeepSpeed ZeRO-3 strategy.

  • DeepSpeed is an open-source deep learning optimization library for PyTorch. It’s designed to reduce computing power and memory usage, and to train large distributed models by leveraging state-of-the-art innovations like ZeRO, 3D-Parallelism, DeepSpeed-MoE, and ZeRO-Infinity.

  • PyTorch Lightning offers a DeepSpeed integration, which provides a simple interface to configure the knobs for DeepSpeed and automatically trigger your training process with the DeepSpeed Engine.

  • Ray TorchTrainer allows you to easily scale your PyTorch Lightning job across multiple nodes in a Ray cluster, without worrying about the underlying cluster management, autoscaling, and distributed process group settings.

Our demo aims to illustrate how these three tools can be combined effectively to finetune the Vicuna-13B model, leveraging the strengths of each to create an efficient and high-performance deep learning solution.

Note

This is an advanced example of Large Language Model fine-tuning with Ray Train. If you’re a beginner or new to the concepts of Ray Train and our Lightning integrations, it would be beneficial to first explore the introductory documentation below to build a foundational understanding.

Cluster Setting#

Compute instances#

In this example, we set up a Ray cluster on AWS with the following settings:

num

instance type

GPU per node

GPU Memory

CPU Memory

Head node

1

g5.16xlarge

1 x A10G

24 GB

256 GB

Worker node

15

g5.4xlarge

1 x A10G

24 GB

64 GB

Note

In this example, we used 16 A10G GPUs for model training and tuned the DeepSpeed configurations for this setup. If you have a different cluster setup or GPUs with lower memory capacities, you may need to modify the DeepSpeed configurations and batch size to fit the model into the GPUs.

Tip

We selected a GPU instance with additional CPU memory for the head node to demonstrate single-node offline inference. If you are training only, you can still opt for the g5.4xlarge instance for the head node.

Cloud Storage#

Additionally, since the checkpoint size for this 13B parameter model can be large (~140GB), we choose to store the checkpoints in AWS S3. Thanks to the newly introduced distributed checkpointing feature in Ray 2.5, each worker can upload its own shards individually to the S3 bucket, greatly reducing the latency and network traffic of checkpoint syncing.

Local Storage#

To demonstrate offline inference, we need to download and consolidate the model checkpoint onto the head node. This action requires around 200GB disk storage. Therefore, we mounted the NVMe SSD provided by g5 instances at /dev/nvme1n1 to /mnt/local_storage, and we will save the checkpoints in this folder.

For more details, see Amazon EBS and NVMe on Linux instances.

Setup Ray Environment#

We define a runtime environment to ensure that the Ray workers have access to all necessary packages. If you have already included these dependencies in your Docker image or installed them on each node, you can ignore the runtime_env argument.

Note

Note that the codebases of transformers, accelerate, and deepspeed are all rapidly changing, so we have pinned the package versions here to ensure testing stability. You can try other version combinations and feel free to report any issues you encounter.

import ray

NUM_WORKERS = 16
BATCH_SIZE_PER_WORKER = 8
MODEL_NAME = "lmsys/vicuna-13b-v1.3"

ray.init(
    runtime_env={
        "pip": [
            "datasets==2.13.1",
            "torch>=1.13.0",
            "deepspeed==0.12.3",
            "accelerate==0.20.3",
            "transformers==4.30.2",
            "lightning==2.0.3",
        ],
    }
)

Load and preprocess datasets#

We were impressed by LLM’s ability of zero-shot text-generation, while some LLMs may not perform well in code generation due to the lack of code in the training corpus. The CMU CoNaLa(The Code/Natural Language Challenge) was designed to test systems for generating program snippets from natural language. Each data record contains an intent sentence and a one-line code snippet. The goal is to fine-tune the Vicuna model on this dataset, enabling the model to generate correct and runnable code snippets, thereby achieving natural language intent. Here are some examples:

intent

code snippet

“convert a list of integers into a single integer”

r = int(''.join(map(str, x)))

“normalize a pandas dataframe df by row”

df.div(df.sum(axis=1), axis=0)

“Convert string ‘03:55’ into datetime.time object”

datetime.datetime.strptime('03:55', '%H:%M').time()

The CoNaLa team has released a dataset crawled from Stack Overflow, automatically filtered, then curated by annotators, split into 2379 training and 500 test examples. In addition, they also included an automatically-mined dataset with 600k examples. In this demo, we take all the curated data and the top 5000 mined data for fine-tuning.

Here we preprocess the CoNaLa dataset with Ray Data. You can also use HuggingFace Datasets and pass it directly to LightningConfigBuilder.fit_params().

import re
import ray
import json
from transformers import AutoTokenizer
from datasets import concatenate_datasets, load_dataset

# Combine the curated dataset and automatically-mined dataset
hf_dataset_curated = load_dataset("neulab/conala")
hf_dataset_mined = load_dataset("neulab/conala", "mined", split="train[:5000]")
hf_dataset_merged = concatenate_datasets(
    [hf_dataset_curated["train"], hf_dataset_mined]
)
print(hf_dataset_merged)

# Convert it into Ray Dataset
ray_ds = ray.data.from_huggingface(hf_dataset_merged)

# Build a prompt template for Vicuna-13b model
PROMPT_TEMPLATE = "Intent: {intent}\nOne-line code snippet: {snippet}"


def fill_prompt(batch):
    batch["input_sentence"] = batch.apply(
        lambda row: PROMPT_TEMPLATE.format(
            intent=row["rewritten_intent"]
            if row["rewritten_intent"]
            else row["intent"],
            snippet=f"`{row['snippet']}`",
        )
        + "</s>",
        axis=1,
    )
    return batch[["input_sentence"]]


# Tokenize input sentences to tensors
def tokenize(batch):
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME, padding_side="left", use_fast=False
    )
    tokenizer.pad_token = tokenizer.eos_token
    ret = tokenizer(
        list(batch["input_sentence"]),
        truncation=True,
        max_length=128,
        padding="max_length",
        return_tensors="np",
    )
    ret["labels"] = ret["input_ids"].copy()
    return dict(ret)

# Preprocess train dataset
processed_ds = ray_ds.map_batches(fill_prompt, batch_format="pandas").map_batches(tokenize, batch_format="pandas")
Dataset({
    features: ['question_id', 'intent', 'rewritten_intent', 'snippet', 'parent_answer_post_id', 'prob', 'id'],
    num_rows: 7379
})

Define a Lightning Module#

Here we load the pre-trained model weights from HuggingFace Model Hub, and wrap them into pl.LightningModule. We adopted the efficient model initialization techniques introduced in Lightning-transformers to avoid unnecessary full weights loading.

import torch
import transformers
import lightning.pytorch as pl
from transformers import AutoTokenizer, AutoModelForCausalLM
from deepspeed.ops.adam import DeepSpeedCPUAdam


class ZeRO3Config:
    def __init__(self, pl_module):
        self.config = pl_module.trainer.strategy.config

    def __call__(self, *args, **kwargs):
        return self

    def is_zero3(self) -> bool:
        return True


def enable_transformers_pretrained_deepspeed_sharding(
    pl_module: "pl.LightningModule",
) -> None:
    transformers.deepspeed._hf_deepspeed_config_weak_ref = ZeRO3Config(pl_module)


class Vicuna13BModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Enable tf32 for better performance
        torch.backends.cuda.matmul.allow_tf32 = True

    def setup(self, stage) -> None:
        # Defer model initialization to inject deepspeed configs to HF.
        # During initialization, HF transformers can immediately partition 
        # the model across all gpus avoid the overhead in time and memory 
        # copying it on CPU or each GPU first.
        enable_transformers_pretrained_deepspeed_sharding(self)
        self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
        if self.global_rank == 0:
            print("DeepSpeed Configs: ", self.trainer.strategy.config)
            print("Model Archetecture: ", self.model)

    def forward(self, batch):
        outputs = self.model(
            batch["input_ids"],
            labels=batch["labels"],
            attention_mask=batch["attention_mask"],
        )
        return outputs.loss

    def training_step(self, batch, batch_idx):
        loss = self.forward(batch)
        self.log("train_loss", loss, prog_bar=True, on_step=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return DeepSpeedCPUAdam(self.parameters(), lr=2e-5, weight_decay=0.01)
[2023-06-30 17:39:35,109] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)

DeepSpeed Configurations#

Before training, let’s calculate the memory usage of finetuning a vicuna-13b model. Assume we are using FP16 mixed-precision training, and the optimizer is Adam with FP32 states.

  • Model parameters: 13(billion parameters) * 2(FP16) ≈ 26GB

  • Optimizer states: 13(billion parameters) * 2(momentums per param) * 4 (FP32) ≈ 52GB

As we can see, the model parameters themselves require 26GB, which cannot fit in a single A10G GPU, let alone the activations and optimizers states. Here, we use ZeRO stage-3 to partition the model, gradients, and optimizer states across 16 nodes. Additionally, we employ optimizer CPU offloading to reduce GRAM usage and increase throughput with larger batch sizes. We also disabled parameter offloading and activation checkpointing to improve the training speed.

Regarding other knobs such as reduce_bucket_size, stage3_prefetch_bucket_size and stage3_param_persistence_threshold, we kept them as the default values in HuggingFace. Feel free to further adjust them to speed up the training process.

from transformers import AutoConfig

config = AutoConfig.from_pretrained(MODEL_NAME)
HIDDEN_SIZE = config.hidden_size

deepspeed_configs = {
    "zero_allow_untested_optimizer": True,
    "bf16": {"enabled": True},
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {"device": "cpu", "pin_memory": True},
        "overlap_comm": True,
        "contiguous_gradients": True,
        "reduce_bucket_size": HIDDEN_SIZE * HIDDEN_SIZE,
        "stage3_prefetch_bucket_size": 0.9 * HIDDEN_SIZE * HIDDEN_SIZE,
        "stage3_param_persistence_threshold": 10 * HIDDEN_SIZE,
    },
}

Define your training function#

Finally, define the training function that will be launched on multiple workers. The training function is generally the same as the pure pytorch Lightning training code, with additional Ray Train utilities:

For Ray Data ingestion, we fetched the preprocessed and sharded dataset with get_dataset_shard(), and created a dataloader with iter_torch_batches(). It returns a custom iterator that replaces the Torch DataLoader.

import ray.train
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
    prepare_trainer,
    RayDeepSpeedStrategy, 
    RayLightningEnvironment, 
    RayTrainReportCallback
)


def train_func(config):
    """Training function for each worker."""

    # Unpack the `train_loop_config`
    max_epochs = config["max_epochs"]
    batch_size = config["batch_size"]
    accumulate_grad_batches = config["accumulate_grad_batches"]

    model = Vicuna13BModel()
    
    # Prepare Ray Data Ingestion
    train_ds = ray.train.get_dataset_shard("train")
    train_dataloader = train_ds.iter_torch_batches(batch_size=batch_size)
    
    pl_trainer = pl.Trainer(
        devices="auto",
        accelerator="auto",
        strategy=RayDeepSpeedStrategy(config=deepspeed_configs),
        plugins=[RayLightningEnvironment()],
        callbacks=[RayTrainReportCallback()],
        enable_checkpointing=False, # RayTrainReportCallback will save the checkpoints
        max_epochs=max_epochs,
        precision="bf16-mixed",
        accumulate_grad_batches=accumulate_grad_batches,
    )
    pl_trainer = prepare_trainer(pl_trainer)

    pl_trainer.fit(model, train_dataloaders=train_dataloader)
    

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config={
        "max_epochs": 1,
        "batch_size": BATCH_SIZE_PER_WORKER,
        "accumulate_grad_batches": 2
    },
    run_config=RunConfig(
        name="vicuna-13b-finetune",
        storage_path="s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/air-release-tests",
        checkpoint_config=CheckpointConfig(num_to_keep=1),
    ),
    scaling_config=ScalingConfig(
        num_workers=NUM_WORKERS,
        use_gpu=True,
        resources_per_worker={"CPU": 15, "GPU": 1},
    ),
    datasets={"train": processed_ds},
)

Model Fine-tuning#

Once everything is configured in TorchTrainer, training becomes easy. Simply call trainer.fit(), and your workload will be scaled to the Ray cluster, initiating ZeRO-3 parallel training.

result = trainer.fit()

Tune Status

Current time:2023-06-30 18:21:59
Running for: 00:42:22.75
Memory: 10.7/249.1 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 241.0/304 CPUs, 16.0/16 GPUs (0.0/16.0 accelerator_type:A10G)

Trial Status

Trial name status loc iter total time (s) train_loss epoch step
LightningTrainer_c1544_00000TERMINATED10.0.55.20:134103 1 2473.94 0.523438 0 29
(pid=134103) [2023-06-30 17:39:41,637] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
(LightningTrainer pid=134103) The `preprocessor` arg to Trainer is deprecated. Apply preprocessor transformations ahead of time by calling `preprocessor.transform(ds)`. Support for the preprocessor arg will be dropped in a future release.
(LightningTrainer pid=134103) Important: Ray Data requires schemas for all datasets in Ray 2.5. This means that standalone Python objects are no longer supported. In addition, the default batch format is fixed to NumPy. To revert to legacy behavior temporarily, set the environment variable RAY_DATA_STRICT_MODE=0 on all cluster processes.
(LightningTrainer pid=134103) 
(LightningTrainer pid=134103) Learn more here: https://docs.ray.io/en/master/data/faq.html#migrating-to-strict-mode
(LightningTrainer pid=134103) Starting distributed worker processes: ['134267 (10.0.55.20)', '74152 (10.0.63.141)', '75476 (10.0.51.205)', '75547 (10.0.42.158)', '74711 (10.0.45.211)', '75132 (10.0.20.140)', '74502 (10.0.60.86)', '75695 (10.0.53.69)', '74457 (10.0.47.2)', '74569 (10.0.33.23)', '74341 (10.0.29.61)', '74274 (10.0.36.152)', '74561 (10.0.35.16)', '74427 (10.0.16.236)', '74273 (10.0.54.55)', '74996 (10.0.9.249)']
(RayTrainWorker pid=134267) Setting up process group for: env:// [rank=0, world_size=16]
(LightningTrainer pid=134103) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(BatchMapper._transform_pandas)->MapBatches(BatchMapper._transform_pandas)] -> AllToAllOperator[RandomizeBlockOrder]
(LightningTrainer pid=134103) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(LightningTrainer pid=134103) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
Downloading (…)okenizer_config.json: 100%|██████████| 727/727 [00:00<00:00, 8.86MB/s]m_pandas) pid=74329, ip=10.0.54.55) 
Downloading tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 18.2MB/s]ansform_pandas) pid=74329, ip=10.0.54.55) 
Downloading (…)cial_tokens_map.json: 100%|██████████| 435/435 [00:00<00:00, 3.33MB/s]m_pandas) pid=74329, ip=10.0.54.55) 
(RayTrainWorker pid=74152, ip=10.0.63.141) [2023-06-30 17:39:54,612] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Downloading (…)okenizer_config.json: 100%|██████████| 727/727 [00:00<00:00, 7.86MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 727/727 [00:00<00:00, 7.57MB/s]
(RayTrainWorker pid=134267) GPU available: True (cuda), used: True
(RayTrainWorker pid=134267) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=134267) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=134267) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=134267) `Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]
Downloading tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 14.9MB/s]
(RayTrainWorker pid=134267) initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/16
(RayTrainWorker pid=74273, ip=10.0.54.55) Missing logger folder: /home/ray/ray_results/vicuna-13b-relation-extraction/LightningTrainer_c1544_00000_0_2023-06-30_17-39-36/rank_all/lightning_logs
(RayTrainWorker pid=134267) [2023-06-30 17:39:55,589] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
Downloading tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 18.2MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 435/435 [00:00<00:00, 6.49MB/s]
Downloading (…)lve/main/config.json:   0%|          | 0.00/585 [00:00<?, ?B/s]
Downloading (…)lve/main/config.json: 100%|██████████| 585/585 [00:00<00:00, 7.81MB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 585/585 [00:00<00:00, 7.09MB/s]
Downloading (…)model.bin.index.json: 100%|██████████| 33.4k/33.4k [00:00<00:00, 35.1MB/s]
Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]
(RayTrainWorker pid=75547, ip=10.0.42.158) 
Downloading (…)l-00001-of-00003.bin:   0%|          | 0.00/9.95G [00:00<?, ?B/s]
Downloading (…)l-00001-of-00003.bin:   0%|          | 21.0M/9.95G [00:00<00:59, 167MB/s]
Downloading (…)l-00001-of-00003.bin:   0%|          | 41.9M/9.95G [00:00<00:58, 170MB/s] 
Downloading (…)okenizer_config.json: 100%|██████████| 727/727 [00:00<00:00, 8.33MB/s] [repeated 9x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]
Downloading tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 17.5MB/s] [repeated 8x across cluster]
(RayTrainWorker pid=74561, ip=10.0.35.16) initializing deepspeed distributed: GLOBAL_RANK: 12, MEMBER: 13/16 [repeated 15x across cluster]
(RayTrainWorker pid=74561, ip=10.0.35.16) Missing logger folder: /home/ray/ray_results/vicuna-13b-relation-extraction/LightningTrainer_c1544_00000_0_2023-06-30_17-39-36/rank_all/lightning_logs [repeated 15x across cluster]
Downloading tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 8.85MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 435/435 [00:00<00:00, 5.23MB/s] [repeated 10x across cluster]
Downloading (…)lve/main/config.json: 100%|██████████| 585/585 [00:00<00:00, 7.03MB/s] [repeated 13x across cluster]
Downloading (…)model.bin.index.json: 100%|██████████| 33.4k/33.4k [00:00<00:00, 87.9MB/s] [repeated 15x across cluster]
Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s] [repeated 15x across cluster]
(RayTrainWorker pid=74341, ip=10.0.29.61)  [repeated 650x across cluster]
Downloading (…)l-00001-of-00003.bin:   0%|          | 0.00/9.95G [00:00<?, ?B/s] [repeated 15x across cluster]
Downloading (…)l-00001-of-00003.bin:  13%|█▎        | 1.31G/9.95G [00:05<00:36, 239MB/s] [repeated 636x across cluster]
Downloading (…)l-00001-of-00003.bin:   1%|          | 105M/9.95G [00:00<00:41, 239MB/s]  [repeated 17x across cluster]
(RayTrainWorker pid=74711, ip=10.0.45.211)  [repeated 640x across cluster]
Downloading (…)l-00001-of-00003.bin:  26%|██▌       | 2.58G/9.95G [00:10<00:28, 256MB/s] [repeated 635x across cluster]
(RayTrainWorker pid=74502, ip=10.0.60.86)  [repeated 638x across cluster]
Downloading (…)l-00001-of-00003.bin:  37%|███▋      | 3.70G/9.95G [00:15<00:26, 238MB/s] [repeated 638x across cluster]
(RayTrainWorker pid=74274, ip=10.0.36.152)  [repeated 643x across cluster]
Downloading (…)l-00001-of-00003.bin:  51%|█████▏    | 5.12G/9.95G [00:20<00:18, 255MB/s] [repeated 649x across cluster]
(RayTrainWorker pid=75476, ip=10.0.51.205)  [repeated 638x across cluster]
Downloading (…)l-00001-of-00003.bin:  65%|██████▌   | 6.48G/9.95G [00:25<00:14, 246MB/s] [repeated 633x across cluster]
(RayTrainWorker pid=74457, ip=10.0.47.2)  [repeated 645x across cluster]
Downloading (…)l-00001-of-00003.bin:  76%|███████▌  | 7.52G/9.95G [00:29<00:09, 247MB/s] [repeated 644x across cluster]
Downloading (…)l-00001-of-00003.bin:  91%|█████████▏| 9.10G/9.95G [00:34<00:03, 263MB/s]
Downloading (…)l-00001-of-00003.bin:  92%|█████████▏| 9.13G/9.95G [00:34<00:03, 257MB/s]
(RayTrainWorker pid=74711, ip=10.0.45.211)  [repeated 634x across cluster]
Downloading (…)l-00001-of-00003.bin:  82%|████████▏ | 8.17G/9.95G [00:35<00:07, 228MB/s] [repeated 628x across cluster]
Downloading (…)l-00001-of-00003.bin: 100%|██████████| 9.95G/9.95G [00:37<00:00, 262MB/s]
Downloading shards:  33%|███▎      | 1/3 [00:38<01:16, 38.09s/it]
Downloading (…)l-00002-of-00003.bin:   0%|          | 0.00/9.90G [00:00<?, ?B/s]
Downloading (…)l-00002-of-00003.bin:   1%|▏         | 126M/9.90G [00:00<00:35, 273MB/s] 
Downloading (…)l-00001-of-00003.bin:  93%|█████████▎| 9.27G/9.95G [00:39<00:02, 228MB/s] [repeated 394x across cluster]
(RayTrainWorker pid=75547, ip=10.0.42.158)  [repeated 633x across cluster]
Downloading (…)l-00002-of-00003.bin:   2%|▏         | 241M/9.90G [00:01<00:38, 252MB/s] [repeated 213x across cluster]
Downloading (…)l-00001-of-00003.bin: 100%|██████████| 9.95G/9.95G [00:40<00:00, 243MB/s] [repeated 8x across cluster]
Downloading shards:  33%|███▎      | 1/3 [00:42<01:25, 42.77s/it] [repeated 15x across cluster]
Downloading (…)l-00002-of-00003.bin:   0%|          | 0.00/9.90G [00:00<?, ?B/s] [repeated 15x across cluster]
Downloading (…)l-00002-of-00003.bin:   1%|          | 115M/9.90G [00:00<00:46, 209MB/s]  [repeated 16x across cluster]
Downloading (…)l-00001-of-00003.bin: 100%|██████████| 9.95G/9.95G [00:42<00:00, 233MB/s] [repeated 50x across cluster]
(RayTrainWorker pid=74341, ip=10.0.29.61)  [repeated 636x across cluster]
Downloading (…)l-00002-of-00003.bin:  19%|█▊        | 1.86G/9.90G [00:06<00:29, 275MB/s] [repeated 589x across cluster]
(RayTrainWorker pid=74996, ip=10.0.9.249)  [repeated 649x across cluster]
Downloading (…)l-00002-of-00003.bin:  18%|█▊        | 1.75G/9.90G [00:07<00:34, 234MB/s] [repeated 643x across cluster]
(RayTrainWorker pid=74502, ip=10.0.60.86)  [repeated 645x across cluster]
Downloading (…)l-00002-of-00003.bin:  41%|████▏     | 4.09G/9.90G [00:15<00:21, 271MB/s] [repeated 644x across cluster]
(RayTrainWorker pid=74273, ip=10.0.54.55)  [repeated 652x across cluster]
Downloading (…)l-00002-of-00003.bin:  53%|█████▎    | 5.25G/9.90G [00:21<00:19, 242MB/s] [repeated 656x across cluster]
(RayTrainWorker pid=74152, ip=10.0.63.141)  [repeated 647x across cluster]
Downloading (…)l-00002-of-00003.bin:  67%|██████▋   | 6.66G/9.90G [00:25<00:13, 246MB/s] [repeated 646x across cluster]
(RayTrainWorker pid=75132, ip=10.0.20.140)  [repeated 629x across cluster]
Downloading (…)l-00002-of-00003.bin:  84%|████████▍ | 8.30G/9.90G [00:31<00:06, 234MB/s] [repeated 627x across cluster]
Downloading (…)l-00002-of-00003.bin:  91%|█████████▏| 9.06G/9.90G [00:34<00:03, 241MB/s]
(RayTrainWorker pid=74457, ip=10.0.47.2)  [repeated 627x across cluster]
Downloading (…)l-00002-of-00003.bin:  89%|████████▉ | 8.84G/9.90G [00:36<00:04, 228MB/s] [repeated 567x across cluster]
Downloading (…)l-00002-of-00003.bin: 100%|██████████| 9.90G/9.90G [00:38<00:00, 257MB/s]
Downloading shards:  67%|██████▋   | 2/3 [01:16<00:38, 38.38s/it]
Downloading (…)l-00003-of-00003.bin:   0%|          | 0.00/6.18G [00:00<?, ?B/s]
Downloading (…)l-00003-of-00003.bin:   2%|▏         | 126M/6.18G [00:00<00:22, 266MB/s] 
Downloading (…)l-00002-of-00003.bin:  98%|█████████▊| 9.69G/9.90G [00:38<00:00, 236MB/s] [repeated 310x across cluster]
(RayTrainWorker pid=75476, ip=10.0.51.205)  [repeated 629x across cluster]
Downloading (…)l-00003-of-00003.bin:   2%|▏         | 94.4M/6.18G [00:00<00:24, 247MB/s] [repeated 275x across cluster]
Downloading (…)l-00002-of-00003.bin: 100%|██████████| 9.90G/9.90G [00:39<00:00, 253MB/s] [repeated 10x across cluster]
Downloading shards:  67%|██████▋   | 2/3 [01:20<00:40, 40.01s/it] [repeated 13x across cluster]
Downloading (…)l-00003-of-00003.bin:   0%|          | 0.00/6.18G [00:00<?, ?B/s] [repeated 13x across cluster]
Downloading (…)l-00003-of-00003.bin:   2%|▏         | 126M/6.18G [00:00<00:24, 243MB/s]  [repeated 13x across cluster]
Downloading (…)l-00002-of-00003.bin: 100%|█████████▉| 9.88G/9.90G [00:41<00:00, 242MB/s] [repeated 122x across cluster]
(RayTrainWorker pid=74273, ip=10.0.54.55)  [repeated 638x across cluster]
Downloading (…)l-00003-of-00003.bin:  21%|██        | 1.31G/6.18G [00:05<00:20, 243MB/s] [repeated 569x across cluster]
Downloading (…)l-00002-of-00003.bin: 100%|██████████| 9.90G/9.90G [00:40<00:00, 242MB/s] [repeated 2x across cluster]
Downloading shards:  67%|██████▋   | 2/3 [01:23<00:41, 41.78s/it] [repeated 2x across cluster]
Downloading (…)l-00003-of-00003.bin:   0%|          | 0.00/6.18G [00:00<?, ?B/s] [repeated 2x across cluster]
Downloading (…)l-00003-of-00003.bin:   2%|▏         | 105M/6.18G [00:00<00:24, 248MB/s]  [repeated 2x across cluster]
Downloading (…)l-00002-of-00003.bin: 100%|█████████▉| 9.87G/9.90G [00:40<00:00, 260MB/s] [repeated 3x across cluster]
(RayTrainWorker pid=74274, ip=10.0.36.152)  [repeated 638x across cluster]
Downloading (…)l-00003-of-00003.bin:  41%|████▏     | 2.56G/6.18G [00:10<00:14, 256MB/s] [repeated 635x across cluster]
(RayTrainWorker pid=74152, ip=10.0.63.141)  [repeated 629x across cluster]
Downloading (…)l-00003-of-00003.bin:  62%|██████▏   | 3.84G/6.18G [00:15<00:08, 279MB/s] [repeated 627x across cluster]
Downloading (…)l-00003-of-00003.bin:  92%|█████████▏| 5.66G/6.18G [00:22<00:01, 268MB/s]
Downloading (…)l-00003-of-00003.bin:  92%|█████████▏| 5.69G/6.18G [00:22<00:01, 265MB/s]
Downloading (…)l-00003-of-00003.bin:  93%|█████████▎| 5.73G/6.18G [00:22<00:01, 268MB/s]
Downloading (…)l-00003-of-00003.bin:  93%|█████████▎| 5.76G/6.18G [00:22<00:01, 270MB/s]
(RayTrainWorker pid=75547, ip=10.0.42.158)  [repeated 644x across cluster]
Downloading (…)l-00003-of-00003.bin:  85%|████████▌ | 5.25G/6.18G [00:20<00:03, 270MB/s] [repeated 618x across cluster]
Downloading (…)l-00003-of-00003.bin: 100%|██████████| 6.18G/6.18G [00:24<00:00, 257MB/s]
Downloading shards: 100%|██████████| 3/3 [01:40<00:00, 33.61s/it]
Downloading (…)l-00003-of-00003.bin:  98%|█████████▊| 6.03G/6.18G [00:23<00:00, 269MB/s] [repeated 166x across cluster]
(RayTrainWorker pid=74274, ip=10.0.36.152)  [repeated 426x across cluster]
Downloading (…)l-00003-of-00003.bin:  86%|████████▌ | 5.30G/6.18G [00:21<00:03, 246MB/s] [repeated 222x across cluster]
Downloading (…)l-00003-of-00003.bin: 100%|██████████| 6.18G/6.18G [00:25<00:00, 239MB/s] [repeated 7x across cluster]
Downloading shards: 100%|██████████| 3/3 [01:45<00:00, 35.27s/it] [repeated 11x across cluster]
Downloading (…)l-00003-of-00003.bin:  98%|█████████▊| 6.04G/6.18G [00:25<00:00, 231MB/s] [repeated 98x across cluster]
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]
(RayTrainWorker pid=74274, ip=10.0.36.152)  [repeated 74x across cluster]
Downloading (…)l-00003-of-00003.bin:  91%|█████████ | 5.63G/6.18G [00:23<00:02, 242MB/s] [repeated 23x across cluster]
Downloading (…)l-00003-of-00003.bin: 100%|██████████| 6.18G/6.18G [00:24<00:00, 249MB/s]
Downloading shards: 100%|██████████| 3/3 [01:49<00:00, 36.47s/it] [repeated 4x across cluster]
Downloading (…)l-00003-of-00003.bin: 100%|██████████| 6.18G/6.18G [00:25<00:00, 241MB/s] [repeated 5x across cluster]
Loading checkpoint shards:  33%|███▎      | 1/3 [00:12<00:24, 12.11s/it]
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s] [repeated 15x across cluster]
Loading checkpoint shards:  33%|███▎      | 1/3 [00:18<00:37, 18.54s/it] [repeated 15x across cluster]
Loading checkpoint shards:  67%|██████▋   | 2/3 [00:30<00:15, 15.63s/it]
Loading checkpoint shards:  67%|██████▋   | 2/3 [00:30<00:15, 15.71s/it]
Loading checkpoint shards:  67%|██████▋   | 2/3 [00:35<00:17, 17.73s/it] [repeated 14x across cluster]
Loading checkpoint shards: 100%|██████████| 3/3 [00:40<00:00, 13.47s/it]
Downloading (…)neration_config.json: 100%|██████████| 132/132 [00:00<00:00, 458kB/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:45<00:00, 15.29s/it] [repeated 15x across cluster]
(RayTrainWorker pid=74996, ip=10.0.9.249) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Downloading (…)neration_config.json: 100%|██████████| 132/132 [00:00<00:00, 542kB/s] [repeated 14x across cluster]
(RayTrainWorker pid=134267) DeepSpeed Configs:  {'zero_allow_untested_optimizer': True, 'bf16': {'enabled': True}, 'zero_optimization': {'stage': 3, 'offload_optimizer': {'device': 'cpu', 'pin_memory': True}, 'overlap_comm': True, 'contiguous_gradients': True, 'reduce_bucket_size': 26214400, 'stage3_prefetch_bucket_size': 23592960.0, 'stage3_param_persistence_threshold': 51200}, 'gradient_accumulation_steps': 2, 'train_micro_batch_size_per_gpu': 1, 'gradient_clipping': 0.0}
(RayTrainWorker pid=134267) Model Archetecture:  LlamaForCausalLM(
(RayTrainWorker pid=134267)   (model): LlamaModel(
(RayTrainWorker pid=134267)     (embed_tokens): Embedding(32000, 5120, padding_idx=0)
(RayTrainWorker pid=134267)     (layers): ModuleList(
(RayTrainWorker pid=134267)       (0-39): 40 x LlamaDecoderLayer(
(RayTrainWorker pid=134267)         (self_attn): LlamaAttention(
(RayTrainWorker pid=134267)           (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
(RayTrainWorker pid=134267)           (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
(RayTrainWorker pid=134267)           (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
(RayTrainWorker pid=134267)           (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
(RayTrainWorker pid=134267)           (rotary_emb): LlamaRotaryEmbedding()
(RayTrainWorker pid=134267)         )
(RayTrainWorker pid=134267)         (mlp): LlamaMLP(
(RayTrainWorker pid=134267)           (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
(RayTrainWorker pid=134267)           (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
(RayTrainWorker pid=134267)           (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
(RayTrainWorker pid=134267)           (act_fn): SiLUActivation()
(RayTrainWorker pid=134267)         )
(RayTrainWorker pid=134267)         (input_layernorm): LlamaRMSNorm()
(RayTrainWorker pid=134267)         (post_attention_layernorm): LlamaRMSNorm()
(RayTrainWorker pid=134267)       )
(RayTrainWorker pid=134267)     )
(RayTrainWorker pid=134267)     (norm): LlamaRMSNorm()
(RayTrainWorker pid=134267)   )
(RayTrainWorker pid=134267)   (lm_head): Linear(in_features=5120, out_features=32000, bias=False)
(RayTrainWorker pid=134267) )
(RayTrainWorker pid=74274, ip=10.0.36.152) [2023-06-30 17:39:54,688] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect) [repeated 15x across cluster]
(RayTrainWorker pid=74561, ip=10.0.35.16) [2023-06-30 17:39:56,220] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented [repeated 15x across cluster]
(RayTrainWorker pid=134267) ninja: no work to do.
(RayTrainWorker pid=134267) Time to load cpu_adam op: 2.403524875640869 seconds
(RayTrainWorker pid=134267) Using /home/ray/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
(RayTrainWorker pid=134267) Detected CUDA files, patching ldflags
(RayTrainWorker pid=134267) Emitting ninja build file /home/ray/.cache/torch_extensions/py310_cu118/cpu_adam/build.ninja...
(RayTrainWorker pid=134267) Building extension module cpu_adam...
(RayTrainWorker pid=134267) Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
(RayTrainWorker pid=134267) Loading extension module cpu_adam...
(RayTrainWorker pid=74502, ip=10.0.60.86) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [repeated 15x across cluster]
Downloading (…)neration_config.json: 100%|██████████| 132/132 [00:00<00:00, 1.72MB/s]
(RayTrainWorker pid=74996, ip=10.0.9.249) Building extension module utils...
(RayTrainWorker pid=74152, ip=10.0.63.141) Loading extension module utils...
(RayTrainWorker pid=74152, ip=10.0.63.141) Time to load utils op: 0.0775597095489502 seconds
(RayTrainWorker pid=134267) Parameter Offload: Total persistent parameters: 414720 in 81 params
(RayTrainWorker pid=74152, ip=10.0.63.141) No modifications detected for re-loaded extension module utils, skipping build step...
(RayTrainWorker pid=74152, ip=10.0.63.141) Using /home/ray/.cache/torch_extensions/py310_cu118 as PyTorch extensions root... [repeated 32x across cluster]
(RayTrainWorker pid=74561, ip=10.0.35.16) Detected CUDA files, patching ldflags [repeated 15x across cluster]
(RayTrainWorker pid=134267) Emitting ninja build file /home/ray/.cache/torch_extensions/py310_cu118/utils/build.ninja... [repeated 31x across cluster]
(RayTrainWorker pid=74561, ip=10.0.35.16) Building extension module cpu_adam... [repeated 15x across cluster]
(RayTrainWorker pid=134267) Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N) [repeated 31x across cluster]
(RayTrainWorker pid=75132, ip=10.0.20.140) Loading extension module cpu_adam... [repeated 15x across cluster]
(RayTrainWorker pid=134267) Building extension module utils... [repeated 15x across cluster]
(RayTrainWorker pid=74152, ip=10.0.63.141) Loading extension module utils... [repeated 16x across cluster]
(RayTrainWorker pid=134267) ninja: no work to do. [repeated 31x across cluster]
(RayTrainWorker pid=75132, ip=10.0.20.140) Time to load cpu_adam op: 2.3851447105407715 seconds [repeated 15x across cluster]
(RayTrainWorker pid=74152, ip=10.0.63.141) Time to load utils op: 0.0005815029144287109 seconds [repeated 16x across cluster]
(RayTrainWorker pid=134267) 
(RayTrainWorker pid=134267)   | Name  | Type             | Params | Params per Device
(RayTrainWorker pid=134267) ---------------------------------------------------------------
(RayTrainWorker pid=134267) 0 | model | LlamaForCausalLM | 13.0 B | 813 M            
(RayTrainWorker pid=134267) ---------------------------------------------------------------
(RayTrainWorker pid=134267) 13.0 B    Trainable params
(RayTrainWorker pid=134267) 0         Non-trainable params
(RayTrainWorker pid=134267) 13.0 B    Total params
(RayTrainWorker pid=134267) 52,063.457Total estimated model params size (MB)
Epoch 0:   0%|          | 0/57 [00:00<?, ?it/s]
(RayTrainWorker pid=134267) /home/ray/anaconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
(RayTrainWorker pid=134267)   rank_zero_warn(
Epoch 0:   2%|▏         | 1/57 [00:38<35:42, 38.26s/it, v_num=0, train_loss=11.50]
(RayTrainWorker pid=134267) Time to load utils op: 0.00030732154846191406 seconds [repeated 15x across cluster]
(RayTrainWorker pid=134267) [2023-06-30 17:44:33,395] [WARNING] [stage3.py:1851:step] 2 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:   4%|▎         | 2/57 [01:19<36:23, 39.69s/it, v_num=0, train_loss=10.70]
Epoch 0:   5%|▌         | 3/57 [01:52<33:52, 37.65s/it, v_num=0, train_loss=1.710]
(RayTrainWorker pid=134267) [2023-06-30 17:45:48,054] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:   7%|▋         | 4/57 [02:34<34:01, 38.51s/it, v_num=0, train_loss=1.610]
Epoch 0:   9%|▉         | 5/57 [03:08<32:35, 37.60s/it, v_num=0, train_loss=0.914]
(RayTrainWorker pid=134267) [2023-06-30 17:47:03,011] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  11%|█         | 6/57 [03:49<32:26, 38.17s/it, v_num=0, train_loss=0.973]
Epoch 0:  12%|█▏        | 7/57 [04:24<31:30, 37.81s/it, v_num=0, train_loss=0.801]
(RayTrainWorker pid=134267) [2023-06-30 17:48:19,362] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  14%|█▍        | 8/57 [05:05<31:10, 38.17s/it, v_num=0, train_loss=0.844]
Epoch 0:  16%|█▌        | 9/57 [05:39<30:12, 37.75s/it, v_num=0, train_loss=0.652]
(RayTrainWorker pid=134267) [2023-06-30 17:49:36,571] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  18%|█▊        | 10/57 [06:22<29:58, 38.26s/it, v_num=0, train_loss=0.633]
Epoch 0:  19%|█▉        | 11/57 [06:59<29:13, 38.12s/it, v_num=0, train_loss=0.629]
/arrow/cpp/src/arrow/filesystem/s3fs.cc:663: CompletedMultipartUpload got error embedded in a 200 OK response: InternalError ("We encountered an internal error. Please try again."), retry = 1
(RayTrainWorker pid=134267) [2023-06-30 17:50:54,177] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  21%|██        | 12/57 [07:40<28:45, 38.35s/it, v_num=0, train_loss=0.609]
Epoch 0:  23%|██▎       | 13/57 [08:14<27:53, 38.04s/it, v_num=0, train_loss=0.680]
(RayTrainWorker pid=134267) [2023-06-30 17:52:10,002] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  25%|██▍       | 14/57 [08:55<27:26, 38.29s/it, v_num=0, train_loss=0.648]
Epoch 0:  26%|██▋       | 15/57 [09:29<26:33, 37.95s/it, v_num=0, train_loss=0.645]
(RayTrainWorker pid=134267) [2023-06-30 17:53:23,209] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  28%|██▊       | 16/57 [10:09<26:01, 38.08s/it, v_num=0, train_loss=0.664]
Epoch 0:  30%|██▉       | 17/57 [10:43<25:13, 37.83s/it, v_num=0, train_loss=0.625]
(RayTrainWorker pid=134267) [2023-06-30 17:54:36,660] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  32%|███▏      | 18/57 [11:22<24:39, 37.93s/it, v_num=0, train_loss=0.617]
Epoch 0:  33%|███▎      | 19/57 [11:56<23:53, 37.71s/it, v_num=0, train_loss=0.609]
(RayTrainWorker pid=134267) [2023-06-30 17:55:51,289] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  35%|███▌      | 20/57 [12:37<23:20, 37.86s/it, v_num=0, train_loss=0.602]
Epoch 0:  37%|███▋      | 21/57 [13:11<22:36, 37.69s/it, v_num=0, train_loss=0.590]
(RayTrainWorker pid=134267) [2023-06-30 17:57:07,919] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  39%|███▊      | 22/57 [13:53<22:06, 37.91s/it, v_num=0, train_loss=0.555]
Epoch 0:  40%|████      | 23/57 [14:27<21:22, 37.72s/it, v_num=0, train_loss=0.598]
(RayTrainWorker pid=134267) [2023-06-30 17:58:22,349] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  42%|████▏     | 24/57 [15:08<20:48, 37.85s/it, v_num=0, train_loss=0.625]
Epoch 0:  44%|████▍     | 25/57 [15:43<20:07, 37.74s/it, v_num=0, train_loss=0.625]
Epoch 0:  44%|████▍     | 25/57 [15:43<20:07, 37.74s/it, v_num=0, train_loss=0.582]
(RayTrainWorker pid=134267) [2023-06-30 17:59:40,125] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  46%|████▌     | 26/57 [16:26<19:35, 37.93s/it, v_num=0, train_loss=0.535]
Epoch 0:  47%|████▋     | 27/57 [17:02<18:56, 37.88s/it, v_num=0, train_loss=0.578]
(RayTrainWorker pid=134267) [2023-06-30 18:00:58,164] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  49%|████▉     | 28/57 [17:44<18:22, 38.01s/it, v_num=0, train_loss=0.582]
Epoch 0:  51%|█████     | 29/57 [18:20<17:42, 37.93s/it, v_num=0, train_loss=0.578]
(RayTrainWorker pid=134267) [2023-06-30 18:02:15,097] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  53%|█████▎    | 30/57 [19:01<17:06, 38.04s/it, v_num=0, train_loss=0.598]
Epoch 0:  54%|█████▍    | 31/57 [19:36<16:26, 37.95s/it, v_num=0, train_loss=0.586]
(RayTrainWorker pid=134267) [2023-06-30 18:03:30,632] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  56%|█████▌    | 32/57 [20:16<15:50, 38.02s/it, v_num=0, train_loss=0.605]
Epoch 0:  58%|█████▊    | 33/57 [20:49<15:08, 37.87s/it, v_num=0, train_loss=0.594]
(RayTrainWorker pid=134267) [2023-06-30 18:04:45,362] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  60%|█████▉    | 34/57 [21:31<14:33, 37.98s/it, v_num=0, train_loss=0.598]
Epoch 0:  61%|██████▏   | 35/57 [22:08<13:54, 37.95s/it, v_num=0, train_loss=0.574]
(RayTrainWorker pid=134267) [2023-06-30 18:06:02,727] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  63%|██████▎   | 36/57 [22:48<13:18, 38.02s/it, v_num=0, train_loss=0.586]
Epoch 0:  65%|██████▍   | 37/57 [23:23<12:38, 37.94s/it, v_num=0, train_loss=0.562]
(RayTrainWorker pid=134267) [2023-06-30 18:07:19,126] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  67%|██████▋   | 38/57 [24:05<12:02, 38.03s/it, v_num=0, train_loss=0.535]
Epoch 0:  68%|██████▊   | 39/57 [24:38<11:22, 37.91s/it, v_num=0, train_loss=0.598]
(RayTrainWorker pid=134267) [2023-06-30 18:08:36,683] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  70%|███████   | 40/57 [25:22<10:47, 38.07s/it, v_num=0, train_loss=0.562]
Epoch 0:  72%|███████▏  | 41/57 [25:57<10:07, 37.98s/it, v_num=0, train_loss=0.555]
(RayTrainWorker pid=134267) [2023-06-30 18:09:52,426] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  74%|███████▎  | 42/57 [26:38<09:30, 38.06s/it, v_num=0, train_loss=0.555]
Epoch 0:  75%|███████▌  | 43/57 [27:13<08:51, 37.99s/it, v_num=0, train_loss=0.547]
(RayTrainWorker pid=134267) [2023-06-30 18:11:08,855] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  77%|███████▋  | 44/57 [27:54<08:14, 38.06s/it, v_num=0, train_loss=0.562]
Epoch 0:  79%|███████▉  | 45/57 [28:29<07:35, 37.98s/it, v_num=0, train_loss=0.535]
(RayTrainWorker pid=134267) [2023-06-30 18:12:25,181] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  81%|████████  | 46/57 [29:11<06:58, 38.07s/it, v_num=0, train_loss=0.531]
Epoch 0:  82%|████████▏ | 47/57 [29:45<06:19, 37.99s/it, v_num=0, train_loss=0.504]
(RayTrainWorker pid=134267) [2023-06-30 18:13:40,300] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  84%|████████▍ | 48/57 [30:26<05:42, 38.05s/it, v_num=0, train_loss=0.520]
Epoch 0:  86%|████████▌ | 49/57 [31:01<05:03, 37.99s/it, v_num=0, train_loss=0.523]
(RayTrainWorker pid=134267) [2023-06-30 18:14:55,542] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  88%|████████▊ | 50/57 [31:41<04:26, 38.03s/it, v_num=0, train_loss=0.520]
Epoch 0:  89%|████████▉ | 51/57 [32:16<03:47, 37.98s/it, v_num=0, train_loss=0.527]
(RayTrainWorker pid=134267) [2023-06-30 18:16:12,131] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  91%|█████████ | 52/57 [32:58<03:10, 38.04s/it, v_num=0, train_loss=0.562]
Epoch 0:  93%|█████████▎| 53/57 [33:34<02:32, 38.00s/it, v_num=0, train_loss=0.539]
(RayTrainWorker pid=134267) [2023-06-30 18:17:29,752] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  95%|█████████▍| 54/57 [34:15<01:54, 38.07s/it, v_num=0, train_loss=0.535]
Epoch 0:  96%|█████████▋| 55/57 [34:50<01:16, 38.01s/it, v_num=0, train_loss=0.512]
(RayTrainWorker pid=134267) [2023-06-30 18:18:45,986] [WARNING] [stage3.py:1851:step] 4 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0:  98%|█████████▊| 56/57 [35:31<00:38, 38.07s/it, v_num=0, train_loss=0.516]
Epoch 0: 100%|██████████| 57/57 [36:06<00:00, 38.00s/it, v_num=0, train_loss=0.461]
(RayTrainWorker pid=134267) [2023-06-30 18:20:01,817] [WARNING] [stage3.py:1851:step] 3 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
Epoch 0: : 58it [36:47, 38.07s/it, v_num=0, train_loss=0.523]                      
(RayTrainWorker pid=74427, ip=10.0.16.236) /home/ray/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1802: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.
(RayTrainWorker pid=74427, ip=10.0.16.236)   warnings.warn(
(RayTrainWorker pid=134267) No modifications detected for re-loaded extension module utils, skipping build step... [repeated 15x across cluster]
(RayTrainWorker pid=134267) Using /home/ray/.cache/torch_extensions/py310_cu118 as PyTorch extensions root... [repeated 15x across cluster]
(RayTrainWorker pid=134267) Loading extension module utils... [repeated 15x across cluster]
(RayTrainWorker pid=134267) Uploading checkpoint files from worker rank 0 to cloud URI s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/yunxuanx-test/vicuna-13b-test/vicuna-13b-relation-extraction/LightningTrainer_c1544_00000_0_2023-06-30_17-39-36/checkpoint_000000.
(RayTrainWorker pid=134267) /home/ray/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1802: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details. [repeated 15x across cluster]
(RayTrainWorker pid=134267)   warnings.warn( [repeated 15x across cluster]
(RayTrainWorker pid=75547, ip=10.0.42.158) Uploading checkpoint files from worker rank 3 to cloud URI s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/yunxuanx-test/vicuna-13b-test/vicuna-13b-relation-extraction/LightningTrainer_c1544_00000_0_2023-06-30_17-39-36/checkpoint_000000.
(RayTrainWorker pid=74152, ip=10.0.63.141) Uploading checkpoint files from worker rank 1 to cloud URI s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/yunxuanx-test/vicuna-13b-test/vicuna-13b-relation-extraction/LightningTrainer_c1544_00000_0_2023-06-30_17-39-36/checkpoint_000000.
(RayTrainWorker pid=134267) Done uploading checkpoint files.
(RayTrainWorker pid=74341, ip=10.0.29.61) Uploading checkpoint files from worker rank 10 to cloud URI s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/yunxuanx-test/vicuna-13b-test/vicuna-13b-relation-extraction/LightningTrainer_c1544_00000_0_2023-06-30_17-39-36/checkpoint_000000. [repeated 13x across cluster]
(RayTrainWorker pid=74427, ip=10.0.16.236) Done uploading checkpoint files.
(RayTrainWorker pid=74152, ip=10.0.63.141) Done uploading checkpoint files.
(RayTrainWorker pid=74711, ip=10.0.45.211) Done uploading checkpoint files. [repeated 11x across cluster]
Epoch 0: : 58it [37:42, 39.00s/it, v_num=0, train_loss=0.523]
(RayTrainWorker pid=134267) `Trainer.fit` stopped: `max_epochs=1` reached.
(LightningTrainer pid=134103) Uploading trial artifacts took 26.651 s, which may be a performance bottleneck. Consider saving fewer/smaller artifacts to the trial log directory, or disable artifact syncing with `SyncConfig(sync_artifacts=False)`.
(RayTrainWorker pid=75547, ip=10.0.42.158) Done uploading checkpoint files. [repeated 2x across cluster]
2023-06-30 18:21:59,316	INFO tune.py:1148 -- Total run time: 2542.82 seconds (2511.95 seconds for the tuning loop).

In summary:

  • Training takes: 36:06 = 2166s

  • Training + initialization + checkpointing takes 2473s

Model initialization and checkpoint synchronization took 307 seconds. It will be amortized as you have larger datasets and take more time to train.

result
Result(
  metrics={'_report_on': 'train_epoch_end', 'train_loss': 0.5234375, 'epoch': 0, 'step': 29, 'should_checkpoint': True, 'done': True, 'trial_id': 'c1544_00000', 'experiment_tag': '0'},
  path='s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/yunxuanx-test/vicuna-13b-test/vicuna-13b-relation-extraction/LightningTrainer_c1544_00000_0_2023-06-30_17-39-36',
  checkpoint=LightningCheckpoint(uri=s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/yunxuanx-test/vicuna-13b-test/vicuna-13b-relation-extraction/LightningTrainer_c1544_00000_0_2023-06-30_17-39-36/checkpoint_000000)
)

LLM Inference#

Now, it’s time to play with our fine-tuned Vicuna code generator!

Download and Process your checkpoints#

First, download the checkpoints to your local machine using the AWS CLI.

Note that adding the following configurations can significantly increase the syncing throughput compared to the default configurations. On a g5 instance with NVME SSD, the download speed improved from 200MB/s to around 1.5GB/s.

!awsv2 configure set s3.max_concurrent_requests 32
!awsv2 configure set default.s3.preferred_transfer_client crt
!awsv2 configure set default.s3.target_bandwidth 100Gb/s
!awsv2 configure set default.s3.multipart_chunksize 8MB
import os

os.system(f"awsv2 s3 sync s3://{result.checkpoint.path} /mnt/local_storage")

The deepspeed ZeRO-3 checkpoint is a directory containing of k shards (k=16 in our case).

  • zero_pp_rank_k_mp_rank_00_model_states.pt: contains the model parameter skeleton of shard k.

  • bf16_zero_pp_rank_k_mp_rank_00_optim_states.pt: contains the actual flattened model parameters and optimizer states of shard k.

Next, we removed the optimizer states and consolidate the checkpoint into a single binary file using DeepSpeed utilities. Also, since we wrapped vicuna-13b within a LightningModule, we need to remove the prefix _forward_module.model.model so that we can directly load the checkpoint into a HF vicuna model.

import os
import torch
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint

def extract_fp32_ckpt_from_zero(zero_ckpt_dir):
    state_dict = get_fp32_state_dict_from_zero_checkpoint(zero_ckpt_dir)
    vicuna_state_dict = {
        k.replace("_forward_module.model.", ""): v for k, v in state_dict.items()
    }
    torch.save(vicuna_state_dict, os.path.join(zero_ckpt_dir, "full_model.pt"))


full_model_ckpt_path = "/mnt/local_storage/checkpoint.ckpt/full_model.pt"
extract_fp32_ckpt_from_zero("/mnt/local_storage/checkpoint.ckpt")
Processing zero checkpoint '/mnt/local_storage/checkpoint/model/checkpoint'
Detected checkpoint of type zero stage 3, world_size: 16
Parsing checkpoint created by deepspeed==0.9.4
Reconstructed Trainable fp32 state dict with 363 params 13015864320 elements

Initialize Generation Pipeline#

Here, we leverage the Accelerate library to efficiently load the model onto a suitable device(GPU and CPU) and generate a HF text generation pipeline.

  • Initialize an empty model on metadevice

  • Create valid device mappings for the vicuna-13b model

  • Load and distribute model weights to target devices

This ensures that only 1x model size of RAM is used for model initialization.

import torch
import ray
import lightning.pytorch as pl
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from accelerate import (
    init_empty_weights,
    infer_auto_device_map,
    load_checkpoint_and_dispatch,
)

# Initialize a model on meta device
with init_empty_weights():
    config = AutoConfig.from_pretrained(MODEL_NAME)
    meta_model = AutoModelForCausalLM.from_config(config)
meta_model.tie_weights()

# Define the device mapping
device_map = infer_auto_device_map(
    meta_model,
    max_memory={0: "15GB", "cpu": "60GB"},
    no_split_module_classes=["LlamaDecoderLayer"],
)

# Load the model parameters
model = load_checkpoint_and_dispatch(
    meta_model,
    checkpoint=full_model_ckpt_path,
    device_map=device_map,
)
from transformers import pipeline

generator = pipeline(
    "text-generation",
    model=model,
    device_map=device_map,
    tokenizer=AutoTokenizer.from_pretrained(
        MODEL_NAME, padding_side="left", use_fast=False
    ),
)

Case Study#

We took 3 examples from the CoNaLa’s test split for demo:

testcases = [
    {
        "intent": "replace white spaces in colunm 'col' of dataframe `df` with '_'",
    },
    {
        "intent": "search for occurrences of regex pattern '>.*<' in xml string `line`",
    },
    {
        "intent": "send a signal `signal.SIGUSR1` to the current process",
    },
]

Let’s begin by examining the generated outputs without fine-tuning. In this case study, we utilize Aviary Explorer, an open-source multi-LLM serving platform supported by Ray and Anyscale. You can easily select from a variety of open-source LLMs and compare their generation quality, cost, latency, and many other metrics.

We constructed a prompt in a zero-shot learning manner and feed it into 3 OSS LLMs.

  • vicuna-13b-v1.3 begins to speak Chinese.

  • mpt-7b-chat generates a reasonable code snippet, but with multiple lines.

  • falcon-7b-sft generates a one line snippet, but it doesn’t seem to work.

As we can see, none of them generate a satisfactory code snippet.

Now let’s check the performance of our fine-tuned vicuna-13b-v1.3 model:

for case in testcases:
    prompt = PROMPT_TEMPLATE.format(intent=case["intent"], snippet="")
    output = generator(prompt, max_new_tokens=30, do_sample=True)
    print(output[0]["generated_text"])
/home/ray/anaconda3/lib/python3.10/site-packages/transformers/pipelines/base.py:1081: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
  warnings.warn(
Intent: replace white spaces in colunm 'col' of dataframe `df` with '_'
One-line code snippet:  `df['col'] = df['col'].str.replace(' ', '_')`

Intent: search for occurrences of regex pattern '>.*<' in xml string `line`
One-line code snippet:  `re.findall('>.*<', line)``

Intent: send a signal `signal.SIGUSR1` to the current process
One-line code snippet:  `os.kill(os.getpid(), signal.SIGUSR1)``

Test the Generated Code Snippets#

The generated code snippets look pretty reasonable. The results covered Pandas operations, regular expressions, and Linux commands. Let’s test them one by one.

import pandas as pd

df = pd.DataFrame.from_dict({"col": ["abc def ghi", " 12 3 456", "     "]})
print("Before\n", df)

df["col"] = df["col"].str.replace(" ", "_")
print("After\n", df)
Before
            col
0  abc def ghi
1     12 3 456
2             
After
            col
0  abc_def_ghi
1    _12_3_456
2        _____
import re

line = """
<bookstore>
  <book category="fiction">
    <title>The Great Gatsby</title>
    <author>F. Scott Fitzgerald</author>
    <year>1925</year>
  </book>
  <book category="non-fiction">
    <title>Sapiens: A Brief History of Humankind</title>
    <author>Yuval Noah Harari</author>
    <year>2011</year>
  </book>
</bookstore>
"""
re.findall(">.*<", line)
['>The Great Gatsby<',
 '>F. Scott Fitzgerald<',
 '>1925<',
 '>Sapiens: A Brief History of Humankind<',
 '>Yuval Noah Harari<',
 '>2011<']

Finally, let’s hand it over to LLM and let it wrap up the demo:

import os, signal

os.kill(os.getpid(), signal.SIGUSR1)  # Terminate the current process~

References:#