ResNet Model Training with Intel Gaudi#

In this Jupyter notebook, we will train a ResNet-50 model to classify images of ants and bees using HPU. We will use PyTorch for model training and Ray for distributed training. The dataset will be downloaded and processed using torchvision’s datasets and transforms.

Intel Gaudi AI Processors (HPUs) are AI hardware accelerators designed by Intel Habana Labs. For more information, see Gaudi Architecture and Gaudi Developer Docs.

Configuration#

A node with Gaudi/Gaudi2 installed is required to run this example. Both Gaudi and Gaudi2 have 8 HPUs. We will use 2 workers to train the model, each using 1 HPU.

We recommend using a prebuilt container to run these examples. To run a container, you need Docker. See Install Docker Engine for installation instructions.

Next, follow Run Using Containers to install the Gaudi drivers and container runtime.

Next, start the Gaudi container:

docker pull vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest

Inside the container, install Ray and Jupyter to run this notebook.

pip install ray[train] notebook
import os
from typing import Dict
from tempfile import TemporaryDirectory

import torch
from filelock import FileLock
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm

import ray
import ray.train as train
from ray.train import ScalingConfig, Checkpoint
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchConfig
from ray.runtime_env import RuntimeEnv

import habana_frameworks.torch.core as htcore

Define Data Transforms#

We will set up the data transforms for preprocessing images for training and validation. This includes random cropping, flipping, and normalization for the training set, and resizing and normalization for the validation set.

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    "val": transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
}

Dataset Download Function#

We will define a function to download the Hymenoptera dataset. This dataset contains images of ants and bees for a binary classification problem.

def download_datasets():
    os.system("wget https://download.pytorch.org/tutorial/hymenoptera_data.zip >/dev/null 2>&1")
    os.system("unzip hymenoptera_data.zip >/dev/null 2>&1")

Dataset Preparation Function#

After downloading the dataset, we need to build PyTorch datasets for training and validation. The build_datasets function will apply the previously defined transforms and create the datasets.

def build_datasets():
    torch_datasets = {}
    for split in ["train", "val"]:
        torch_datasets[split] = datasets.ImageFolder(
            os.path.join("./hymenoptera_data", split), data_transforms[split]
        )
    return torch_datasets

Model Initialization Functions#

We will define two functions to initialize our model. The initialize_model function will load a pre-trained ResNet-50 model and replace the final classification layer for our binary classification task. The initialize_model_from_checkpoint function will load a model from a saved checkpoint if available.

def initialize_model():
    # Load pretrained model params
    model = models.resnet50(pretrained=True)

    # Replace the original classifier with a new Linear layer
    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    # Ensure all params get updated during finetuning
    for param in model.parameters():
        param.requires_grad = True
    return model

Evaluation Function#

To assess the performance of our model during training, we define an evaluate function. This function computes the number of correct predictions by comparing the predicted labels with the true labels.

def evaluate(logits, labels):
    _, preds = torch.max(logits, 1)
    corrects = torch.sum(preds == labels).item()
    return corrects

Training Loop Function#

This function defines the training loop that will be executed by each worker. It includes downloading the dataset, preparing data loaders, initializing the model, and running the training and validation phases. Compared to a training function for GPU, no changes are needed to port to HPU. Internally, Ray Train does these things:

  • Detect HPU and set the device.

  • Initializes the habana PyTorch backend.

  • Initializes the habana distributed backend.

def train_loop_per_worker(configs):
    import warnings

    warnings.filterwarnings("ignore")

    # Calculate the batch size for a single worker
    worker_batch_size = configs["batch_size"] // train.get_context().get_world_size()

    # Download dataset once on local rank 0 worker
    if train.get_context().get_local_rank() == 0:
        download_datasets()
    torch.distributed.barrier()

    # Build datasets on each worker
    torch_datasets = build_datasets()

    # Prepare dataloader for each worker
    dataloaders = dict()
    dataloaders["train"] = DataLoader(
        torch_datasets["train"], batch_size=worker_batch_size, shuffle=True
    )
    dataloaders["val"] = DataLoader(
        torch_datasets["val"], batch_size=worker_batch_size, shuffle=False
    )

    # Distribute
    dataloaders["train"] = train.torch.prepare_data_loader(dataloaders["train"])
    dataloaders["val"] = train.torch.prepare_data_loader(dataloaders["val"])

    # Obtain HPU device automatically
    device = train.torch.get_device()

    # Prepare DDP Model, optimizer, and loss function
    model = initialize_model()
    model = model.to(device)

    optimizer = optim.SGD(
        model.parameters(), lr=configs["lr"], momentum=configs["momentum"]
    )
    criterion = nn.CrossEntropyLoss()

    # Start training loops
    for epoch in range(configs["num_epochs"]):
        # Each epoch has a training and validation phase
        for phase in ["train", "val"]:
            if phase == "train":
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase == "train"):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # calculate statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += evaluate(outputs, labels)

            size = len(torch_datasets[phase]) // train.get_context().get_world_size()
            epoch_loss = running_loss / size
            epoch_acc = running_corrects / size

            if train.get_context().get_world_rank() == 0:
                print(
                    "Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format(
                        epoch, phase, epoch_loss, epoch_acc
                    )
                )

            # Report metrics and checkpoint every epoch
            if phase == "val":
                train.report(
                    metrics={"loss": epoch_loss, "acc": epoch_acc},
                )

Main Training Function#

The train_resnet function sets up the distributed training environment using Ray and starts the training process. It specifies the batch size, number of epochs, learning rate, and momentum for the SGD optimizer. To enable training using HPU, we only need to make the following changes:

  • Require an HPU for each worker in ScalingConfig

  • Set backend to “hccl” in TorchConfig

