Source code for ray.util.accelerators.tpu
from typing import Optional
from ray._private.accelerators import TPUAcceleratorManager
from ray.util.annotations import PublicAPI
[docs]
@PublicAPI(stability="alpha")
def get_current_pod_name() -> Optional[str]:
"""
Return the name of the TPU pod that the worker is a part of.
Returns:
The name of the TPU pod. Returns None if not part of a TPU pod.
"""
tpu_name = TPUAcceleratorManager.get_current_node_tpu_name()
if tpu_name == "":
tpu_name = None
return tpu_name
[docs]
@PublicAPI(stability="alpha")
def get_current_pod_worker_count() -> Optional[int]:
"""
Count the number of workers associated with the TPU pod that the worker belongs to.
Returns:
The total number of workers in the TPU pod. Returns None if the worker is not
part of a TPU pod.
"""
return TPUAcceleratorManager.get_num_workers_in_current_tpu_pod()
[docs]
@PublicAPI(stablity="alpha")
def get_num_tpu_chips_on_node() -> int:
"""
Return the number of TPU chips on the node.
Returns:
The total number of chips on the TPU node. Returns 0 if none are found.
"""
return TPUAcceleratorManager.get_current_node_num_accelerators()