Serve Llama2-7b/70b on a single or multiple Intel Gaudi Accelerator#

Intel Gaudi AI Processors (HPUs) are AI hardware accelerators designed by Intel Habana Labs. See Gaudi Architecture and Gaudi Developer Docs for more details.

This tutorial has two examples:

  1. Deployment of Llama2-7b using a single HPU:

    • Load a model onto an HPU.

    • Perform generation on an HPU.

    • Enable HPU Graph optimizations.

  2. Deployment of Llama2-70b using multiple HPUs on a single node:

    • Initialize a distributed backend.

    • Load a sharded model onto DeepSpeed workers.

    • Stream responses from DeepSpeed workers.

This tutorial serves a large language model (LLM) on HPUs.

Environment setup#

Use a prebuilt container to run these examples. To run a container, you need Docker. See Install Docker Engine for installation instructions.

Next, follow Run Using Containers to install the Gaudi drivers and container runtime. To verify your installation, start a shell and run hl-smi. It should print status information about the HPUs on the machine:

+-----------------------------------------------------------------------------+
| HL-SMI Version:                              hl-1.20.0-fw-58.1.1.1          |
| Driver Version:                                     1.19.1-6f47ddd          |
| Nic Driver Version:                                 1.19.1-f071c23          |
|-------------------------------+----------------------+----------------------+
| AIP  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncor-Events|
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | AIP-Util  Compute M. |
|===============================+======================+======================|
|   0  HL-225              N/A  | 0000:9a:00.0     N/A |                   0  |
| N/A   22C   N/A  96W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   1  HL-225              N/A  | 0000:9b:00.0     N/A |                   0  |
| N/A   24C   N/A  78W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   2  HL-225              N/A  | 0000:b3:00.0     N/A |                   0  |
| N/A   25C   N/A  81W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   3  HL-225              N/A  | 0000:b4:00.0     N/A |                   0  |
| N/A   22C   N/A  92W /  600W  | 96565MiB /  98304MiB |     0%           98% |
|-------------------------------+----------------------+----------------------+
|   4  HL-225              N/A  | 0000:33:00.0     N/A |                   0  |
| N/A   22C   N/A  83W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   5  HL-225              N/A  | 0000:4e:00.0     N/A |                   0  |
| N/A   21C   N/A  80W /  600W  | 96564MiB /  98304MiB |     0%           98% |
|-------------------------------+----------------------+----------------------+
|   6  HL-225              N/A  | 0000:34:00.0     N/A |                   0  |
| N/A   25C   N/A  86W /  600W  |   768MiB /  98304MiB |     0%            0% |
|-------------------------------+----------------------+----------------------+
|   7  HL-225              N/A  | 0000:4d:00.0     N/A |                   0  |
| N/A   30C   N/A 100W /  600W  | 17538MiB /  98304MiB |     0%           17% |
|-------------------------------+----------------------+----------------------+
| Compute Processes:                                               AIP Memory |
|  AIP       PID   Type   Process name                             Usage      |
|=============================================================================|
|   0        N/A   N/A    N/A                                      N/A        |
|   1        N/A   N/A    N/A                                      N/A        |
|   2        N/A   N/A    N/A                                      N/A        |
|   3        N/A   N/A    N/A                                      N/A        |
|   4        N/A   N/A    N/A                                      N/A        |
|   5        N/A   N/A    N/A                                      N/A        |
|   6        N/A   N/A    N/A                                      N/A        |
|   7       107684     C   ray::_RayTrainW                         16770MiB    
+=============================================================================+

Next, start the Gaudi container:

docker pull vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest

To follow the examples in this tutorial, mount the directory containing the examples and models into the container. Inside the container, run:

pip install ray[tune,serve]
pip install git+https://github.com/huggingface/optimum-habana.git
# Replace 1.20.0 with the driver version of the container.
pip install git+https://github.com/HabanaAI/[email protected]
# Only needed by the DeepSpeed example.
export RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES=1

Start Ray in the container with ray start --head. You are now ready to run the examples.

