ray.train.tensorflow.TensorflowCheckpoint#

class ray.train.tensorflow.TensorflowCheckpoint(*args, **kwargs)[source]#

Bases: ray.air.checkpoint.Checkpoint

A Checkpoint with TensorFlow-specific functionality.

Create this from a generic Checkpoint by calling TensorflowCheckpoint.from_checkpoint(ckpt).

PublicAPI (beta): This API is in beta and may change before becoming stable.

class Flavor(value)[source]#

Bases: enum.Enum

An enumeration.

classmethod from_model(model: keras.engine.training.Model, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#

Create a Checkpoint that stores a Keras model.

The checkpoint created with this method needs to be paired with model when used.

Parameters
  • model – The Keras model, whose weights are stored in the checkpoint.

  • preprocessor – A fitted preprocessor to be applied before inference.

Returns

A TensorflowCheckpoint containing the specified model.

Examples

>>> from ray.train.tensorflow import TensorflowCheckpoint
>>> import tensorflow as tf
>>>
>>> model = tf.keras.applications.resnet.ResNet101()  
>>> checkpoint = TensorflowCheckpoint.from_model(model)  
classmethod from_h5(file_path: str, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#

Create a Checkpoint that stores a Keras model from H5 format.

The checkpoint generated by this method contains all the information needed. Thus no model is needed to be supplied when using this checkpoint.

file_path must maintain validity even after this function returns. Some new files/directories may be added to the parent directory of file_path, as a side effect of this method.

Parameters
  • file_path – The path to the .h5 file to load model from. This is the same path that is used for model.save(path).

  • preprocessor – A fitted preprocessor to be applied before inference.

Returns

A TensorflowCheckpoint converted from h5 format.

Examples

>>> import tensorflow as tf
>>> import ray
>>> from ray.train.batch_predictor import BatchPredictor
>>> from ray.train.tensorflow import (
...     TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor
... )
>>> from ray.air import session
>>> from ray.air.config import ScalingConfig
>>> def train_func():
...     model = tf.keras.Sequential(
...         [
...             tf.keras.layers.InputLayer(input_shape=()),
...             tf.keras.layers.Flatten(),
...             tf.keras.layers.Dense(10),
...             tf.keras.layers.Dense(1),
...         ]
...     )
...     model.save("my_model.h5")
...     checkpoint = TensorflowCheckpoint.from_h5("my_model.h5")
...     session.report({"my_metric": 1}, checkpoint=checkpoint)
>>> trainer = TensorflowTrainer(
...     train_loop_per_worker=train_func,
...     scaling_config=ScalingConfig(num_workers=2))
>>> result_checkpoint = trainer.fit().checkpoint  
>>> batch_predictor = BatchPredictor.from_checkpoint(
...     result_checkpoint, TensorflowPredictor)  
>>> batch_predictor.predict(ray.data.range(3))  
classmethod from_saved_model(dir_path: str, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint[source]#

Create a Checkpoint that stores a Keras model from SavedModel format.

The checkpoint generated by this method contains all the information needed. Thus no model is needed to be supplied when using this checkpoint.

dir_path must maintain validity even after this function returns. Some new files/directories may be added to dir_path, as a side effect of this method.

Parameters
  • dir_path – The directory containing the saved model. This is the same directory as used by model.save(dir_path).

  • preprocessor – A fitted preprocessor to be applied before inference.

Returns

A TensorflowCheckpoint converted from SavedModel format.

Examples

>>> import tensorflow as tf
>>> import ray
>>> from ray.train.batch_predictor import BatchPredictor
>>> from ray.train.tensorflow import (
... TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor)
>>> from ray.air import session
>>> from ray.air.config import ScalingConfig
>>> def train_fn():
...     model = tf.keras.Sequential(
...         [
...             tf.keras.layers.InputLayer(input_shape=()),
...             tf.keras.layers.Flatten(),
...             tf.keras.layers.Dense(10),
...             tf.keras.layers.Dense(1),
...         ])
...     model.save("my_model")
...     checkpoint = TensorflowCheckpoint.from_saved_model("my_model")
...     session.report({"my_metric": 1}, checkpoint=checkpoint)
>>> trainer = TensorflowTrainer(
...     train_loop_per_worker=train_fn,
...     scaling_config=ScalingConfig(num_workers=2))
>>> result_checkpoint = trainer.fit().checkpoint  
>>> batch_predictor = BatchPredictor.from_checkpoint(
...     result_checkpoint, TensorflowPredictor)  
>>> batch_predictor.predict(ray.data.range(3))  
get_model(model: Optional[Union[keras.engine.training.Model, Callable[[], keras.engine.training.Model]]] = None, model_definition: Optional[Callable[[], keras.engine.training.Model]] = None) keras.engine.training.Model[source]#

Retrieve the model stored in this checkpoint.

Parameters
  • model – This arg is expected only if the original checkpoint was created via TensorflowCheckpoint.from_model.

  • model_definition – This parameter is deprecated. Use model instead.

Returns

The Tensorflow Keras model stored in the checkpoint.