Configuring Training Datasets#

AIR builds its training data pipeline on Ray Datasets, which is a scalable, framework-agnostic data loading and preprocessing library. Datasets enables AIR to seamlessly load data for local and distributed training with Train.

This page describes how to setup and configure these datasets in Train under different scenarios and scales.

Overview#

The following figure illustrates a simple Ray AIR training job that (1) loads parquet data from S3, (2) applies a simple user-defined function to preprocess batches of data, and (3) runs an AIR Trainer with the given dataset and preprocessor.

../_images/ingest.svg

Let’s walk through the stages of what happens when Trainer.fit() is called.

Preprocessing: First, AIR will fit the preprocessor (e.g., compute statistics) on the "train" dataset, and then transform all given datasets with the fitted preprocessor. This is done by calling prep.fit_transform() on the train dataset passed to the Trainer, followed by prep.transform() on remaining datasets.

Training: Then, AIR passes the preprocessed dataset to Train workers (Ray actors) launched by the Trainer. Each worker calls get_dataset_shard() to get a handle to its assigned data shard. This returns a DatasetIterator, which can be used to loop over the data with iter_batches(), iter_torch_batches(), or to_tf(). Each of these returns a batch iterator for one epoch (a full pass over the original dataset).

Getting Started#

The following is a simple example of how to configure ingest for a dummy TorchTrainer. Below, we are passing a small tensor dataset to the Trainer via the datasets argument. In the Trainer’s train_loop_per_worker, we access the preprocessed dataset using get_dataset_shard().

import ray
from ray.air import session
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

# A simple preprocessor that just scales all values by 2.0.
preprocessor = BatchMapper(lambda df: df * 2, batch_format="pandas")


def train_loop_per_worker():
    # Get a handle to the worker's assigned DatasetIterator shard.
    data_shard: DatasetIterator = session.get_dataset_shard("train")

    # Manually iterate over the data 10 times (10 epochs).
    for _ in range(10):
        for batch in data_shard.iter_batches():
            print("Do some training on batch", batch)

    # Print the stats for performance debugging.
    print(data_shard.stats())


my_trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=1),
    datasets={
        "train": ray.data.range_tensor(1000),
    },
    preprocessor=preprocessor,
)
my_trainer.fit()

For local development and testing, you can also use the helper function make_local_dataset_iterator() to get a local DatasetIterator.

Configuring Ingest#

You can use the DatasetConfig object to configure how Datasets are preprocessed and split across training workers. Each DataParallelTrainer takes in a dataset_config constructor argument that takes in a mapping from Dataset name to a DatasetConfig object. If no dataset_config is passed in, the default configuration is used:

# The default DataParallelTrainer dataset config, which is inherited
# by sub-classes such as TorchTrainer, HorovodTrainer, etc.
_dataset_config = {
    # Fit preprocessors on the train dataset only. Split the dataset
    # across workers if scaling_config["num_workers"] > 1.
    "train": DatasetConfig(fit=True, split=True),
    # For all other datasets, use the defaults (don't fit, don't split).
    # The datasets will be transformed by the fitted preprocessor.
    "*": DatasetConfig(),
}

Here are some examples of configuring Dataset ingest options and what they do:

Enabling Streaming Ingest#

By default, AIR loads all datasets into the Ray object store at the start of training. This provides the best performance if the cluster can fit the datasets entirely in memory, or if the preprocessing step is expensive to run more than once.

import ray
from ray.air import session
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

# A simple preprocessor that just scales all values by 2.0.
preprocessor = BatchMapper(lambda df: df * 2, batch_format="pandas")


def train_loop_per_worker():
    # Get a handle to the worker's assigned DatasetIterator shard.
    data_shard: DatasetIterator = session.get_dataset_shard("train")

    # Manually iterate over the data 10 times (10 epochs).
    for _ in range(10):
        for batch in data_shard.iter_batches():
            print("Do some training on batch", batch)

    # Print the stats for performance debugging.
    print(data_shard.stats())


my_trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=1),
    datasets={
        "train": ray.data.range_tensor(1000),
    },
    preprocessor=preprocessor,
)
my_trainer.fit()