Running a model on a single HPU#

This example shows how to deploy a Llama2-7b model on an HPU for inference.

First, define a deployment that serves a Llama2-7b model using an HPU. Note that we enable HPU graph optimizations for better performance.

import asyncio
from functools import partial
from queue import Empty
from typing import Dict, Any

from starlette.requests import Request
from starlette.responses import StreamingResponse
import torch

from ray import serve


# Define the Ray Serve deployment
@serve.deployment(ray_actor_options={"num_cpus": 10, "resources": {"HPU": 1}})
class LlamaModel:
    def __init__(self, model_id_or_path: str):
        from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
        from optimum.habana.transformers.modeling_utils import (
            adapt_transformers_to_gaudi,
        )

        # Tweak transformers to optimize performance
        adapt_transformers_to_gaudi()

        self.device = torch.device("hpu")

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id_or_path, use_fast=False, use_auth_token=""
        )
        hf_config = AutoConfig.from_pretrained(
            model_id_or_path,
            torchscript=True,
            use_auth_token="",
            trust_remote_code=False,
        )
        # Load the model in Gaudi
        model = AutoModelForCausalLM.from_pretrained(
            model_id_or_path,
            config=hf_config,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            use_auth_token="",
        )
        model = model.eval().to(self.device)

        from habana_frameworks.torch.hpu import wrap_in_hpu_graph

        # Enable hpu graph runtime
        self.model = wrap_in_hpu_graph(model)

        # Set pad token, etc.
        self.tokenizer.pad_token_id = self.model.generation_config.pad_token_id
        self.tokenizer.padding_side = "left"

        # Use async loop in streaming
        self.loop = asyncio.get_running_loop()

    def tokenize(self, prompt: str):
        """Tokenize the input and move to HPU."""

        input_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True)
        return input_tokens.input_ids.to(device=self.device)

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Take a prompt and generate a response."""

        input_ids = self.tokenize(prompt)
        gen_tokens = self.model.generate(input_ids, **config)
        return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

    async def consume_streamer_async(self, streamer):
        """Consume the streamer asynchronously."""

        while True:
            try:
                for token in streamer:
                    yield token
                break
            except Empty:
                await asyncio.sleep(0.001)

    def streaming_generate(self, prompt: str, streamer, **config: Dict[str, Any]):
        """Generate a streamed response given an input."""

        input_ids = self.tokenize(prompt)
        self.model.generate(input_ids, streamer=streamer, **config)

    async def __call__(self, http_request: Request):
        """Handle HTTP requests."""

        # Load fields from the request
        json_request: str = await http_request.json()
        text = json_request["text"]
        # Config used in generation
        config = json_request.get("config", {})
        streaming_response = json_request["stream"]

        # Prepare prompts
        prompts = []
        if isinstance(text, list):
            prompts.extend(text)
        else:
            prompts.append(text)

        # Process config
        config.setdefault("max_new_tokens", 128)

        # Enable HPU graph runtime
        config["hpu_graphs"] = True
        # Lazy mode should be True when using HPU graphs
        config["lazy_mode"] = True

        # Non-streaming case
        if not streaming_response:
            return self.generate(prompts, **config)

        # Streaming case
        from transformers import TextIteratorStreamer

        streamer = TextIteratorStreamer(
            self.tokenizer, skip_prompt=True, timeout=0, skip_special_tokens=True
        )
        # Convert the streamer into a generator
        self.loop.run_in_executor(
            None, partial(self.streaming_generate, prompts, streamer, **config)
        )
        return StreamingResponse(
            self.consume_streamer_async(streamer),
            status_code=200,
            media_type="text/plain",
        )


# Replace the model ID with path if necessary
entrypoint = LlamaModel.bind("meta-llama/Llama-2-7b-chat-hf")

Copy the code above and save it as intel_gaudi_inference_serve.py. Start the deployment like this:

serve run intel_gaudi_inference_serve:entrypoint

The terminal should print logs as the deployment starts up:

2025-03-03 06:07:08,106 INFO scripts.py:494 -- Running import path: 'infer:entrypoint'.
2025-03-03 06:07:09,295 INFO worker.py:1654 -- Connecting to existing Ray cluster at address: 100.83.111.228:6379...
2025-03-03 06:07:09,304 INFO worker.py:1832 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 
(ProxyActor pid=147082) INFO 2025-03-03 06:07:11,096 proxy 100.83.111.228 -- Proxy starting on node b4d028b67678bfdd190b503b44780bc319c07b1df13ac5c577873861 (HTTP port: 8000).
INFO 2025-03-03 06:07:11,202 serve 162730 -- Started Serve in namespace "serve".
INFO 2025-03-03 06:07:11,203 serve 162730 -- Connecting to existing Serve app in namespace "serve". New http options will not be applied.
(ProxyActor pid=147082) INFO 2025-03-03 06:07:11,184 proxy 100.83.111.228 -- Got updated endpoints: {}.
(ServeController pid=147087) INFO 2025-03-03 06:07:11,278 controller 147087 -- Deploying new version of Deployment(name='LlamaModel', app='default') (initial target replicas: 1).
(ProxyActor pid=147082) INFO 2025-03-03 06:07:11,280 proxy 100.83.111.228 -- Got updated endpoints: {Deployment(name='LlamaModel', app='default'): EndpointInfo(route='/', app_is_cross_language=False)}.
(ProxyActor pid=147082) INFO 2025-03-03 06:07:11,286 proxy 100.83.111.228 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x7f74804e90c0>.
(ServeController pid=147087) INFO 2025-03-03 06:07:11,381 controller 147087 -- Adding 1 replica to Deployment(name='LlamaModel', app='default').
(ServeReplica:default:LlamaModel pid=147085) [WARNING|utils.py:212] 2025-03-03 06:07:15,251 >> optimum-habana v1.15.0 has been validated for SynapseAI v1.19.0 but habana-frameworks v1.20.0.543 was found, this could lead to undefined behavior!
(ServeReplica:default:LlamaModel pid=147085) /usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
(ServeReplica:default:LlamaModel pid=147085)   warnings.warn(
(ServeReplica:default:LlamaModel pid=147085) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/tokenization_auto.py:796: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.
(ServeReplica:default:LlamaModel pid=147085)   warnings.warn(
(ServeReplica:default:LlamaModel pid=147085) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/configuration_auto.py:991: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.
(ServeReplica:default:LlamaModel pid=147085)   warnings.warn(
(ServeReplica:default:LlamaModel pid=147085) /usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py:471: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.
(ServeReplica:default:LlamaModel pid=147085)   warnings.warn(
Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:01<00:01,  1.72s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.45s/it]
(ServeReplica:default:LlamaModel pid=147085) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_LAZY_MODE = 1
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_RECIPE_CACHE_CONFIG = ,false,1024
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_LAZY_ACC_PAR_MODE = 1
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_EAGER_PIPELINE_ENABLE = 1
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_EAGER_COLLECTIVE_PIPELINE_ENABLE = 1
(ServeReplica:default:LlamaModel pid=147085)  PT_HPU_ENABLE_LAZY_COLLECTIVES = 0
(ServeReplica:default:LlamaModel pid=147085) ---------------------------: System Configuration :---------------------------
(ServeReplica:default:LlamaModel pid=147085) Num CPU Cores : 160
(ServeReplica:default:LlamaModel pid=147085) CPU RAM       : 1056374420 KB
(ServeReplica:default:LlamaModel pid=147085) ------------------------------------------------------------------------------
INFO 2025-03-03 06:07:30,359 serve 162730 -- Application 'default' is ready at http://127.0.0.1:8000/.
INFO 2025-03-03 06:07:30,359 serve 162730 -- Deployed app 'default' successfully.

In another shell, use the following code to send requests to the deployment to perform generation tasks.

import requests

# Prompt for the model
prompt = "Once upon a time,"

# Add generation config here
config = {}

# Non-streaming response
sample_input = {"text": prompt, "config": config, "stream": False}
outputs = requests.post("http://127.0.0.1:8000/", json=sample_input, stream=False)
print(outputs.text, flush=True)

# Streaming response
sample_input["stream"] = True
outputs = requests.post("http://127.0.0.1:8000/", json=sample_input, stream=True)
outputs.raise_for_status()
for output in outputs.iter_content(chunk_size=None, decode_unicode=True):
    print(output, end="", flush=True)
print()

Here is an example output:

Once upon a time, in a small village nestled in the rolling hills of Tuscany, there lived a young girl named Sophia.

Sophia was a curious and adventurous child, always eager to explore the world around her. She spent her days playing in the fields and forests, chasing after butterflies and watching the clouds drift lazily across the sky.

One day, as Sophia was wandering through the village, she stumbled upon a beautiful old book hidden away in a dusty corner of the local library. The book was bound in worn leather and adorned with intr
in a small village nestled in the rolling hills of Tuscany, there lived a young girl named Luna.
Luna was a curious and adventurous child, always eager to explore the world around her. She spent her days wandering through the village, discovering new sights and sounds at every turn.

One day, as she was wandering through the village, Luna stumbled upon a hidden path she had never seen before. The path was overgrown with weeds and vines, and it seemed to disappear into the distance.

Luna's curiosity was piqued,

Running a sharded model on multiple HPUs#

This example deploys a Llama2-70b model using 8 HPUs orchestrated by DeepSpeed.

The example requires caching the Llama2-70b model. Run the following Python code in the Gaudi container to cache the model.

from huggingface_hub import snapshot_download
snapshot_download(
    "meta-llama/Llama-2-70b-chat-hf",
    # Replace the path if necessary.
    cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
    # Specify your Hugging Face token.
    token=""
)

In this example, the deployment replica sends prompts to the DeepSpeed workers, which are running in Ray actors:

import tempfile
from typing import Dict, Any
from starlette.requests import Request
from starlette.responses import StreamingResponse

import torch
from transformers import TextStreamer

import ray
from ray import serve
from ray.util.queue import Queue
from ray.runtime_env import RuntimeEnv


@ray.remote(resources={"HPU": 1})
class DeepSpeedInferenceWorker:
    def __init__(self, model_id_or_path: str, world_size: int, local_rank: int):
        """An actor that runs a DeepSpeed inference engine.

        Arguments:
            model_id_or_path: Either a Hugging Face model ID
                or a path to a cached model.
            world_size: Total number of worker processes.
            local_rank: Rank of this worker process.
                The rank 0 worker is the head worker.
        """
        from transformers import AutoTokenizer, AutoConfig
        from optimum.habana.transformers.modeling_utils import (
            adapt_transformers_to_gaudi,
        )

        # Tweak transformers for better performance on Gaudi.
        adapt_transformers_to_gaudi()

        self.model_id_or_path = model_id_or_path
        self._world_size = world_size
        self._local_rank = local_rank
        self.device = torch.device("hpu")

        self.model_config = AutoConfig.from_pretrained(
            model_id_or_path,
            torch_dtype=torch.bfloat16,
            token="",
            trust_remote_code=False,
        )

        # Load and configure the tokenizer.
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id_or_path, use_fast=False, token=""
        )
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        import habana_frameworks.torch.distributed.hccl as hccl

        # Initialize the distributed backend.
        hccl.initialize_distributed_hpu(
            world_size=world_size, rank=local_rank, local_rank=local_rank
        )
        torch.distributed.init_process_group(backend="hccl")

    def load_model(self):
        """Load the model to HPU and initialize the DeepSpeed inference engine."""

        import deepspeed
        from transformers import AutoModelForCausalLM
        from optimum.habana.checkpoint_utils import (
            get_ds_injection_policy,
            write_checkpoints_json,
        )

        # Construct the model with fake meta Tensors.
        # Loads the model weights from the checkpoint later.
        with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"):
            model = AutoModelForCausalLM.from_config(
                self.model_config, torch_dtype=torch.bfloat16
            )
        model = model.eval()

        # Create a file to indicate where the checkpoint is.
        checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="w+")
        write_checkpoints_json(
            self.model_id_or_path, self._local_rank, checkpoints_json, token=""
        )

        # Prepare the DeepSpeed inference configuration.
        kwargs = {"dtype": torch.bfloat16}
        kwargs["checkpoint"] = checkpoints_json.name
        kwargs["tensor_parallel"] = {"tp_size": self._world_size}
        # Enable the HPU graph, similar to the cuda graph.
        kwargs["enable_cuda_graph"] = True
        # Specify the injection policy, required by DeepSpeed Tensor parallelism.
        kwargs["injection_policy"] = get_ds_injection_policy(self.model_config)

        # Initialize the inference engine.
        self.model = deepspeed.init_inference(model, **kwargs).module

    def tokenize(self, prompt: str):
        """Tokenize the input and move it to HPU."""

        input_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True)
        return input_tokens.input_ids.to(device=self.device)

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Take in a prompt and generate a response."""

        input_ids = self.tokenize(prompt)
        gen_tokens = self.model.generate(input_ids, **config)
        return self.tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

    def streaming_generate(self, prompt: str, streamer, **config: Dict[str, Any]):
        """Generate a streamed response given an input."""

        input_ids = self.tokenize(prompt)
        self.model.generate(input_ids, streamer=streamer, **config)

    def get_streamer(self):
        """Return a streamer.

        We only need the rank 0 worker's result.
        Other workers return a fake streamer.
        """

        if self._local_rank == 0:
            return RayTextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
        else:

            class FakeStreamer:
                def put(self, value):
                    pass

                def end(self):
                    pass

            return FakeStreamer()