def train_resnet(num_workers=2):
    global_batch_size = 16

    train_loop_config = {
        "input_size": 224,  # Input image size (224 x 224)
        "batch_size": 32,  # Batch size for training
        "num_epochs": 10,  # Number of epochs to train for
        "lr": 0.001,  # Learning Rate
        "momentum": 0.9,  # SGD optimizer momentum
    }
    # Configure computation resources
    # In ScalingConfig, require an HPU for each worker
    scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={"CPU": 1, "HPU": 1})
    # Set backend to hccl in TorchConfig
    torch_config = TorchConfig(backend = "hccl")
    
    ray.init()
    
    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config=train_loop_config,
        torch_config=torch_config,
        scaling_config=scaling_config,
    )

    result = trainer.fit()
    print(f"Training result: {result}")

Start Training#

Finally, we call the train_resnet function to start the training process. You can adjust the number of workers to use. Before running this cell, ensure that Ray is properly set up in your environment to handle distributed training.

Note: the following warning is fine, and is resolved in SynapseAI version 1.14.0+:

/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
train_resnet(num_workers=2) 

Tune Status

Current time:2024-02-28 07:31:55
Running for: 00:00:55.04
Memory: 389.2/1007.5 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 3.0/160 CPUs, 0/0 GPUs (2.0/8.0 HPU, 0.0/1.0 TPU)

Trial Status

Trial name status loc iter total time (s) loss acc
TorchTrainer_521db_00000TERMINATED172.17.0.3:109080 10 49.30960.1546480.986842
(pid=109080) /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
(pid=109080)   warnings.warn(
(RayTrainWorker pid=115673) Setting up process group for: env:// [rank=0, world_size=2]
(TorchTrainer pid=109080) Started distributed worker processes: 
(TorchTrainer pid=109080) - (ip=172.17.0.3, pid=115673) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=109080) - (ip=172.17.0.3, pid=115678) world_rank=1, local_rank=1, node_rank=0
(RayTrainWorker pid=115673) /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`. [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=115673)   warnings.warn( [repeated 2x across cluster]
(RayTrainWorker pid=115673) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(RayTrainWorker pid=115673)  PT_HPU_LAZY_MODE = 1
(RayTrainWorker pid=115673)  PT_RECIPE_CACHE_PATH = 
(RayTrainWorker pid=115673)  PT_CACHE_FOLDER_DELETE = 0
(RayTrainWorker pid=115673)  PT_HPU_RECIPE_CACHE_CONFIG = 
(RayTrainWorker pid=115673)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(RayTrainWorker pid=115673)  PT_HPU_LAZY_ACC_PAR_MODE = 1
(RayTrainWorker pid=115673)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(RayTrainWorker pid=115673) ---------------------------: System Configuration :---------------------------
(RayTrainWorker pid=115673) Num CPU Cores : 160
(RayTrainWorker pid=115673) CPU RAM       : 1056389756 KB
(RayTrainWorker pid=115673) ------------------------------------------------------------------------------
(RayTrainWorker pid=115673) Epoch 0-train Loss: 0.6667 Acc: 0.6148
(RayTrainWorker pid=115673) Epoch 0-val Loss: 0.5717 Acc: 0.6053
(RayTrainWorker pid=115673) Epoch 1-train Loss: 0.5248 Acc: 0.7295
(RayTrainWorker pid=115673) Epoch 1-val Loss: 0.3194 Acc: 0.9605
(RayTrainWorker pid=115673) Epoch 2-train Loss: 0.3100 Acc: 0.9016
(RayTrainWorker pid=115673) Epoch 2-val Loss: 0.2336 Acc: 0.9474
(RayTrainWorker pid=115673) Epoch 3-train Loss: 0.2391 Acc: 0.9180
(RayTrainWorker pid=115673) Epoch 3-val Loss: 0.1789 Acc: 0.9737
(RayTrainWorker pid=115673) Epoch 4-train Loss: 0.1780 Acc: 0.9508
(RayTrainWorker pid=115673) Epoch 4-val Loss: 0.1696 Acc: 0.9605
(RayTrainWorker pid=115673) Epoch 5-train Loss: 0.1447 Acc: 0.9754
(RayTrainWorker pid=115673) Epoch 5-val Loss: 0.1534 Acc: 0.9737
(RayTrainWorker pid=115673) Epoch 6-train Loss: 0.1398 Acc: 0.9426
(RayTrainWorker pid=115673) Epoch 6-val Loss: 0.1606 Acc: 0.9605
(RayTrainWorker pid=115673) Epoch 7-train Loss: 0.1398 Acc: 0.9590
(RayTrainWorker pid=115673) Epoch 7-val Loss: 0.1582 Acc: 0.9605
(RayTrainWorker pid=115673) Epoch 8-train Loss: 0.0856 Acc: 0.9754
(RayTrainWorker pid=115673) Epoch 8-val Loss: 0.1552 Acc: 0.9605
(RayTrainWorker pid=115673) Epoch 9-train Loss: 0.0602 Acc: 0.9836
(RayTrainWorker pid=115673) Epoch 9-val Loss: 0.1546 Acc: 0.9868
2024-02-28 07:31:55,645	INFO tune.py:1042 -- Total run time: 55.08 seconds (55.04 seconds for the tuning loop).
Training result: Result(
  metrics={'loss': 0.15464812321098229, 'acc': 0.9868421052631579},
  path='/root/ray_results/TorchTrainer_2024-02-28_07-31-00/TorchTrainer_521db_00000_0_2024-02-28_07-31-00',
  filesystem='local',
  checkpoint=None
)