You should use bulk ingest when:

  • you have enough memory to fit data blocks in cluster object store; or

  • your preprocessing transform is expensive to recompute on each epoch

In streaming ingest mode, instead of loading the entire dataset into the Ray object store at once, AIR will load a fraction of the dataset at a time. This can be desirable when the dataset is very large, and caching it all at once would cause expensive disk spilling. The downside is that the dataset will have to be preprocessed on each epoch, which may be more expensive. Preprocessing is overlapped with training computation, but overall training throughput may still decrease if preprocessing is more expensive than the training computation (forward pass, backward pass, gradient sync).

To enable this mode, use the max_object_store_memory_fraction argument. This argument defaults to -1, meaning that bulk ingest should be used and the entire dataset should be computed and cached before training starts.

Use a float value 0 or greater to indicate the “window” size, i.e. the maximum fraction of object store memory that should be used at once. A reasonable value is 0.2, meaning 20% of available object store memory. Larger window sizes can improve performance by increasing parallelism. A window size of 1 or greater will likely result in spilling.

import ray
from ray.air import session
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig

# A simple preprocessor that just scales all values by 2.0.
preprocessor = BatchMapper(lambda df: df * 2, batch_format="pandas")


def train_loop_per_worker():
    data_shard: DatasetIterator = session.get_dataset_shard("train")

    # Iterate over 10 epochs of data.
    for _ in range(10):
        for batch in data_shard.iter_batches():
            print("Do some training on batch", batch)

    # View the stats for performance debugging.
    print(data_shard.stats())


my_trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=1),
    datasets={
        "train": ray.data.range_tensor(1000),
    },
    dataset_config={
        # Use 20% of object store memory.
        "train": DatasetConfig(max_object_store_memory_fraction=0.2),
    },
    preprocessor=preprocessor,
)
my_trainer.fit()

Use streaming ingest when:

  • you have large datasets that don’t fit into memory; and

  • re-executing the preprocessing step on each epoch is faster than caching the preprocessed dataset on disk and reloading from disk on each epoch

Note that this feature is experimental and the actual object store memory usage may vary. Please file a GitHub issue if you run into problems.

Shuffling Data#

Shuffling or data randomization is important for training high-quality models.

By default, AIR shuffles the assignment of data blocks (files) to dataset shards between epochs. You can disable this behavior by setting randomize_block_order to False in your DatasetConfig.

To randomize data records within a file, perform a local or global shuffle.

Local shuffling is the recommended approach for randomizing data order. To use local shuffle, simply specify a non-zero local_shuffle_buffer_size as an argument to iter_batches(). The iterator will then use a local buffer of the given size to randomize record order. The larger the buffer size, the more randomization will be applied, but it will also use more memory.

See iter_batches() for more details.

import ray
from ray.air import session
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import DatasetConfig, ScalingConfig


def train_loop_per_worker():
    data_shard: DatasetIterator = session.get_dataset_shard("train")

    # Iterate over 10 epochs of data.
    for epoch in range(10):
        for batch in data_shard.iter_batches(
            batch_size=10_000,
            local_shuffle_buffer_size=100_000,
        ):
            print("Do some training on batch", batch)

    # View the stats for performance debugging.
    print(data_shard.stats())


my_trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=2),
    datasets={"train": ray.data.range_tensor(1000)},
    dataset_config={
        # global_shuffle is disabled by default, but we're emphasizing here that you
        # would NOT want to use both global and local shuffling together.
        "train": DatasetConfig(global_shuffle=False),
    },
)
print(my_trainer.get_dataset_config())
# -> {'train': DatasetConfig(fit=True, split=True, global_shuffle=False, ...)}
my_trainer.fit()

You should use local shuffling when:

  • a small in-memory buffer provides enough randomization; or

  • you want the highest possible ingest performance; or

  • your model is not overly sensitive to shuffle quality

