Source code for ray.train.tensorflow.train_loop_utils
import tensorflow as tf
from ray.util.annotations import PublicAPI
[docs]
@PublicAPI(stability="beta")
def prepare_dataset_shard(tf_dataset_shard: tf.data.Dataset):
"""A utility function that overrides default config for Tensorflow Dataset.
This should be used on a TensorFlow ``Dataset`` created by calling
``iter_tf_batches()`` on a ``ray.data.Dataset`` returned by
``ray.train.get_dataset_shard()`` since the dataset has already
been sharded across the workers.
Args:
tf_dataset_shard (tf.data.Dataset): A TensorFlow Dataset.
Returns:
A TensorFlow Dataset with:
- autosharding turned off
- prefetching turned on with autotune enabled
"""
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF
)
return tf_dataset_shard.with_options(options).prefetch(tf.data.AUTOTUNE)