Serve Llama2-7b/70b on a single or multiple Intel Gaudi Accelerator#

Intel Gaudi AI Processors (HPUs) are AI hardware accelerators designed by Intel Habana Labs. See Gaudi Architecture and Gaudi Developer Docs for more details.

This tutorial has two examples:

  1. Deployment of Llama2-7b using a single HPU:

    • Load a model onto an HPU.

    • Perform generation on an HPU.

    • Enable HPU Graph optimizations.

  2. Deployment of Llama2-70b using multiple HPUs on a single node:

    • Initialize a distributed backend.

    • Load a sharded model onto DeepSpeed workers.

    • Stream responses from DeepSpeed workers.

This tutorial serves a large language model (LLM) on HPUs.

Environment setup#

Use a prebuilt container to run these examples. To run a container, you need Docker. See Install Docker Engine for installation instructions.

Next, follow Run Using Containers to install the Gaudi drivers and container runtime. To verify your installation, start a shell and run hl-smi. It should print status information about the HPUs on the machine:

+-----------------------------------------------------------------------------+
| HL-SMI Version:                              hl-1.14.0-fw-48.0.1.0          |
| Driver Version:                                     1.15.0-c43dc7b          |
|-------------------------------+----------------------+----------------------+
| AIP  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | AIP-Util  Compute M. |
|===============================+======================+======================|
|   0  HL-225              N/A  | 0000:09:00.0     N/A |                   0  |
| N/A   26C   N/A    87W / 600W |    768MiB / 98304MiB |     0%           N/A |
|-------------------------------+----------------------+----------------------+
|   1  HL-225              N/A  | 0000:08:00.0     N/A |                   0  |
| N/A   28C   N/A    99W / 600W |    768MiB / 98304MiB |     0%           N/A |
|-------------------------------+----------------------+----------------------+
|   2  HL-225              N/A  | 0000:0a:00.0     N/A |                   0  |
| N/A   24C   N/A    98W / 600W |    768MiB / 98304MiB |     0%           N/A |
|-------------------------------+----------------------+----------------------+
|   3  HL-225              N/A  | 0000:0c:00.0     N/A |                   0  |
| N/A   27C   N/A    87W / 600W |    768MiB / 98304MiB |     0%           N/A |
|-------------------------------+----------------------+----------------------+
|   4  HL-225              N/A  | 0000:0b:00.0     N/A |                   0  |
| N/A   25C   N/A   112W / 600W |    768MiB / 98304MiB |     0%           N/A |
|-------------------------------+----------------------+----------------------+
|   5  HL-225              N/A  | 0000:0d:00.0     N/A |                   0  |
| N/A   26C   N/A   111W / 600W |  26835MiB / 98304MiB |     0%           N/A |
|-------------------------------+----------------------+----------------------+
|   6  HL-225              N/A  | 0000:0f:00.0     N/A |                   0  |
| N/A   24C   N/A    93W / 600W |    768MiB / 98304MiB |     0%           N/A |
|-------------------------------+----------------------+----------------------+
|   7  HL-225              N/A  | 0000:0e:00.0     N/A |                   0  |
| N/A   25C   N/A    86W / 600W |    768MiB / 98304MiB |     0%           N/A |
|-------------------------------+----------------------+----------------------+
| Compute Processes:                                               AIP Memory |
|  AIP       PID   Type   Process name                             Usage      |
|=============================================================================|
|   0        N/A   N/A    N/A                                      N/A        |
|   1        N/A   N/A    N/A                                      N/A        |
|   2        N/A   N/A    N/A                                      N/A        |
|   3        N/A   N/A    N/A                                      N/A        |
|   4        N/A   N/A    N/A                                      N/A        |
|   5        N/A   N/A    N/A                                      N/A        |
|   6        N/A   N/A    N/A                                      N/A        |
|   7        N/A   N/A    N/A                                      N/A        |
+=============================================================================+