class RayTextIteratorStreamer(TextStreamer):
    def __init__(
        self,
        tokenizer,
        skip_prompt: bool = False,
        timeout: int = None,
        **decode_kwargs: Dict[str, Any],
    ):
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        self.text_queue = Queue()
        self.stop_signal = None
        self.timeout = timeout

    def on_finalized_text(self, text: str, stream_end: bool = False):
        self.text_queue.put(text, timeout=self.timeout)
        if stream_end:
            self.text_queue.put(self.stop_signal, timeout=self.timeout)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.text_queue.get(timeout=self.timeout)
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value


Next, define a deployment:

# We need to set these variables for this example.
HABANA_ENVS = {
    "PT_HPU_LAZY_ACC_PAR_MODE": "0",
    "PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES": "0",
    "PT_HPU_ENABLE_WEIGHT_CPU_PERMUTE": "0",
    "PT_HPU_ENABLE_LAZY_COLLECTIVES": "true",
    "HABANA_VISIBLE_MODULES": "0,1,2,3,4,5,6,7",
}


# Define the Ray Serve deployment.
@serve.deployment
class DeepSpeedLlamaModel:
    def __init__(self, world_size: int, model_id_or_path: str):
        self._world_size = world_size

        # Create the DeepSpeed workers
        self.deepspeed_workers = []
        for i in range(world_size):
            self.deepspeed_workers.append(
                DeepSpeedInferenceWorker.options(
                    runtime_env=RuntimeEnv(env_vars=HABANA_ENVS)
                ).remote(model_id_or_path, world_size, i)
            )

        # Load the model to all workers.
        for worker in self.deepspeed_workers:
            worker.load_model.remote()

        # Get the workers' streamers.
        self.streamers = ray.get(
            [worker.get_streamer.remote() for worker in self.deepspeed_workers]
        )

    def generate(self, prompt: str, **config: Dict[str, Any]):
        """Send the prompt to workers for generation.

        Return after all workers finish the generation.
        Only return the rank 0 worker's result.
        """

        futures = [
            worker.generate.remote(prompt, **config)
            for worker in self.deepspeed_workers
        ]
        return ray.get(futures)[0]

    def streaming_generate(self, prompt: str, **config: Dict[str, Any]):
        """Send the prompt to workers for streaming generation.

        Only use the rank 0 worker's result.
        """

        for worker, streamer in zip(self.deepspeed_workers, self.streamers):
            worker.streaming_generate.remote(prompt, streamer, **config)

    def consume_streamer(self, streamer):
        """Consume the streamer and return a generator."""
        for token in streamer:
            yield token

    async def __call__(self, http_request: Request):
        """Handle received HTTP requests."""

        # Load fields from the request
        json_request: str = await http_request.json()
        text = json_request["text"]
        # Config used in generation
        config = json_request.get("config", {})
        streaming_response = json_request["stream"]

        # Prepare prompts
        prompts = []
        if isinstance(text, list):
            prompts.extend(text)
        else:
            prompts.append(text)

        # Process the configuration.
        config.setdefault("max_new_tokens", 128)

        # Enable HPU graph runtime.
        config["hpu_graphs"] = True
        # Lazy mode should be True when using HPU graphs.
        config["lazy_mode"] = True

        # Non-streaming case
        if not streaming_response:
            return self.generate(prompts, **config)

        # Streaming case
        self.streaming_generate(prompts, **config)
        return StreamingResponse(
            self.consume_streamer(self.streamers[0]),
            status_code=200,
            media_type="text/plain",
        )


