ray.train.tensorflow.TensorflowCheckpoint
ray.train.tensorflow.TensorflowCheckpoint#
- class ray.train.tensorflow.TensorflowCheckpoint(*args, **kwargs)[source]#
Bases:
ray.air.checkpoint.Checkpoint
A
Checkpoint
with TensorFlow-specific functionality.Create this from a generic
Checkpoint
by callingTensorflowCheckpoint.from_checkpoint(ckpt)
.PublicAPI (beta): This API is in beta and may change before becoming stable.
- classmethod from_model(model: keras.engine.training.Model, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint [source]#
Create a
Checkpoint
that stores a Keras model.The checkpoint created with this method needs to be paired with
model
when used.- Parameters
model – The Keras model, whose weights are stored in the checkpoint.
preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TensorflowCheckpoint
containing the specified model.
Examples
>>> from ray.train.tensorflow import TensorflowCheckpoint >>> import tensorflow as tf >>> >>> model = tf.keras.applications.resnet.ResNet101() >>> checkpoint = TensorflowCheckpoint.from_model(model)
- classmethod from_h5(file_path: str, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint [source]#
Create a
Checkpoint
that stores a Keras model from H5 format.The checkpoint generated by this method contains all the information needed. Thus no
model
is needed to be supplied when using this checkpoint.file_path
must maintain validity even after this function returns. Some new files/directories may be added to the parent directory offile_path
, as a side effect of this method.- Parameters
file_path – The path to the .h5 file to load model from. This is the same path that is used for
model.save(path)
.preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TensorflowCheckpoint
converted from h5 format.
Examples
>>> import tensorflow as tf
>>> import ray >>> from ray.train.batch_predictor import BatchPredictor >>> from ray.train.tensorflow import ( ... TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor ... ) >>> from ray.air import session >>> from ray.air.config import ScalingConfig
>>> def train_func(): ... model = tf.keras.Sequential( ... [ ... tf.keras.layers.InputLayer(input_shape=()), ... tf.keras.layers.Flatten(), ... tf.keras.layers.Dense(10), ... tf.keras.layers.Dense(1), ... ] ... ) ... model.save("my_model.h5") ... checkpoint = TensorflowCheckpoint.from_h5("my_model.h5") ... session.report({"my_metric": 1}, checkpoint=checkpoint)
>>> trainer = TensorflowTrainer( ... train_loop_per_worker=train_func, ... scaling_config=ScalingConfig(num_workers=2))
>>> result_checkpoint = trainer.fit().checkpoint
>>> batch_predictor = BatchPredictor.from_checkpoint( ... result_checkpoint, TensorflowPredictor) >>> batch_predictor.predict(ray.data.range(3))
- classmethod from_saved_model(dir_path: str, *, preprocessor: Optional[Preprocessor] = None) TensorflowCheckpoint [source]#
Create a
Checkpoint
that stores a Keras model from SavedModel format.The checkpoint generated by this method contains all the information needed. Thus no
model
is needed to be supplied when using this checkpoint.dir_path
must maintain validity even after this function returns. Some new files/directories may be added todir_path
, as a side effect of this method.- Parameters
dir_path – The directory containing the saved model. This is the same directory as used by
model.save(dir_path)
.preprocessor – A fitted preprocessor to be applied before inference.
- Returns
A
TensorflowCheckpoint
converted from SavedModel format.
Examples
>>> import tensorflow as tf
>>> import ray >>> from ray.train.batch_predictor import BatchPredictor >>> from ray.train.tensorflow import ( ... TensorflowCheckpoint, TensorflowTrainer, TensorflowPredictor) >>> from ray.air import session >>> from ray.air.config import ScalingConfig
>>> def train_fn(): ... model = tf.keras.Sequential( ... [ ... tf.keras.layers.InputLayer(input_shape=()), ... tf.keras.layers.Flatten(), ... tf.keras.layers.Dense(10), ... tf.keras.layers.Dense(1), ... ]) ... model.save("my_model") ... checkpoint = TensorflowCheckpoint.from_saved_model("my_model") ... session.report({"my_metric": 1}, checkpoint=checkpoint)
>>> trainer = TensorflowTrainer( ... train_loop_per_worker=train_fn, ... scaling_config=ScalingConfig(num_workers=2))
>>> result_checkpoint = trainer.fit().checkpoint
>>> batch_predictor = BatchPredictor.from_checkpoint( ... result_checkpoint, TensorflowPredictor) >>> batch_predictor.predict(ray.data.range(3))
- get_model(model: Optional[Union[keras.engine.training.Model, Callable[[], keras.engine.training.Model]]] = None, model_definition: Optional[Callable[[], keras.engine.training.Model]] = None) keras.engine.training.Model [source]#
Retrieve the model stored in this checkpoint.
- Parameters
model – This arg is expected only if the original checkpoint was created via
TensorflowCheckpoint.from_model
.model_definition – This parameter is deprecated. Use
model
instead.
- Returns
The Tensorflow Keras model stored in the checkpoint.