ray.data.from_torch#

ray.data.from_torch(dataset: torch.utils.data.Dataset) Dataset[source]#

Create a Dataset from a Torch Dataset.

Note

The input dataset can either be map-style or iterable-style, and can have arbitrarily large amount of data. The data will be sequentially streamed with one single read task.

Examples

>>> import ray
>>> from torchvision import datasets
>>> dataset = datasets.MNIST("data", download=True)  
>>> ds = ray.data.from_torch(dataset)  
>>> ds  
MaterializedDataset(num_blocks=..., num_rows=60000, schema={item: object})
>>> ds.take(1)  
{"item": (<PIL.Image.Image image mode=L size=28x28 at 0x...>, 5)}
Parameters:

dataset – A Torch Dataset.

Returns:

A Dataset containing the Torch dataset samples.