Source code for ray.util.accelerators.tpu
from typing import Optional
from ray._private.accelerators import TPUAcceleratorManager
from ray.util.annotations import PublicAPI
[docs]
@PublicAPI(stability="alpha")
def get_current_pod_name() -> Optional[str]:
"""
Return the name of the TPU pod that the worker is a part of.
Returns:
The name of the TPU pod. Returns None if not part of a TPU pod.
"""
tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
if tpu_name == "":
tpu_name = None
return tpu_name
[docs]
@PublicAPI(stability="alpha")
def get_current_pod_worker_count() -> Optional[int]:
"""
Count the number of workers associated with the TPU pod that the worker belongs to.
Returns:
The total number of workers in the TPU pod. Returns None if the worker is not
part of a TPU pod.
"""
return TPUAcceleratorManager.get_num_workers_in_current_tpu_pod()