Next, start the Gaudi container:

docker pull vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest

To follow the examples in this tutorial, mount the directory containing the examples and models into the container. Inside the container, run:

pip install ray[tune,serve]
pip install git+https://github.com/huggingface/optimum-habana.git
# Replace 1.14.0 with the driver version of the container.
pip install git+https://github.com/HabanaAI/[email protected]
# Only needed by the DeepSpeed example.
export RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES=1

Start Ray in the container with ray start --head. You are now ready to run the examples.

Running a model on a single HPU#

This example shows how to deploy a Llama2-7b model on an HPU for inference.

First, define a deployment that serves a Llama2-7b model using an HPU. Note that we enable HPU graph optimizations for better performance.

import asyncio
from functools import partial
from queue import Empty
from typing import Dict, Any

from starlette.responses import Request, StreamingResponse
import torch

from ray import serve


# Define the Ray Serve deployment
@serve.deployment(ray_actor_options={"num_cpus": 10, "resources": {"HPU": 1}})
class LlamaModel:
    def __init__(self, model_id_or_path: str):
        from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
        from optimum.habana.transformers.modeling_utils import (
            adapt_transformers_to_gaudi,
        )

        # Tweak transformers to optimize performance
        adapt_transformers_to_gaudi()

        self.device = torch.device("hpu")

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id_or_path, use_fast=False, use_auth_token=""
        )
        hf_config = AutoConfig.from_pretrained(
            model_id_or_path,
            torchscript=True,
            use_auth_token="",
            trust_remote_code=False,
        )
        # Load the model in Gaudi
        model = AutoModelForCausalLM.from_pretrained(
            model_id_or_path,
            config=hf_config,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            use_auth_token="",
        )
        model = model.eval().to(self.device)

        from habana_frameworks.torch.hpu import wrap_in_hpu_graph

        # Enable hpu graph runtime
        self.model = wrap_in_hpu_graph(model)

        # Set pad token, etc.
        self.tokenizer.pad_token_id = self.model.generation_config.pad_token_id
        self.tokenizer.padding_side = "left"

        # Use async loop in streaming
        self.loop = asyncio.get_running_loop()

    def tokenize(self, prompt: str):
        """Tokenize the input and move to HPU."""

        input_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True)
        return input_tokens.input_ids.to(device=self.device)

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Take a prompt and generate a response."""

        input_ids = self.tokenize(prompt)
        gen_tokens = self.model.generate(input_ids, **config)
        return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

    async def consume_streamer_async(self, streamer):
        """Consume the streamer asynchronously."""

        while True:
            try:
                for token in streamer:
                    yield token
                break
            except Empty:
                await asyncio.sleep(0.001)

    def streaming_generate(self, prompt: str, streamer, **config: Dict[str, Any]):
        """Generate a streamed response given an input."""

        input_ids = self.tokenize(prompt)
        self.model.generate(input_ids, streamer=streamer, **config)

    async def __call__(self, http_request: Request):
        """Handle HTTP requests."""

        # Load fields from the request
        json_request: str = await http_request.json()
        text = json_request["text"]
        # Config used in generation
        config = json_request.get("config", {})
        streaming_response = json_request["stream"]

        # Prepare prompts
        prompts = []
        if isinstance(text, list):
            prompts.extend(text)
        else:
            prompts.append(text)

        # Process config
        config.setdefault("max_new_tokens", 128)

        # Enable HPU graph runtime
        config["hpu_graphs"] = True
        # Lazy mode should be True when using HPU graphs
        config["lazy_mode"] = True

        # Non-streaming case
        if not streaming_response:
            return self.generate(prompts, **config)

        # Streaming case
        from transformers import TextIteratorStreamer

        streamer = TextIteratorStreamer(
            self.tokenizer, skip_prompt=True, timeout=0, skip_special_tokens=True
        )
        # Convert the streamer into a generator
        self.loop.run_in_executor(
            None, partial(self.streaming_generate, prompts, streamer, **config)
        )
        return StreamingResponse(
            self.consume_streamer_async(streamer),
            status_code=200,
            media_type="text/plain",
        )


# Replace the model ID with path if necessary
entrypoint = LlamaModel.bind("meta-llama/Llama-2-7b-chat-hf")

Copy the code above and save it as intel_gaudi_inference_serve.py. Start the deployment like this:

serve run intel_gaudi_inference_serve:entrypoint

The terminal should print logs as the deployment starts up:

2024-02-01 05:38:34,021 INFO scripts.py:438 -- Running import path: 'ray_serve_7b:entrypoint'.
2024-02-01 05:38:36,112 INFO worker.py:1540 -- Connecting to existing Ray cluster at address: 10.111.128.177:6379...
2024-02-01 05:38:36,124 INFO worker.py:1715 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 
(ProxyActor pid=17179) INFO 2024-02-01 05:38:39,573 proxy 10.111.128.177 proxy.py:1141 - Proxy actor b0c697edb66f42a46f802f4603000000 starting on node 7776cd4634f69216c8354355018195b290314ad24fd9565404a2ed12.
(ProxyActor pid=17179) INFO 2024-02-01 05:38:39,580 proxy 10.111.128.177 proxy.py:1346 - Starting HTTP server on node: 7776cd4634f69216c8354355018195b290314ad24fd9565404a2ed12 listening on port 8000
(ProxyActor pid=17179) INFO:     Started server process [17179]
(ServeController pid=17084) INFO 2024-02-01 05:38:39,677 controller 17084 deployment_state.py:1545 - Deploying new version of deployment LlamaModel in application 'default'. Setting initial target number of replicas to 1.
(ServeController pid=17084) INFO 2024-02-01 05:38:39,780 controller 17084 deployment_state.py:1829 - Adding 1 replica to deployment LlamaModel in application 'default'.
(ServeReplica:default:LlamaModel pid=17272) [WARNING|utils.py:198] 2024-02-01 05:38:48,700 >> optimum-habana v1.11.0.dev0 has been validated for SynapseAI v1.14.0 but the driver version is v1.15.0, this could lead to undefined behavior!
(ServeReplica:default:LlamaModel pid=17272) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/tokenization_auto.py:655: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.
(ServeReplica:default:LlamaModel pid=17272)   warnings.warn(
(ServeReplica:default:LlamaModel pid=17272) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py:1020: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.
(ServeReplica:default:LlamaModel pid=17272)   warnings.warn(
(ServeReplica:default:LlamaModel pid=17272) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py:472: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.
(ServeReplica:default:LlamaModel pid=17272)   warnings.warn(
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:17<00:17, 17.90s/it]
(ServeController pid=17084) WARNING 2024-02-01 05:39:09,835 controller 17084 deployment_state.py:2171 - Deployment 'LlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading checkpoint shards: 100%|██████████| 2/2 [00:24<00:00, 12.36s/it]
(ServeReplica:default:LlamaModel pid=17272) /usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
(ServeReplica:default:LlamaModel pid=17272)   warnings.warn(
(ServeReplica:default:LlamaModel pid=17272) /usr/local/lib/python3.10/dist-packages/transformers/generation/configuration_utils.py:367: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
(ServeReplica:default:LlamaModel pid=17272)   warnings.warn(
(ServeReplica:default:LlamaModel pid=17272) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(ServeReplica:default:LlamaModel pid=17272)  PT_HPU_LAZY_MODE = 1
(ServeReplica:default:LlamaModel pid=17272)  PT_RECIPE_CACHE_PATH = 
(ServeReplica:default:LlamaModel pid=17272)  PT_CACHE_FOLDER_DELETE = 0
(ServeReplica:default:LlamaModel pid=17272)  PT_HPU_RECIPE_CACHE_CONFIG = 
(ServeReplica:default:LlamaModel pid=17272)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(ServeReplica:default:LlamaModel pid=17272)  PT_HPU_LAZY_ACC_PAR_MODE = 1
(ServeReplica:default:LlamaModel pid=17272)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(ServeReplica:default:LlamaModel pid=17272) ---------------------------: System Configuration :---------------------------
(ServeReplica:default:LlamaModel pid=17272) Num CPU Cores : 156
(ServeReplica:default:LlamaModel pid=17272) CPU RAM       : 495094196 KB
(ServeReplica:default:LlamaModel pid=17272) ------------------------------------------------------------------------------
2024-02-01 05:39:25,873 SUCC scripts.py:483 -- Deployed Serve app successfully.

In another shell, use the following code to send requests to the deployment to perform generation tasks.

import requests

# Prompt for the model
prompt = "Once upon a time,"

# Add generation config here
config = {}

# Non-streaming response
sample_input = {"text": prompt, "config": config, "stream": False}
outputs = requests.post("http://127.0.0.1:8000/", json=sample_input, stream=False)
print(outputs.text, flush=True)

# Streaming response
sample_input["stream"] = True
outputs = requests.post("http://127.0.0.1:8000/", json=sample_input, stream=True)
outputs.raise_for_status()
for output in outputs.iter_content(chunk_size=None, decode_unicode=True):
    print(output, end="", flush=True)
print()

Here is an example output:

Once upon a time, in a far-off land, there was a magical kingdom called "Happily Ever Laughter." It was a place where laughter was the key to unlocking all the joys of life, and where everyone lived in perfect harmony.
In this kingdom, there was a beautiful princess named Lily. She was kind, gentle, and had a heart full of laughter. Every day, she would wake up with a smile on her face, ready to face whatever adventures the day might bring.
One day, a wicked sorcerer cast a spell on the kingdom, causing all
in a far-off land, there was a magical kingdom called "Happily Ever Laughter." It was a place where laughter was the key to unlocking all the joys of life, and where everyone lived in perfect harmony.
In this kingdom, there was a beautiful princess named Lily. She was kind, gentle, and had a heart full of laughter. Every day, she would wake up with a smile on her face, ready to face whatever adventures the day might bring.
One day, a wicked sorcerer cast a spell on the kingdom, causing all

Running a sharded model on multiple HPUs#

This example deploys a Llama2-70b model using 8 HPUs orchestrated by DeepSpeed.

The example requires caching the Llama2-70b model. Run the following Python code in the Gaudi container to cache the model.

from huggingface_hub import snapshot_download
snapshot_download(
    "meta-llama/Llama-2-70b-chat-hf",
    # Replace the path if necessary.
    cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
    # Specify your Hugging Face token.
    token=""
)

In this example, the deployment replica sends prompts to the DeepSpeed workers, which are running in Ray actors:

import tempfile
from typing import Dict, Any
from starlette.responses import Request, StreamingResponse

import torch
from transformers import TextStreamer

import ray
from ray import serve
from ray.util.queue import Queue
from ray.runtime_env import RuntimeEnv


@ray.remote(resources={"HPU": 1})
class DeepSpeedInferenceWorker:
    def __init__(self, model_id_or_path: str, world_size: int, local_rank: int):
        """An actor that runs a DeepSpeed inference engine.

        Arguments:
            model_id_or_path: Either a Hugging Face model ID
                or a path to a cached model.
            world_size: Total number of worker processes.
            local_rank: Rank of this worker process.
                The rank 0 worker is the head worker.
        """
        from transformers import AutoTokenizer, AutoConfig
        from optimum.habana.transformers.modeling_utils import (
            adapt_transformers_to_gaudi,
        )

        # Tweak transformers for better performance on Gaudi.
        adapt_transformers_to_gaudi()

        self.model_id_or_path = model_id_or_path
        self._world_size = world_size
        self._local_rank = local_rank
        self.device = torch.device("hpu")

        self.model_config = AutoConfig.from_pretrained(
            model_id_or_path,
            torch_dtype=torch.bfloat16,
            token="",
            trust_remote_code=False,
        )

        # Load and configure the tokenizer.
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id_or_path, use_fast=False, token=""
        )
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        import habana_frameworks.torch.distributed.hccl as hccl

        # Initialize the distributed backend.
        hccl.initialize_distributed_hpu(
            world_size=world_size, rank=local_rank, local_rank=local_rank
        )
        torch.distributed.init_process_group(backend="hccl")

    def load_model(self):
        """Load the model to HPU and initialize the DeepSpeed inference engine."""

        import deepspeed
        from transformers import AutoModelForCausalLM
        from optimum.habana.checkpoint_utils import (
            get_ds_injection_policy,
            write_checkpoints_json,
        )

        # Construct the model with fake meta Tensors.
        # Loads the model weights from the checkpoint later.
        with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"):
            model = AutoModelForCausalLM.from_config(
                self.model_config, torch_dtype=torch.bfloat16
            )
        model = model.eval()

        # Create a file to indicate where the checkpoint is.
        checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="w+")
        write_checkpoints_json(
            self.model_id_or_path, self._local_rank, checkpoints_json, token=""
        )

        # Prepare the DeepSpeed inference configuration.
        kwargs = {"dtype": torch.bfloat16}
        kwargs["checkpoint"] = checkpoints_json.name
        kwargs["tensor_parallel"] = {"tp_size": self._world_size}
        # Enable the HPU graph, similar to the cuda graph.
        kwargs["enable_cuda_graph"] = True
        # Specify the injection policy, required by DeepSpeed Tensor parallelism.
        kwargs["injection_policy"] = get_ds_injection_policy(self.model_config)

        # Initialize the inference engine.
        self.model = deepspeed.init_inference(model, **kwargs).module

    def tokenize(self, prompt: str):
        """Tokenize the input and move it to HPU."""

        input_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True)
        return input_tokens.input_ids.to(device=self.device)

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Take in a prompt and generate a response."""

        input_ids = self.tokenize(prompt)
        gen_tokens = self.model.generate(input_ids, **config)
        return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

    def streaming_generate(self, prompt: str, streamer, **config: Dict[str, Any]):
        """Generate a streamed response given an input."""

        input_ids = self.tokenize(prompt)
        self.model.generate(input_ids, streamer=streamer, **config)

    def get_streamer(self):
        """Return a streamer.

        We only need the rank 0 worker's result.
        Other workers return a fake streamer.
        """

        if self._local_rank == 0:
            return RayTextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
        else:

            class FakeStreamer:
                def put(self, value):
                    pass

                def end(self):
                    pass

            return FakeStreamer()