Global shuffling provides more uniformly random (decorrelated) samples and is carried out via a distributed map-reduce operation. This higher quality shuffle can often lead to more precision gain per training step, but it is also an expensive distributed operation and will decrease the ingest throughput. The shuffle step is overlapped with training computation, so as long as the shuffled ingest throughput matches or exceeds the model training (forward pass, backward pass, gradient sync) throughput, this higher-quality shuffle shouldn’t slow down the overall training.

If global shuffling is causing the ingest throughput to become the training bottleneck, local shuffling may be a better option.

import ray
from ray.air import session
from ray.data import DatasetIterator
from ray.train.torch import TorchTrainer
from ray.air.config import DatasetConfig, ScalingConfig


def train_loop_per_worker():
    data_shard: DatasetIterator = session.get_dataset_shard("train")

    # Iterate over 10 epochs of data.
    for epoch in range(10):
        for batch in data_shard.iter_batches():
            print("Do some training on batch", batch)

    # View the stats for performance debugging.
    print(data_shard.stats())


my_trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=2),
    datasets={"train": ray.data.range_tensor(1000)},
    dataset_config={
        "train": DatasetConfig(global_shuffle=True),
    },
)
print(my_trainer.get_dataset_config())
# -> {'train': DatasetConfig(fit=True, split=True, global_shuffle=True, ...)}
my_trainer.fit()

You should use global shuffling when:

  • you suspect high-quality shuffles may significantly improve model quality; and

  • absolute ingest performance is less of a concern

Applying randomized preprocessing (experimental)#

The standard preprocessor passed to the Trainer is only applied once to the initial dataset when using bulk ingest. However, in some cases you may want to reapply a preprocessor on each epoch, for example to augment your training dataset with a randomized transform.

To support this use case, AIR offers an additional per-epoch preprocessor that gets reapplied on each epoch, after all other preprocessors and right before dataset consumption (e.g., using iter_batches()). Per-epoch preprocessing also executes in parallel with dataset consumption to reduce pauses in dataset consumption.

This example shows how to use this feature to apply a randomized preprocessor on top of the standard preprocessor.

import random

import ray
from ray.air import session
from ray.data import DatasetIterator
from ray.data.preprocessors import BatchMapper
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig

# A simple preprocessor that just scales all values by 2.0.
preprocessor = BatchMapper(lambda df: df * 2, batch_format="pandas")

# A randomized preprocessor that adds a random float to all values, to be
# reapplied on each epoch after `preprocessor`. Each epoch will therefore add a
# different random float to the scaled dataset.
add_noise = BatchMapper(lambda df: df + random.random(), batch_format="pandas")


def train_loop_per_worker():
    # Get a handle to the worker's assigned DatasetIterator shard.
    data_shard: DatasetIterator = session.get_dataset_shard("train")

    # Manually iterate over the data 10 times (10 epochs).
    for _ in range(10):
        for batch in data_shard.iter_batches():
            print("Do some training on batch", batch)

    # Print the stats for performance debugging.
    print(data_shard.stats())


my_trainer = TorchTrainer(
    train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=1),
    datasets={
        "train": ray.data.range_tensor(100),
    },
    dataset_config={
        "train": DatasetConfig(
            # Don't randomize order, just to make it easier to read the results.
            randomize_block_order=False,
            per_epoch_preprocessor=add_noise,
        ),
    },
    preprocessor=preprocessor,
)
my_trainer.fit()

Splitting Auxiliary Datasets#

During data parallel training, the datasets are split so that each model replica is training on a different shard of data. By default, only the "train" dataset is split. All the other Datasets are not split and the entire dataset is returned by get_dataset_shard().

However, you may want to split a large validation dataset example to also do data parallel validation. This example shows overriding the split config for the “valid” and “test” datasets. This means that both the valid and test datasets here will be .split() across the training workers.

import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig

train_ds = ray.data.range_tensor(1000)
valid_ds = ray.data.range_tensor(100)
test_ds = ray.data.range_tensor(100)

