ray.air.session.get_dataset_shard
ray.air.session.get_dataset_shard#
- ray.air.session.get_dataset_shard(dataset_name: Optional[str] = None) Optional[DatasetIterator] [source]#
Returns the
ray.data.DatasetIterator
shard for this worker.Call
iter_torch_batches()
orto_tf()
on this shard to convert it to the appropriate framework-specific data type.import ray from ray import train from ray.air import session from ray.air.config import ScalingConfig def train_loop_per_worker(): model = Net() for iter in range(100): # Trainer will automatically handle sharding. data_shard = session.get_dataset_shard("train") for batch in data_shard.iter_torch_batches(): # ... return model train_dataset = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32)]) trainer = TorchTrainer(train_loop_per_worker, scaling_config=ScalingConfig(num_workers=2), datasets={"train": train_dataset}) trainer.fit()
- Parameters
dataset_name – If a Dictionary of Datasets was passed to
Trainer
, then specifies which dataset shard to return.- Returns
The
DatasetIterator
shard to use for this worker. If no dataset is passed into Trainer, then return None.
PublicAPI (beta): This API is in beta and may change before becoming stable.