ray.train.get_dataset_shard#

ray.train.get_dataset_shard(dataset_name: Optional[str] = None) Optional[DataIterator]#

Returns the ray.data.DataIterator shard for this worker.

Call iter_torch_batches() or to_tf() on this shard to convert it to the appropriate framework-specific data type.

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

def train_loop_per_worker(config):
    ...
    for epoch in range(2):
        # Trainer will automatically handle sharding.
        data_shard = train.get_dataset_shard("train")
        for batch in data_shard.iter_torch_batches():
            ...

train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=2),
    datasets={"train": train_dataset}
)
trainer.fit()
Parameters

dataset_name – If a Dictionary of Datasets was passed to Trainer, then specifies which dataset shard to return.

Returns

The DataIterator shard to use for this worker. If no dataset is passed into Trainer, then return None.