my_trainer = TorchTrainer(
    lambda: None,  # No-op training loop.
    scaling_config=ScalingConfig(num_workers=2),
    datasets={
        "train": train_ds,
        "valid": valid_ds,
        "test": test_ds,
    },
    dataset_config={
        "valid": DatasetConfig(split=True),
        "test": DatasetConfig(split=True),
    },
)
print(my_trainer.get_dataset_config())
# -> {'train': DatasetConfig(fit=True, split=True, ...),
#     'valid': DatasetConfig(fit=False, split=True, ...),
#     'test': DatasetConfig(fit=False, split=True, ...), ...}

Disabling Preprocessor Transforms#

By default, the provided Preprocessor is fit on the "train" dataset and is then used to transform all the datasets. However, you may want to disable the preprocessor transforms for certain datasets.

This example shows overriding the transform config for the “side” dataset. This means that the original dataset will be returned by .get_dataset_shard("side").

import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig

train_ds = ray.data.range_tensor(1000)
side_ds = ray.data.range_tensor(10)

my_trainer = TorchTrainer(
    lambda: None,  # No-op training loop.
    scaling_config=ScalingConfig(num_workers=2),
    datasets={
        "train": train_ds,
        "side": side_ds,
    },
    dataset_config={
        "side": DatasetConfig(transform=False),
    },
)
print(my_trainer.get_dataset_config())
# -> {'train': DatasetConfig(fit=True, split=True, ...),
#     'side': DatasetConfig(fit=False, split=False, transform=False, ...), ...}

Dataset Resources#

Datasets uses Ray tasks to execute data processing operations. These tasks use CPU resources in the cluster during execution, which may compete with resources needed for Training.

By default, Dataset tasks use cluster CPU resources for execution. This can sometimes conflict with Trainer resource requests. For example, if Trainers allocate all CPU resources in the cluster, then no Datasets tasks can run.

import ray
from ray.air import session
from ray.data.preprocessors import BatchMapper
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

# Create a cluster with 4 CPU slots available.
ray.init(num_cpus=4)

# A simple example training loop.
def train_loop_per_worker():
    data_shard = session.get_dataset_shard("train")
    for _ in range(10):
        for batch in data_shard.iter_batches():
            print("Do some training on batch", batch)


# A simple preprocessor that just scales all values by 2.0.
preprocessor = BatchMapper(lambda df: df * 2, batch_format="pandas")

my_trainer = TorchTrainer(
    train_loop_per_worker,
    # This will hang if you set num_workers=4, since the
    # Trainer will reserve all 4 CPUs for workers, leaving
    # none left for Datasets execution.
    scaling_config=ScalingConfig(num_workers=2),
    datasets={
        "train": ray.data.range_tensor(1000),
    },
    preprocessor=preprocessor,
)
my_trainer.fit()

Unreserved CPUs work well when:

  • you are running only one Trainer and the cluster has enough CPUs; or

  • your Trainers are configured to use GPUs and not CPUs

The _max_cpu_fraction_per_node option can be used to exclude CPUs from placement group scheduling. In the below example, setting this parameter to 0.8 enables Tune trials to run smoothly without risk of deadlock by reserving 20% of node CPUs for Dataset execution.

import ray
from ray.air import session
from ray.data.preprocessors import BatchMapper
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

# Create a cluster with 4 CPU slots available.
ray.init(num_cpus=4)

# A simple example training loop.
def train_loop_per_worker():
    data_shard = session.get_dataset_shard("train")
    for _ in range(10):
        for batch in data_shard.iter_batches():
            print("Do some training on batch", batch)


# A simple preprocessor that just scales all values by 2.0.
preprocessor = BatchMapper(lambda df: df * 2, batch_format="pandas")

my_trainer = TorchTrainer(
    train_loop_per_worker,
    # This will hang if you set num_workers=4, since the
    # Trainer will reserve all 4 CPUs for workers, leaving
    # none left for Datasets execution.
    scaling_config=ScalingConfig(num_workers=2),
    datasets={
        "train": ray.data.range_tensor(1000),
    },
    preprocessor=preprocessor,
)
my_trainer.fit()

You should use reserved CPUs when:

  • you are running multiple concurrent CPU Trainers using Tune; or

  • you want to ensure predictable Datasets performance

Warning

