BERT Model Training with Intel Gaudi#
In this notebook, we will train a BERT model for sequence classification using the Yelp review full dataset. We will use the transformers
and datasets
libraries from Hugging Face, along with ray.train
for distributed training.
Intel Gaudi AI Processors (HPUs) are AI hardware accelerators designed by Intel Habana Labs. For more information, see Gaudi Architecture and Gaudi Developer Docs.
Configuration#
A node with Gaudi/Gaudi2 installed is required to run this example. Both Gaudi and Gaudi2 have 8 HPUs. We will use 2 workers to train the model, each using 1 HPU.
We recommend using a prebuilt container to run these examples. To run a container, you need Docker. See Install Docker Engine for installation instructions.
Next, follow Run Using Containers to install the Gaudi drivers and container runtime.
Next, start the Gaudi container:
docker pull vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest
Inside the container, install the following dependecies to run this notebook.
pip install ray[train] notebook transformers datasets evaluate
# Import necessary libraries
import os
from typing import Dict
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import evaluate
from datasets import load_dataset
import transformers
from transformers import (
Trainer,
TrainingArguments,
AutoTokenizer,
AutoModelForSequenceClassification,
)
import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchConfig
from ray.runtime_env import RuntimeEnv
import habana_frameworks.torch.core as htcore
/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
warnings.warn(
Metrics Setup#
We will use accuracy as our evaluation metric. The compute_metrics
function will calculate the accuracy of our model’s predictions.
# Metrics
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
Training Function#
This function will be executed by each worker during training. It handles data loading, tokenization, model initialization, and the training loop. Compared to a training function for GPU, no changes are needed to port to HPU. Internally, Ray Train does these things:
Detect HPU and set the device.
Initializes the habana PyTorch backend.
Initializes the habana distributed backend.
def train_func_per_worker(config: Dict):
# Datasets
dataset = load_dataset("yelp_review_full")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
lr = config["lr"]
epochs = config["epochs"]
batch_size = config["batch_size_per_worker"]
train_dataset = dataset["train"].select(range(1000)).map(tokenize_function, batched=True)
eval_dataset = dataset["test"].select(range(1000)).map(tokenize_function, batched=True)
# Prepare dataloader for each worker
dataloaders = {}
dataloaders["train"] = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=transformers.default_data_collator,
batch_size=batch_size
)
dataloaders["test"] = torch.utils.data.DataLoader(
eval_dataset,
shuffle=True,
collate_fn=transformers.default_data_collator,
batch_size=batch_size
)
# Obtain HPU device automatically
device = ray.train.torch.get_device()
# Prepare model and optimizer
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased", num_labels=5
)
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
# Start training loops
for epoch in range(epochs):
# Each epoch has a training and validation phase
for phase in ["train", "test"]:
if phase == "train":
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
# breakpoint()
for batch in dataloaders[phase]:
batch = {k: v.to(device) for k, v in batch.items()}
# zero the parameter gradients
optimizer.zero_grad()
# forward
with torch.set_grad_enabled(phase == "train"):
# Get model outputs and calculate loss
outputs = model(**batch)
loss = outputs.loss
# backward + optimize only if in training phase
if phase == "train":
loss.backward()
optimizer.step()
print(f"train epoch:[{epoch}]\tloss:{loss:.6f}")
Main Training Function#
The train_bert
function sets up the distributed training environment using Ray and starts the training process. To enable training using HPU, we only need to make the following changes:
Require an HPU for each worker in ScalingConfig
Set backend to “hccl” in TorchConfig
def train_bert(num_workers=2):
global_batch_size = 8
train_config = {
"lr": 1e-3,
"epochs": 10,
"batch_size_per_worker": global_batch_size // num_workers,
}
# Configure computation resources
# In ScalingConfig, require an HPU for each worker
scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={"CPU": 1, "HPU": 1})
# Set backend to hccl in TorchConfig
torch_config = TorchConfig(backend = "hccl")
# start your ray cluster
ray.init()
# Initialize a Ray TorchTrainer
trainer = TorchTrainer(
train_loop_per_worker=train_func_per_worker,
train_loop_config=train_config,
torch_config=torch_config,
scaling_config=scaling_config,
)
result = trainer.fit()
print(f"Training result: {result}")
Start Training#
Finally, we call the train_bert
function to start the training process. You can adjust the number of workers to use.
Note: the following warning is fine, and is resolved in SynapseAI version 1.14.0+:
/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
train_bert(num_workers=2)
Tune Status
Current time: | 2024-02-28 07:05:06 |
Running for: | 00:05:09.32 |
Memory: | 389.1/1007.5 GiB |
System Info
Using FIFO scheduling algorithm.Logical resource usage: 3.0/160 CPUs, 0/0 GPUs (0.0/1.0 TPU, 2.0/8.0 HPU)
Trial Status
Trial name | status | loc |
---|---|---|
TorchTrainer_fb74f_00000 | TERMINATED | 172.17.0.3:59382 |
(pid=59382) /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
(pid=59382) warnings.warn(
(RayTrainWorker pid=66009) Setting up process group for: env:// [rank=0, world_size=2]
(RayTrainWorker pid=66010) /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`. [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=66010) warnings.warn( [repeated 2x across cluster]
(TorchTrainer pid=59382) Started distributed worker processes:
(TorchTrainer pid=59382) - (ip=172.17.0.3, pid=66009) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=59382) - (ip=172.17.0.3, pid=66010) world_rank=1, local_rank=1, node_rank=0
Downloading readme: 100%|██████████| 6.72k/6.72k [00:00<00:00, 21.0MB/s]
Downloading data: 0%| | 0.00/299M [00:00<?, ?B/s]
Downloading data: 1%|▏ | 4.19M/299M [00:00<00:26, 11.2MB/s]
Downloading data: 4%|▍ | 12.6M/299M [00:00<00:10, 27.3MB/s]
Downloading data: 7%|▋ | 21.0M/299M [00:00<00:07, 35.6MB/s]
Downloading data: 10%|▉ | 29.4M/299M [00:00<00:06, 41.6MB/s]
Downloading data: 13%|█▎ | 37.7M/299M [00:01<00:05, 44.7MB/s]
Downloading data: 15%|█▌ | 46.1M/299M [00:01<00:05, 46.2MB/s]
Downloading data: 18%|█▊ | 54.5M/299M [00:01<00:05, 45.3MB/s]
Downloading data: 21%|██ | 62.9M/299M [00:01<00:05, 47.0MB/s]
Downloading data: 24%|██▍ | 71.3M/299M [00:01<00:06, 34.1MB/s]
Downloading data: 27%|██▋ | 79.7M/299M [00:02<00:05, 37.7MB/s]
Downloading data: 29%|██▉ | 88.1M/299M [00:02<00:05, 39.1MB/s]
Downloading data: 32%|███▏ | 96.5M/299M [00:02<00:04, 41.6MB/s]
Downloading data: 35%|███▌ | 105M/299M [00:02<00:05, 33.2MB/s]
Downloading data: 38%|███▊ | 113M/299M [00:03<00:05, 36.7MB/s]
Downloading data: 41%|████ | 122M/299M [00:03<00:04, 40.5MB/s]
Downloading data: 43%|████▎ | 130M/299M [00:03<00:04, 41.8MB/s]
Downloading data: 46%|████▌ | 138M/299M [00:03<00:03, 42.1MB/s]
Downloading data: 49%|████▉ | 147M/299M [00:03<00:03, 43.7MB/s]
Downloading data: 52%|█████▏ | 155M/299M [00:03<00:03, 44.2MB/s]
Downloading data: 55%|█████▍ | 164M/299M [00:04<00:02, 45.9MB/s]
Downloading data: 57%|█████▋ | 172M/299M [00:04<00:02, 47.0MB/s]
Downloading data: 60%|██████ | 180M/299M [00:04<00:02, 46.5MB/s]
Downloading data: 63%|██████▎ | 189M/299M [00:04<00:02, 48.0MB/s]
Downloading data: 66%|██████▌ | 197M/299M [00:04<00:02, 47.5MB/s]
Downloading data: 69%|██████▊ | 206M/299M [00:04<00:01, 49.7MB/s]
Downloading data: 71%|███████▏ | 214M/299M [00:05<00:01, 45.0MB/s]
Downloading data: 74%|███████▍ | 222M/299M [00:05<00:01, 46.8MB/s]
Downloading data: 77%|███████▋ | 231M/299M [00:05<00:01, 47.8MB/s]
Downloading data: 80%|███████▉ | 239M/299M [00:05<00:01, 48.1MB/s]
Downloading data: 83%|████████▎ | 247M/299M [00:05<00:01, 43.1MB/s]
Downloading data: 85%|████████▌ | 256M/299M [00:06<00:00, 45.6MB/s]
Downloading data: 88%|████████▊ | 264M/299M [00:06<00:00, 48.1MB/s]
Downloading data: 91%|█████████ | 273M/299M [00:06<00:00, 48.3MB/s]
Downloading data: 94%|█████████▍| 281M/299M [00:06<00:00, 47.7MB/s]
Downloading data: 97%|█████████▋| 289M/299M [00:06<00:00, 48.0MB/s]
Downloading data: 99%|█████████▉| 298M/299M [00:06<00:00, 50.0MB/s]
Downloading data: 100%|██████████| 299M/299M [00:06<00:00, 43.0MB/s]
Downloading data: 0%| | 0.00/23.5M [00:00<?, ?B/s]
Downloading data: 18%|█▊ | 4.19M/23.5M [00:00<00:01, 18.0MB/s]
Downloading data: 54%|█████▎ | 12.6M/23.5M [00:00<00:00, 33.7MB/s]
Downloading data: 100%|██████████| 23.5M/23.5M [00:00<00:00, 38.5MB/s]
Generating train split: 0%| | 0/650000 [00:00<?, ? examples/s]
Generating train split: 2%|▏ | 10000/650000 [00:00<00:12, 53061.09 examples/s]
Generating train split: 8%|▊ | 50000/650000 [00:00<00:03, 177970.09 examples/s]
Generating train split: 14%|█▍ | 90000/650000 [00:00<00:02, 241849.95 examples/s]
Generating train split: 20%|██ | 130000/650000 [00:00<00:01, 268863.13 examples/s]
Generating train split: 26%|██▌ | 170000/650000 [00:00<00:01, 253807.33 examples/s]
Generating train split: 32%|███▏ | 210000/650000 [00:00<00:01, 257649.77 examples/s]
Generating train split: 38%|███▊ | 250000/650000 [00:01<00:01, 253667.16 examples/s]
Generating train split: 45%|████▍ | 290000/650000 [00:01<00:01, 271412.63 examples/s]
Generating train split: 51%|█████ | 330000/650000 [00:01<00:01, 265042.75 examples/s]
Generating train split: 57%|█████▋ | 370000/650000 [00:01<00:01, 260300.41 examples/s]
Generating train split: 63%|██████▎ | 410000/650000 [00:01<00:00, 247497.01 examples/s]
Generating train split: 69%|██████▉ | 450000/650000 [00:01<00:00, 239998.89 examples/s]
Generating train split: 75%|███████▌ | 490000/650000 [00:02<00:00, 235786.32 examples/s]
Generating train split: 80%|████████ | 520000/650000 [00:02<00:00, 231040.12 examples/s]
Generating train split: 86%|████████▌ | 560000/650000 [00:02<00:00, 234604.52 examples/s]
Generating train split: 92%|█████████▏| 600000/650000 [00:02<00:00, 234508.34 examples/s]
Generating train split: 100%|██████████| 650000/650000 [00:02<00:00, 237989.20 examples/s]
Generating test split: 0%| | 0/50000 [00:00<?, ? examples/s]
Generating test split: 80%|████████ | 40000/50000 [00:00<00:00, 248449.76 examples/s]
Generating test split: 100%|██████████| 50000/50000 [00:00<00:00, 247162.55 examples/s]
Map: 0%| | 0/1000 [00:00<?, ? examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 2898.10 examples/s]
(RayTrainWorker pid=66009) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
(RayTrainWorker pid=66009) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(RayTrainWorker pid=66009) ============================= HABANA PT BRIDGE CONFIGURATION ===========================
(RayTrainWorker pid=66009) PT_HPU_LAZY_MODE = 1
(RayTrainWorker pid=66009) PT_RECIPE_CACHE_PATH =
(RayTrainWorker pid=66009) PT_CACHE_FOLDER_DELETE = 0
(RayTrainWorker pid=66009) PT_HPU_RECIPE_CACHE_CONFIG =
(RayTrainWorker pid=66009) PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(RayTrainWorker pid=66009) PT_HPU_LAZY_ACC_PAR_MODE = 1
(RayTrainWorker pid=66009) PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(RayTrainWorker pid=66009) ---------------------------: System Configuration :---------------------------
(RayTrainWorker pid=66009) Num CPU Cores : 160
(RayTrainWorker pid=66009) CPU RAM : 1056389756 KB
(RayTrainWorker pid=66009) ------------------------------------------------------------------------------
Map: 0%| | 0/1000 [00:00<?, ? examples/s] [repeated 3x across cluster]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 3179.11 examples/s] [repeated 3x across cluster]
(RayTrainWorker pid=66010) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
(RayTrainWorker pid=66010) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(RayTrainWorker pid=66010) train epoch:[0] loss:1.782888
(RayTrainWorker pid=66010) train epoch:[0] loss:2.250521 [repeated 2x across cluster]
(RayTrainWorker pid=66010) train epoch:[0] loss:2.005397 [repeated 114x across cluster]
(RayTrainWorker pid=66010) train epoch:[0] loss:1.583421 [repeated 122x across cluster]
(RayTrainWorker pid=66010) train epoch:[0] loss:1.873015 [repeated 117x across cluster]
(RayTrainWorker pid=66010) train epoch:[0] loss:1.287454 [repeated 111x across cluster]
(RayTrainWorker pid=66010) train epoch:[1] loss:1.256705 [repeated 35x across cluster]
(RayTrainWorker pid=66010) train epoch:[1] loss:1.783350 [repeated 112x across cluster]
(RayTrainWorker pid=66009) train epoch:[1] loss:1.161693 [repeated 117x across cluster]
(RayTrainWorker pid=66010) train epoch:[1] loss:1.083962 [repeated 121x across cluster]
(RayTrainWorker pid=66010) train epoch:[1] loss:1.452244 [repeated 126x across cluster]
(RayTrainWorker pid=66010) train epoch:[2] loss:0.848569 [repeated 23x across cluster]
(RayTrainWorker pid=66010) train epoch:[2] loss:0.935847 [repeated 104x across cluster]
(RayTrainWorker pid=66010) train epoch:[2] loss:2.003910 [repeated 133x across cluster]
(RayTrainWorker pid=66010) train epoch:[2] loss:0.719678 [repeated 119x across cluster]
(RayTrainWorker pid=66009) train epoch:[2] loss:1.115227 [repeated 128x across cluster]
(RayTrainWorker pid=66010) train epoch:[3] loss:1.476088 [repeated 16x across cluster]
(RayTrainWorker pid=66010) train epoch:[3] loss:0.938356 [repeated 95x across cluster]
(RayTrainWorker pid=66010) train epoch:[3] loss:0.880045 [repeated 124x across cluster]
(RayTrainWorker pid=66010) train epoch:[3] loss:0.906078 [repeated 126x across cluster]
(RayTrainWorker pid=66010)
(RayTrainWorker pid=66010) train epoch:[3] loss:0.977447 [repeated 121x across cluster]
(RayTrainWorker pid=66010) train epoch:[4] loss:0.545720 [repeated 34x across cluster]
(RayTrainWorker pid=66010) train epoch:[4] loss:0.733710 [repeated 114x across cluster]
(RayTrainWorker pid=66010) train epoch:[4] loss:0.894966 [repeated 121x across cluster]
(RayTrainWorker pid=66010) train epoch:[4] loss:1.428036 [repeated 122x across cluster]
(RayTrainWorker pid=66010) train epoch:[4] loss:1.482066 [repeated 122x across cluster]
(RayTrainWorker pid=66010) train epoch:[5] loss:1.564706 [repeated 22x across cluster]
(RayTrainWorker pid=66010) train epoch:[5] loss:1.853072 [repeated 121x across cluster]
(RayTrainWorker pid=66010) train epoch:[5] loss:2.260058 [repeated 129x across cluster]
(RayTrainWorker pid=66010) train epoch:[5] loss:1.414144 [repeated 128x across cluster]
(RayTrainWorker pid=66009) train epoch:[5] loss:0.980207 [repeated 118x across cluster]
(RayTrainWorker pid=66010) train epoch:[6] loss:1.559380 [repeated 7x across cluster]
(RayTrainWorker pid=66010) train epoch:[6] loss:1.634878 [repeated 123x across cluster]
(RayTrainWorker pid=66010) train epoch:[6] loss:1.564483 [repeated 132x across cluster]
(RayTrainWorker pid=66010) train epoch:[6] loss:1.733673 [repeated 136x across cluster]
(RayTrainWorker pid=66010) train epoch:[7] loss:1.582968 [repeated 105x across cluster]
(RayTrainWorker pid=66010) train epoch:[7] loss:1.486512 [repeated 133x across cluster]
(RayTrainWorker pid=66010) train epoch:[7] loss:1.723742 [repeated 134x across cluster]
(RayTrainWorker pid=66010) train epoch:[7] loss:1.556943 [repeated 137x across cluster]
(RayTrainWorker pid=66010) train epoch:[8] loss:1.613637 [repeated 96x across cluster]
(RayTrainWorker pid=66010) train epoch:[8] loss:1.744777 [repeated 132x across cluster]
(RayTrainWorker pid=66010) train epoch:[8] loss:1.816669 [repeated 131x across cluster]
(RayTrainWorker pid=66010) train epoch:[8] loss:1.313460 [repeated 128x across cluster]
(RayTrainWorker pid=66009) train epoch:[9] loss:1.920412 [repeated 109x across cluster]
(RayTrainWorker pid=66010) train epoch:[9] loss:1.687392 [repeated 131x across cluster]
(RayTrainWorker pid=66010) train epoch:[9] loss:1.714871 [repeated 126x across cluster]
(RayTrainWorker pid=66010) train epoch:[9] loss:1.679613 [repeated 139x across cluster]
Trial TorchTrainer_fb74f_00000 completed. Last result:
2024-02-28 07:05:06,559 INFO tune.py:1042 -- Total run time: 309.37 seconds (309.32 seconds for the tuning loop).
Training result: Result(
metrics={},
path='/root/ray_results/TorchTrainer_2024-02-28_06-59-57/TorchTrainer_fb74f_00000_0_2024-02-28_06-59-57',
filesystem='local',
checkpoint=None
)