Source code for ray.train.tensorflow.train_loop_utils

import tensorflow as tf

from ray.util.annotations import PublicAPI


[docs] @PublicAPI(stability="beta") def prepare_dataset_shard(tf_dataset_shard: tf.data.Dataset): """A utility function that overrides default config for Tensorflow Dataset. This should be used on a TensorFlow ``Dataset`` created by calling ``iter_tf_batches()`` on a ``ray.data.Dataset`` returned by ``ray.train.get_dataset_shard()`` since the dataset has already been sharded across the workers. Args: tf_dataset_shard (tf.data.Dataset): A TensorFlow Dataset. Returns: A TensorFlow Dataset with: - autosharding turned off - prefetching turned on with autotune enabled """ options = tf.data.Options() options.experimental_distribute.auto_shard_policy = ( tf.data.experimental.AutoShardPolicy.OFF ) return tf_dataset_shard.with_options(options).prefetch(tf.data.AUTOTUNE)