ray.train.torch.get_device#

ray.train.torch.get_device() torch.device | List[torch.device][source]#

Gets the correct torch device configured for this process.

Returns a list of devices if more than 1 GPU per worker is requested.

Assumes that CUDA_VISIBLE_DEVICES is set and is a superset of the ray.get_gpu_ids().

Example

>>> # os.environ["CUDA_VISIBLE_DEVICES"] = "3,4"
>>> # ray.get_gpu_ids() == [3]
>>> # torch.cuda.is_available() == True
>>> # get_device() == torch.device("cuda:0")
>>> # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4"
>>> # ray.get_gpu_ids() == [4]
>>> # torch.cuda.is_available() == True
>>> # get_device() == torch.device("cuda:4")
>>> # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5"
>>> # ray.get_gpu_ids() == [4,5]
>>> # torch.cuda.is_available() == True
>>> # get_device() == torch.device("cuda:4")