Source code for ray.train.tensorflow.train_loop_utils

import tensorflow as tf

from ray.util.annotations import PublicAPI

[docs]@PublicAPI(stability="beta") def prepare_dataset_shard(tf_dataset_shard: """A utility function that overrides default config for Tensorflow Dataset. This should be used on a TensorFlow ``Dataset`` created by calling ``iter_tf_batches()`` on a ```` returned by ``ray.train.get_dataset_shard()`` since the dataset has already been sharded across the workers. Args: tf_dataset_shard ( A TensorFlow Dataset. Returns: A TensorFlow Dataset with: - autosharding turned off - prefetching turned on with autotune enabled """ options = options.experimental_distribute.auto_shard_policy = ( ) return tf_dataset_shard.with_options(options).prefetch(