Stable Diffusion Batch Prediction with Ray Data#

In this example, we’ll showcase how to use Ray Data for Stable Diffusion batch inference. Stable Diffusion is a text-to-image latent diffusion model created by the researchers and engineers from CompVis, Stability AI, and LAION. It’s trained on 512x512 images from a subset of the LAION-5B database. LAION-5B is the largest, freely accessible multi-modal dataset that currently exists. For more information on Stable Diffusion, see Stable Diffusion with 🧨 Diffusers.

We’ll use Ray Data and a pretrained model from Hugging Face hub. You can easily adapt this example to use other similar models.

We recommend reading Ray Train Key Concepts and Ray Data Quickstart before starting this example.

Note

To run this example, make sure your Ray cluster has access to at least one GPU with 16 or more GBs of memory. The amount of memory needed will depend on the model.

model_id = "stabilityai/stable-diffusion-2-1"
prompt = "a photo of an astronaut riding a horse on mars"
import ray

We define a runtime environment to ensure that the Ray workers have access to all the necessary packages. You can omit the runtime_env argument if you have all of the packages already installed on each node in your cluster.

ray.init(
    runtime_env={
        "pip": [
            "accelerate>=0.16.0",
            "transformers>=4.26.0",
            "diffusers>=0.13.1",
            "xformers>=0.0.16",
            "torch<2",
        ]
    }
)

For the purposes of this example, we’ll use a very small toy dataset composed of multiple copies of our prompt. Ray Data can handle much bigger datasets with ease.

import ray.data
import pandas as pd

ds = ray.data.from_pandas(pd.DataFrame([prompt] * 4, columns=["prompt"]))

Since we will be using a pretrained model from Hugging Face hub, the simplest way is to use map_batches with a callable class UDF. This will allow us to save time by initializing a model just once and then feed it multiple batches of data.

class PredictCallable:
    def __init__(self, model_id: str, revision: str = None):
        from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

        # Use xformers for better memory usage
        from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
        import torch

        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id, torch_dtype=torch.float16
        )
        self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
            self.pipe.scheduler.config
        )
        self.pipe.enable_xformers_memory_efficient_attention(
            attention_op=MemoryEfficientAttentionFlashAttentionOp
        )
        # Workaround for not accepting attention shape using VAE for Flash Attention
        self.pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
        self.pipe = self.pipe.to("cuda")

    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        import torch
        import numpy as np

        # Set a different seed for every image in batch
        self.pipe.generator = [
            torch.Generator(device="cuda").manual_seed(i) for i in range(len(batch))
        ]
        images = self.pipe(list(batch["prompt"])).images
        return {"images": np.array(images, dtype=object)}

All that’s left is to run the map_batches method on the dataset. We specify that we want to use one GPU for each Ray Actor that will be running our callable class.

Tip

If you have access to large GPUs, you may want to increase the batch size to better saturate them.

preds = ds.map_batches(
    PredictCallable,
    batch_size=1,
    fn_constructor_kwargs=dict(model_id=model_id),
    concurrency=1,
    batch_format="pandas",
    num_gpus=1,
)
results = preds.take_all()
2023-02-28 10:38:32,723	INFO bulk_executor.py:41 -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(PredictCallable)]
MapBatches(PredictCallable), 0 actors [0 locality hits, 1 misses]: 100%|██████████| 1/1 [01:46<00:00, 106.33s/it]

After map_batches is done, we can view our images.

results[0]["images"]
../../_images/d7a75b866930a444718d4544c3f521e835e6373a4213b02714b1e6a2d7f62779.png
results[1]["images"]
../../_images/1ca67d9fba8393b0e715685a6a935671381a74b032ffe1d786930e47a112e23f.png