ray.train.get_dataset_shard#
- ray.train.get_dataset_shard(dataset_name: str | None = None) DataIterator | None #
Returns the
ray.data.DataIterator
shard for this worker.Call
iter_torch_batches()
orto_tf()
on this shard to convert it to the appropriate framework-specific data type.import ray from ray import train from ray.train import ScalingConfig from ray.train.torch import TorchTrainer def train_loop_per_worker(config): ... for epoch in range(2): # Trainer will automatically handle sharding. data_shard = train.get_dataset_shard("train") for batch in data_shard.iter_torch_batches(): ... train_dataset = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv") trainer = TorchTrainer( train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset} ) trainer.fit()
- Parameters:
dataset_name – If a Dictionary of Datasets was passed to
Trainer
, then specifies which dataset shard to return.- Returns:
The
DataIterator
shard to use for this worker. If no dataset is passed into Trainer, then return None.