_max_cpu_fraction_per_node is experimental and not currently recommended for use with autoscaling clusters (scale-up will not trigger properly).

Debugging Ingest with the DummyTrainer#

Data ingest problems can be challenging to debug when combined in a full training pipeline. To isolate data ingest issues from other possible training problems, we provide the DummyTrainer utility class that can be used to debug ingest problems. You can also use the helper function make_local_dataset_iterator() to get a local DatasetIterator for debugging purposes. Let’s walk through using DummyTrainer to understand and resolve an ingest misconfiguration.

Setting it up#

First, let’s create a synthetic in-memory dataset and setup a simple preprocessor pipeline. For this example, we’ll run it on a 3-node cluster with m5.4xlarge nodes. In practice we might want to use a single machine to keep data local, but we’ll use a cluster for illustrative purposes.

import ray
from ray.data.preprocessors import Chain, BatchMapper
from ray.air.util.check_ingest import DummyTrainer
from ray.air.config import ScalingConfig

# Generate a synthetic dataset of ~10GiB of float64 data. The dataset is sharded
# into 100 blocks (parallelism=100).
dataset = ray.data.range_tensor(50000, shape=(80, 80, 4), parallelism=100)

# An example preprocessor chain that just scales all values by 4.0 in two stages.
preprocessor = Chain(
    BatchMapper(lambda df: df * 2, batch_format="pandas"),
    BatchMapper(lambda df: df * 2, batch_format="pandas"),
)

Next, we instantiate and fit a DummyTrainer with a single training worker and no GPUs. You can customize these parameters to simulate your use training use cases (e.g., 16 trainers each with GPUs enabled).

# Setup the dummy trainer that prints ingest stats.
# Run and print ingest stats.
trainer = DummyTrainer(
    scaling_config=ScalingConfig(num_workers=1, use_gpu=False),
    datasets={"train": dataset},
    preprocessor=preprocessor,
    num_epochs=1,  # Stop after this number of epochs is read.
    prefetch_batches=1,  # Number of batches to prefetch when reading data.
    batch_size=None,  # Use whole blocks as batches.
)
trainer.fit()

Understanding the output#

Let’s walk through the output. First, the job starts and executes preprocessing. You can see that the preprocessing runs in 6.8s below. The dataset stats for the preprocessing is also printed:

Starting dataset preprocessing
Preprocessed datasets in 6.874227493000035 seconds
Preprocessor Chain(preprocessors=(BatchMapper(fn=<lambda>), BatchMapper(fn=<lambda>)))
Preprocessor transform stats:

Stage 1 read->map_batches: 100/100 blocks executed in 4.57s
* Remote wall time: 120.68ms min, 522.36ms max, 251.53ms mean, 25.15s total
* Remote cpu time: 116.55ms min, 278.08ms max, 216.38ms mean, 21.64s total
* Output num rows: 500 min, 500 max, 500 mean, 50000 total
* Output size bytes: 102400128 min, 102400128 max, 102400128 mean, 10240012800 total
* Tasks per node: 16 min, 48 max, 33 mean; 3 nodes used

Stage 2 map_batches: 100/100 blocks executed in 2.22s
* Remote wall time: 89.07ms min, 302.71ms max, 175.12ms mean, 17.51s total
* Remote cpu time: 89.22ms min, 207.53ms max, 137.5ms mean, 13.75s total
* Output num rows: 500 min, 500 max, 500 mean, 50000 total
* Output size bytes: 102400128 min, 102400128 max, 102400128 mean, 10240012800 total
* Tasks per node: 30 min, 37 max, 33 mean; 3 nodes used

When the train job finishes running, it will print out some more statistics.

P50/P95/Max batch delay (s) 1.101227020500005 1.120024863100042 1.9424749629999951
Num epochs read 1
Num batches read 100
Num bytes read 9765.64 MiB
Mean throughput 116.59 MiB/s

Let’s break it down:

  • Batch delay: Time the trainer spents waiting for the next data batch to be fetched. Ideally this value is as close to zero as possible. If it is too high, Ray may be spending too much time downloading data from remote nodes to the trainer node.

  • Num epochs read: The number of times the trainer read the dataset during the run.

  • Num batches read: The number of batches read.

  • Num bytes read: The number of bytes read.

  • Mean throughput: The average read throughput.

