ray.train.get_dataset_shard#

ray.train.get_dataset_shard(dataset_name: str | None = None) DataIterator | None#

Returns the ray.data.DataIterator shard for this worker.

Call iter_torch_batches() or to_tf() on this shard to convert it to the appropriate framework-specific data type.

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

def train_loop_per_worker(config):
    ...
    for epoch in range(2):
        # Trainer will automatically handle sharding.
        data_shard = train.get_dataset_shard("train")
        for batch in data_shard.iter_torch_batches():
            ...

train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=2),
    datasets={"train": train_dataset}
)
trainer.fit()
Parameters:

dataset_name – If a Dictionary of Datasets was passed to Trainer, then specifies which dataset shard to return.

Returns:

The DataIterator shard to use for this worker. If no dataset is passed into Trainer, then return None.