# Replace the model ID with a path if necessary.
entrypoint = DeepSpeedLlamaModel.bind(8, "meta-llama/Llama-2-70b-chat-hf")

Copy both blocks of the preceding code and save them into intel_gaudi_inference_serve_deepspeed.py. Run this example using serve run intel_gaudi_inference_serve_deepspeed:entrypoint.

Notice!!! Please set the environment variable HABANA_VISIBLE_MODULES carefully.

The terminal should print logs as the deployment starts up:

2025-03-03 06:21:57,692 INFO scripts.py:494 -- Running import path: 'infer-ds:entrypoint'.
2025-03-03 06:22:03,064 INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
INFO 2025-03-03 06:22:07,343 serve 170212 -- Started Serve in namespace "serve".
INFO 2025-03-03 06:22:07,343 serve 170212 -- Connecting to existing Serve app in namespace "serve". New http options will not be applied.
(ServeController pid=170719) INFO 2025-03-03 06:22:07,377 controller 170719 -- Deploying new version of Deployment(name='DeepSpeedLlamaModel', app='default') (initial target replicas: 1).
(ProxyActor pid=170723) INFO 2025-03-03 06:22:07,290 proxy 100.83.111.228 -- Proxy starting on node 47721c925467a877497e66104328bb72dc7bd7f900a63b2f1fdb48b2 (HTTP port: 8000).
(ProxyActor pid=170723) INFO 2025-03-03 06:22:07,325 proxy 100.83.111.228 -- Got updated endpoints: {}.
(ProxyActor pid=170723) INFO 2025-03-03 06:22:07,379 proxy 100.83.111.228 -- Got updated endpoints: {Deployment(name='DeepSpeedLlamaModel', app='default'): EndpointInfo(route='/', app_is_cross_language=False)}.
(ServeController pid=170719) INFO 2025-03-03 06:22:07,478 controller 170719 -- Adding 1 replica to Deployment(name='DeepSpeedLlamaModel', app='default').
(ProxyActor pid=170723) INFO 2025-03-03 06:22:07,422 proxy 100.83.111.228 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x7fa557945210>.
(DeepSpeedInferenceWorker pid=179962) [WARNING|utils.py:212] 2025-03-03 06:22:14,611 >> optimum-habana v1.15.0 has been validated for SynapseAI v1.19.0 but habana-frameworks v1.20.0.543 was found, this could lead to undefined behavior!
(DeepSpeedInferenceWorker pid=179963) /usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
(DeepSpeedInferenceWorker pid=179963)   warnings.warn(
(DeepSpeedInferenceWorker pid=179964) [WARNING|utils.py:212] 2025-03-03 06:22:14,613 >> optimum-habana v1.15.0 has been validated for SynapseAI v1.19.0 but habana-frameworks v1.20.0.543 was found, this could lead to undefined behavior! [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(DeepSpeedInferenceWorker pid=179962) [2025-03-03 06:22:23,502] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to hpu (auto detect)
Loading 2 checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
(DeepSpeedInferenceWorker pid=179962) [2025-03-03 06:22:24,032] [INFO] [logging.py:105:log_dist] [Rank -1] DeepSpeed info: version=0.16.1+hpu.synapse.v1.20.0, git-hash=61543a96, git-branch=1.20.0
(DeepSpeedInferenceWorker pid=179962) [2025-03-03 06:22:24,035] [INFO] [logging.py:105:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
(DeepSpeedInferenceWorker pid=179962) [2025-03-03 06:22:24,048] [INFO] [comm.py:652:init_distributed] cdb=None
(DeepSpeedInferenceWorker pid=179963) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_LAZY_MODE = 1
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_RECIPE_CACHE_CONFIG = ,false,1024
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_LAZY_ACC_PAR_MODE = 0
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_EAGER_PIPELINE_ENABLE = 1
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_EAGER_COLLECTIVE_PIPELINE_ENABLE = 1
(DeepSpeedInferenceWorker pid=179963)  PT_HPU_ENABLE_LAZY_COLLECTIVES = 1
(DeepSpeedInferenceWorker pid=179963) ---------------------------: System Configuration :---------------------------
(DeepSpeedInferenceWorker pid=179963) Num CPU Cores : 160
(DeepSpeedInferenceWorker pid=179963) CPU RAM       : 1056374420 KB
(DeepSpeedInferenceWorker pid=179963) ------------------------------------------------------------------------------
(DeepSpeedInferenceWorker pid=179964) /usr/local/lib/python3.10/dist-packages/transformers/deepspeed.py:24: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations [repeated 3x across cluster]
(DeepSpeedInferenceWorker pid=179964)   warnings.warn( [repeated 3x across cluster]
Loading 2 checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s] [repeated 3x across cluster]
(ServeController pid=170719) WARNING 2025-03-03 06:22:37,562 controller 170719 -- Deployment 'DeepSpeedLlamaModel' in application 'default' has 1 replicas that have taken more than 30s to initialize.
(ServeController pid=170719) This may be caused by a slow __init__ or reconfigure method.
Loading 2 checkpoint shards:  50%|█████     | 1/2 [00:17<00:17, 17.51s/it]
Loading 2 checkpoint shards: 100%|██████████| 2/2 [00:21<00:00,  9.57s/it]
Loading 2 checkpoint shards: 100%|██████████| 2/2 [00:21<00:00, 10.88s/it]
Loading 2 checkpoint shards:  50%|█████     | 1/2 [00:18<00:18, 18.70s/it] [repeated 3x across cluster]
INFO 2025-03-03 06:22:48,569 serve 170212 -- Application 'default' is ready at http://127.0.0.1:8000/.
INFO 2025-03-03 06:22:48,569 serve 170212 -- Deployed app 'default' successfully.

Use the same code snippet introduced in the single HPU example to send generation requests. Here’s an example output:

Once upon a time, in a far-off land, there was a magical kingdom called "Happily Ever Laughter." It was a place where laughter was the key to unlocking all the joys of life, and where everyone lived in perfect harmony.

In this kingdom, there was a beautiful princess named Lily. She was kind, gentle, and had a heart full of laughter. Every day, she would wake up with a big smile on her face, ready to face whatever adventures the day might bring.

One day, a wicked sorcerer cast a spell on the kingdom
Once upon a time, in a far-off land, there was a magical kingdom called "Happily Ever Laughter." It was a place where laughter was the key to unlocking all the joys of life, and where everyone lived in perfect harmony.

In this kingdom, there was a beautiful princess named Lily. She was kind, gentle, and had a heart full of laughter. Every day, she would wake up with a big smile on her face, ready to face whatever adventures the day might bring.

One day, a wicked sorcerer cast a spell on the kingdom

Next Steps#

See llm-on-ray for more ways to customize and deploy LLMs at scale.