class RayTextIteratorStreamer(TextStreamer):
    def __init__(
        self,
        tokenizer,
        skip_prompt: bool = False,
        timeout: int = None,
        **decode_kwargs: Dict[str, Any],
    ):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.text_queue = Queue()
        self.stop_signal = None
        self.timeout = timeout

    def on_finalized_text(self, text: str, stream_end: bool = False):
        self.text_queue.put(text, timeout=self.timeout)
        if stream_end:
            self.text_queue.put(self.stop_signal, timeout=self.timeout)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.text_queue.get(timeout=self.timeout)
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value


Next, define a deployment:

# We need to set these variables for this example.
HABANA_ENVS = {
    "PT_HPU_LAZY_ACC_PAR_MODE": "0",
    "PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES": "0",
    "PT_HPU_ENABLE_WEIGHT_CPU_PERMUTE": "0",
    "PT_HPU_ENABLE_LAZY_COLLECTIVES": "true",
    "HABANA_VISIBLE_MODULES": "0,1,2,3,4,5,6,7",
}


# Define the Ray Serve deployment.
@serve.deployment
class DeepSpeedLlamaModel:
    def __init__(self, world_size: int, model_id_or_path: str):
        self._world_size = world_size

        # Create the DeepSpeed workers
        self.deepspeed_workers = []
        for i in range(world_size):
            self.deepspeed_workers.append(
                DeepSpeedInferenceWorker.options(
                    runtime_env=RuntimeEnv(env_vars=HABANA_ENVS)
                ).remote(model_id_or_path, world_size, i)
            )

        # Load the model to all workers.
        for worker in self.deepspeed_workers:
            worker.load_model.remote()

        # Get the workers' streamers.
        self.streamers = ray.get(
            [worker.get_streamer.remote() for worker in self.deepspeed_workers]
        )

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Send the prompt to workers for generation.

        Return after all workers finish the generation.
        Only return the rank 0 worker's result.
        """

        futures = [
            worker.generate.remote(prompt, **config)
            for worker in self.deepspeed_workers
        ]
        return ray.get(futures)[0]

    def streaming_generate(self, prompt: str, **config: Dict[str, Any]):
        """Send the prompt to workers for streaming generation.

        Only use the rank 0 worker's result.
        """

        for worker, streamer in zip(self.deepspeed_workers, self.streamers):
            worker.streaming_generate.remote(prompt, streamer, **config)

    def consume_streamer(self, streamer):
        """Consume the streamer and return a generator."""
        for token in streamer:
            yield token

    async def __call__(self, http_request: Request):
        """Handle received HTTP requests."""

        # Load fields from the request
        json_request: str = await http_request.json()
        text = json_request["text"]
        # Config used in generation
        config = json_request.get("config", {})
        streaming_response = json_request["stream"]

        # Prepare prompts
        prompts = []
        if isinstance(text, list):
            prompts.extend(text)
        else:
            prompts.append(text)

        # Process the configuration.
        config.setdefault("max_new_tokens", 128)

        # Enable HPU graph runtime.
        config["hpu_graphs"] = True
        # Lazy mode should be True when using HPU graphs.
        config["lazy_mode"] = True

        # Non-streaming case
        if not streaming_response:
            return self.generate(prompts, **config)

        # Streaming case
        self.streaming_generate(prompts, **config)
        return StreamingResponse(
            self.consume_streamer(self.streamers[0]),
            status_code=200,
            media_type="text/plain",
        )


# Replace the model ID with a path if necessary.
entrypoint = DeepSpeedLlamaModel.bind(8, "meta-llama/Llama-2-70b-chat-hf")

Copy both blocks of the preceding code and save them into intel_gaudi_inference_serve_deepspeed.py. Run this example using serve run intel_gaudi_inference_serve_deepspeed:entrypoint.

The terminal should print logs as the deployment starts up:

2024-02-01 06:08:51,170 INFO scripts.py:438 -- Running import path: 'deepspeed_demo:entrypoint'.
2024-02-01 06:08:54,143 INFO worker.py:1540 -- Connecting to existing Ray cluster at address: 10.111.128.177:6379...
2024-02-01 06:08:54,154 INFO worker.py:1715 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 
(ServeController pid=44317) INFO 2024-02-01 06:08:54,348 controller 44317 deployment_state.py:1545 - Deploying new version of deployment DeepSpeedLlamaModel in application 'default'. Setting initial target number of replicas to 1.
(ServeController pid=44317) INFO 2024-02-01 06:08:54,457 controller 44317 deployment_state.py:1708 - Stopping 1 replicas of deployment 'DeepSpeedLlamaModel' in application 'default' with outdated versions.
(ServeController pid=44317) INFO 2024-02-01 06:08:57,326 controller 44317 deployment_state.py:2187 - Replica default#DeepSpeedLlamaModel#ToJmHV is stopped.
(ServeController pid=44317) INFO 2024-02-01 06:08:57,327 controller 44317 deployment_state.py:1829 - Adding 1 replica to deployment DeepSpeedLlamaModel in application 'default'.
(DeepSpeedInferenceWorker pid=48021) [WARNING|utils.py:198] 2024-02-01 06:09:12,355 >> optimum-habana v1.11.0.dev0 has been validated for SynapseAI v1.14.0 but the driver version is v1.15.0, this could lead to undefined behavior!
(DeepSpeedInferenceWorker pid=48016) /usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/__init__.py:158: UserWarning: torch.hpu.setDeterministic is deprecated and will be removed in next release. Please use torch.use_deterministic_algorithms instead.
(DeepSpeedInferenceWorker pid=48016)   warnings.warn(
(DeepSpeedInferenceWorker pid=48019) [2024-02-01 06:09:14,005] [INFO] [real_accelerator.py:178:get_accelerator] Setting ds_accelerator to hpu (auto detect)
(DeepSpeedInferenceWorker pid=48019) [2024-02-01 06:09:16,908] [INFO] [logging.py:96:log_dist] [Rank -1] DeepSpeed info: version=0.12.4+hpu.synapse.v1.14.0, git-hash=fad45b2, git-branch=1.14.0
(DeepSpeedInferenceWorker pid=48019) [2024-02-01 06:09:16,910] [INFO] [logging.py:96:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
Loading 15 checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]
(DeepSpeedInferenceWorker pid=48019) [2024-02-01 06:09:16,955] [WARNING] [comm.py:163:init_deepspeed_backend] HCCL backend in DeepSpeed not yet implemented
(DeepSpeedInferenceWorker pid=48019) [2024-02-01 06:09:16,955] [INFO] [comm.py:637:init_distributed] cdb=None
(DeepSpeedInferenceWorker pid=48018) [WARNING|utils.py:198] 2024-02-01 06:09:13,528 >> optimum-habana v1.11.0.dev0 has been validated for SynapseAI v1.14.0 but the driver version is v1.15.0, this could lead to undefined behavior! [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(ServeController pid=44317) WARNING 2024-02-01 06:09:27,403 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
(DeepSpeedInferenceWorker pid=48018) /usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/hpu/__init__.py:158: UserWarning: torch.hpu.setDeterministic is deprecated and will be removed in next release. Please use torch.use_deterministic_algorithms instead. [repeated 7x across cluster]
(DeepSpeedInferenceWorker pid=48018)   warnings.warn( [repeated 7x across cluster]
Loading 15 checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s] [repeated 7x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:09:57,475 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:   7%|▋         | 1/15 [00:52<12:15, 52.53s/it]
(DeepSpeedInferenceWorker pid=48014) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(DeepSpeedInferenceWorker pid=48014)  PT_HPU_LAZY_MODE = 1
(DeepSpeedInferenceWorker pid=48014)  PT_RECIPE_CACHE_PATH = 
(DeepSpeedInferenceWorker pid=48014)  PT_CACHE_FOLDER_DELETE = 0
(DeepSpeedInferenceWorker pid=48014)  PT_HPU_RECIPE_CACHE_CONFIG = 
(DeepSpeedInferenceWorker pid=48014)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(DeepSpeedInferenceWorker pid=48014)  PT_HPU_LAZY_ACC_PAR_MODE = 0
(DeepSpeedInferenceWorker pid=48014)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(DeepSpeedInferenceWorker pid=48014) ---------------------------: System Configuration :---------------------------
(DeepSpeedInferenceWorker pid=48014) Num CPU Cores : 156
(DeepSpeedInferenceWorker pid=48014) CPU RAM       : 495094196 KB
(DeepSpeedInferenceWorker pid=48014) ------------------------------------------------------------------------------
Loading 15 checkpoint shards:   7%|▋         | 1/15 [00:57<13:28, 57.75s/it] [repeated 2x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:10:27,504 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:   7%|▋         | 1/15 [00:58<13:42, 58.75s/it] [repeated 5x across cluster]
Loading 15 checkpoint shards:  13%|█▎        | 2/15 [01:15<07:21, 33.98s/it]
Loading 15 checkpoint shards:  13%|█▎        | 2/15 [01:16<07:31, 34.70s/it]
Loading 15 checkpoint shards:  20%|██        | 3/15 [01:35<05:34, 27.88s/it] [repeated 7x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:10:57,547 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:  27%|██▋       | 4/15 [01:53<04:24, 24.03s/it] [repeated 8x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:11:27,625 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:  27%|██▋       | 4/15 [01:54<04:21, 23.79s/it] [repeated 7x across cluster]
Loading 15 checkpoint shards:  40%|████      | 6/15 [02:30<03:06, 20.76s/it] [repeated 9x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:11:57,657 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:  40%|████      | 6/15 [02:29<03:05, 20.61s/it] [repeated 7x across cluster]
Loading 15 checkpoint shards:  47%|████▋     | 7/15 [02:47<02:39, 19.88s/it]
Loading 15 checkpoint shards:  47%|████▋     | 7/15 [02:48<02:39, 19.90s/it]
Loading 15 checkpoint shards:  53%|█████▎    | 8/15 [03:06<02:17, 19.60s/it] [repeated 7x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:12:27,721 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:  60%|██████    | 9/15 [03:26<01:56, 19.46s/it] [repeated 8x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:12:57,725 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:  67%|██████▋   | 10/15 [03:27<01:09, 13.80s/it] [repeated 15x across cluster]
Loading 15 checkpoint shards:  73%|███████▎  | 11/15 [03:46<01:00, 15.14s/it]
Loading 15 checkpoint shards:  73%|███████▎  | 11/15 [03:45<01:00, 15.15s/it]
Loading 15 checkpoint shards:  80%|████████  | 12/15 [04:05<00:49, 16.47s/it] [repeated 7x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:13:27,770 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:  87%|████████▋ | 13/15 [04:24<00:34, 17.26s/it] [repeated 8x across cluster]
(ServeController pid=44317) WARNING 2024-02-01 06:13:57,873 controller 44317 deployment_state.py:2171 - Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize. This may be caused by a slow __init__ or reconfigure method.
Loading 15 checkpoint shards:  87%|████████▋ | 13/15 [04:25<00:34, 17.35s/it] [repeated 7x across cluster]
Loading 15 checkpoint shards:  93%|█████████▎| 14/15 [04:44<00:17, 17.55s/it]
Loading 15 checkpoint shards: 100%|██████████| 15/15 [05:02<00:00, 18.30s/it] [repeated 8x across cluster]
2024-02-01 06:14:24,054 SUCC scripts.py:483 -- Deployed Serve app successfully.

Use the same code snippet introduced in the single HPU example to send generation requests. Here’s an example output:

Once upon a time, there was a young woman named Sophia who lived in a small village nestled in the rolling hills of Tuscany. Sophia was a curious and adventurous soul, always eager to explore the world around her. One day, while wandering through the village, she stumbled upon a hidden path she had never seen before.
The path was overgrown with weeds and vines, and it looked as though it hadn't been traversed in years. But Sophia was intrigued, and she decided to follow it to see where it led. She pushed aside the branches and stepped onto the path
Once upon a time, there was a young woman named Sophia who lived in a small village nestled in the rolling hills of Tuscany. Sophia was a curious and adventurous soul, always eager to explore the world around her. One day, while wandering through the village, she stumbled upon a hidden path she had never seen before.
The path was overgrown with weeds and vines, and it looked as though it hadn't been traversed in years. But Sophia was intrigued, and she decided to follow it to see where it led. She pushed aside the branches and stepped onto the path

Next Steps#

See llm-on-ray for more ways to customize and deploy LLMs at scale.