Finally, we can query memory statistics (this can be run in the middle of a job) to get an idea of how this workload used the object store.

ray memory --stats-only

As you can see, this run used 18GiB of object store memory, which was 32% of the total memory available on the cluster. No disk spilling was reported:

--- Aggregate object store stats across all nodes ---
Plasma memory usage 18554 MiB, 242 objects, 32.98% full, 0.17% needed
Objects consumed by Ray tasks: 38965 MiB.

Debugging the performance problem#

So why was the data ingest only 116MiB/s above? That’s sufficient for many models, but one would expect faster if the trainer was doing nothing except read the data. Based on the stats above, there was no object spilling, but there was a high batch delay.

We can guess that perhaps AIR was spending too much time loading blocks from other machines, since we were using a multi-node cluster. We can test this by setting prefetch_blocks=10 to prefetch blocks more aggressively and rerunning training.

P50/P95/Max batch delay (s) 0.0006792084998323844 0.0009853049503362856 0.12657493300002898
Num epochs read 47
Num batches read 4700
Num bytes read 458984.95 MiB
Mean throughput 15136.18 MiB/s

That’s much better! Now we can see that our DummyTrainer is ingesting data at a rate of 15000MiB/s, and was able to read through many more epochs of training. This high throughput means that all data was able to be fit into memory on a single node.

Going from DummyTrainer to your real Trainer#

Once you’re happy with the ingest performance of with DummyTrainer with synthetic data, the next steps are to switch to adapting it for your real workload scenario. This involves:

  • Scaling the DummyTrainer: Change the scaling config of the DummyTrainer and cluster configuration to reflect your target workload.

  • Switching the Dataset: Change the dataset from synthetic tensor data to reading your real dataset.

  • Switching the Trainer: Swap the DummyTrainer with your real trainer.

Switching these components one by one allows performance problems to be easily isolated and reproduced.

Performance Tips#

Memory availability: To maximize ingest performance, consider using machines with sufficient memory to fit the dataset entirely in memory. This avoids the need for disk spilling, streamed ingest, or fetching data across the network. As a rule of thumb, a Ray cluster with fewer but bigger nodes will outperform a Ray cluster with more smaller nodes due to better memory locality.

Autoscaling: We generally recommend first trying out AIR training with a fixed size cluster. This makes it easier to understand and debug issues. Autoscaling can be enabled after you are happy with performance to autoscale experiment sweeps with Tune, etc. We also recommend starting with autoscaling with a single node type. Autoscaling with hetereogeneous clusters can optimize costs, but may complicate performance debugging.

Partitioning: By default, Datasets will automatically select the read parallelism based on the current cluster size and number of files. If you run into out-of-memory errors during preprocessing, consider increasing the number of blocks to reduce their size. To increase the max number of partitions, you can manually set the parallelism option when calling ray.data.read_*(). To change the number of partitions at runtime, use ds.repartition(N). As a rule of thumb, blocks should be no more than 1-2GiB each.

Dataset Sharing#

When you pass Datasets to a Tuner, Datasets are executed independently per-trial. This could potentially duplicate data reads in the cluster. To share Dataset blocks between trials, call ds = ds.materialize() prior to passing the Dataset to the Tuner. This ensures that the initial read operation will not be repeated per trial.

FAQ#

How do I pass in a DatasetPipeline to my Trainer?#

The Trainer interface only accepts a standard Dataset and not a DatasetPipeline. Instead, you can configure the ingest via the dataset_config that is passed to your Trainer. Internally, Ray AIR will convert the provided Dataset into a DatasetPipeline with the specified configurations.

See the Enabling Streaming Ingest and Shuffling Data sections for full examples.

How do I shard validation and test datasets?#

By default only the "train" Dataset is sharded. To also shard validation and test datasets, you can configure the dataset_config that is passed to your Trainer. See the Splitting Auxiliary Datasets section for a full example.