ray.air.session.get_dataset_shard#

ray.air.session.get_dataset_shard(dataset_name: Optional[str] = None) Optional[DatasetIterator][source]#

Returns the ray.data.DatasetIterator shard for this worker.

Call iter_torch_batches() or to_tf() on this shard to convert it to the appropriate framework-specific data type.

import ray
from ray import train
from ray.air import session
from ray.air.config import ScalingConfig

def train_loop_per_worker():
    model = Net()
    for iter in range(100):
        # Trainer will automatically handle sharding.
        data_shard = session.get_dataset_shard("train")
        for batch in data_shard.iter_torch_batches():
            # ...
    return model

train_dataset = ray.data.from_items(
    [{"x": x, "y": x + 1} for x in range(32)])
trainer = TorchTrainer(train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=2),
    datasets={"train": train_dataset})
trainer.fit()
Parameters

dataset_name – If a Dictionary of Datasets was passed to Trainer, then specifies which dataset shard to return.

Returns

The DatasetIterator shard to use for this worker. If no dataset is passed into Trainer, then return None.

PublicAPI (beta): This API is in beta and may change before becoming stable.