ray.data.from_torch#

ray.data.from_torch(dataset: torch.utils.data.Dataset) ray.data.dataset.Dataset[source]#

Create a dataset from a Torch dataset.

This function is inefficient. Use it to read small datasets or prototype.

Warning

If your dataset is large, this function may execute slowly or raise an out-of-memory error. To avoid issues, read the underyling data with a function like read_images().

Note

This function isn’t paralellized. It loads the entire dataset into the head node’s memory before moving the data to the distributed object store.

Examples

>>> import ray
>>> from torchvision import datasets
>>> dataset = datasets.MNIST("data", download=True)  
>>> dataset = ray.data.from_torch(dataset)  
>>> dataset  
Dataset(num_blocks=200, num_rows=60000, schema=<class 'tuple'>)
>>> dataset.take(1)  
[(<PIL.Image.Image image mode=L size=28x28 at 0x...>, 5)]
Parameters

dataset – A Torch dataset.

Returns

A Dataset that contains the samples stored in the Torch dataset.

PublicAPI: This API is stable across Ray releases.