ray.util.tpu.SlicePlacementGroup#

class ray.util.tpu.SlicePlacementGroup(topology: str, accelerator_version: str, strategy: str = 'SPREAD', name: str = '', lifetime: str | None = None, num_slices=1)[source]#

A handle to a placement group reservation for a TPU slice.

The following definitions are added for clarity:

  • Accelerator type: A string describing the accelerator type and version (e.g. TPU-V2, TPU-V6E).

  • Accelerator version: The accelerator generation only (e.g. v6e, v5p, v5litepod).

  • Pod type: The TPU accelerator version and the number of chips in a topology. (e.g. v6e-128, v5p-8).

  • Accelerator topology: The physical topology representing the structure (e.g. 2x2x2, 16x16).

    Args:

    topology: The TPU topology string (e.g. “2x2x2”). accelerator_version: The TPU accelerator generation (e.g. “v6e”, “v5p”, “v4”). strategy: PlacementGroup parameter. The strategy to create the placement group. Currently default to “SPREAD”

    • “PACK”: Packs Bundles into as few nodes as possible.

    • “SPREAD”: Places Bundles across distinct nodes as even as possible.

    • “STRICT_PACK”: Packs Bundles into one node. The group is not allowed to span multiple nodes.

    • “STRICT_SPREAD”: Packs Bundles across distinct nodes.

    lifetime: PlacementGroup parameter. Either None, which defaults to the placement group

    will fate share with its creator and will be deleted once its creator is dead, or “detached”, which means the placement group will live as a global object independent of the creator.

    num_slices: Number of TPU slices in the SlicePlacementGroup. Defaults to 1 when unspecified.

    Examples:

    import ray
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
    from ray.util.tpu import SlicePlacementGroup
    
    slice_handle = SlicePlacementGroup(topology="4x4", accelerator_version="v6e")
    slice_pg = slice_handle.placement_group
    ray.get(slice_pg.ready(), timeout=10)
    
    @ray.remote(num_cpus=0, resources={'TPU': 4})
    def spmd_task(world, rank):
        print(f"Current TPU is rank {rank} of {world}")
    
    tasks = [
        spmd_task.options(
            scheduling_strategy=PlacementGroupSchedulingStrategy(
                placement_group=slice_pg,
            )
        ).remote(world=4, rank=i)
        for i in range(slice_handle.num_workers)
    ]
    

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

Attributes

accelerator_version

The TPU accelerator type of the slice.

chips_per_host

The number of chips per host for this TPU slice.

num_slices

The number of TPU slices this SlicePlacementGroup spans.

num_workers

The total number of hosts in the SlicePlacementGroup.

placement_group

The underlying PlacementGroup object.

topology

The physical topology of the TPU slice.