Source code for ray.air.integrations.keras

from typing import Dict, List, Optional, Union

from tensorflow.keras.callbacks import Callback as KerasCallback

from ray.air import session
from ray.train.tensorflow import TensorflowCheckpoint
from ray.util.annotations import PublicAPI, Deprecated


class _Callback(KerasCallback):
    """Base class for Air's Keras callbacks."""

    _allowed = [
        "epoch_begin",
        "epoch_end",
        "train_batch_begin",
        "train_batch_end",
        "test_batch_begin",
        "test_batch_end",
        "predict_batch_begin",
        "predict_batch_end",
        "train_begin",
        "train_end",
        "test_begin",
        "test_end",
        "predict_begin",
        "predict_end",
    ]

    def __init__(self, on: Union[str, List[str]] = "validation_end"):
        super(_Callback, self).__init__()

        if not isinstance(on, list):
            on = [on]
        if any(w not in self._allowed for w in on):
            raise ValueError(
                "Invalid trigger time selected: {}. Must be one of {}".format(
                    on, self._allowed
                )
            )
        self._on = on

    def _handle(self, logs: Dict, when: str):
        raise NotImplementedError

    def on_epoch_begin(self, epoch, logs=None):
        if "epoch_begin" in self._on:
            self._handle(logs, "epoch_begin")

    def on_epoch_end(self, epoch, logs=None):
        if "epoch_end" in self._on:
            self._handle(logs, "epoch_end")

    def on_train_batch_begin(self, batch, logs=None):
        if "train_batch_begin" in self._on:
            self._handle(logs, "train_batch_begin")

    def on_train_batch_end(self, batch, logs=None):
        if "train_batch_end" in self._on:
            self._handle(logs, "train_batch_end")

    def on_test_batch_begin(self, batch, logs=None):
        if "test_batch_begin" in self._on:
            self._handle(logs, "test_batch_begin")

    def on_test_batch_end(self, batch, logs=None):
        if "test_batch_end" in self._on:
            self._handle(logs, "test_batch_end")

    def on_predict_batch_begin(self, batch, logs=None):
        if "predict_batch_begin" in self._on:
            self._handle(logs, "predict_batch_begin")

    def on_predict_batch_end(self, batch, logs=None):
        if "predict_batch_end" in self._on:
            self._handle(logs, "predict_batch_end")

    def on_train_begin(self, logs=None):
        if "train_begin" in self._on:
            self._handle(logs, "train_begin")

    def on_train_end(self, logs=None):
        if "train_end" in self._on:
            self._handle(logs, "train_end")

    def on_test_begin(self, logs=None):
        if "test_begin" in self._on:
            self._handle(logs, "test_begin")

    def on_test_end(self, logs=None):
        if "test_end" in self._on:
            self._handle(logs, "test_end")

    def on_predict_begin(self, logs=None):
        if "predict_begin" in self._on:
            self._handle(logs, "predict_begin")

    def on_predict_end(self, logs=None):
        if "predict_end" in self._on:
            self._handle(logs, "predict_end")


[docs]@PublicAPI(stability="alpha") class ReportCheckpointCallback(_Callback): """Keras callback for Ray AIR reporting and checkpointing. .. note:: Metrics are always reported with checkpoints, even if the event isn't specified in ``report_metrics_on``. Example: .. code-block: python ############# Using it in TrainSession ############### from ray.air.integrations.keras import ReportCheckpointCallback def train_loop_per_worker(): strategy = tf.distribute.MultiWorkerMirroredStrategy() with strategy.scope(): model = build_model() model.fit(dataset_shard, callbacks=[ReportCheckpointCallback()]) Args: metrics: Metrics to report. If this is a list, each item describes the metric key reported to Keras, and it's reported under the same name. If this is a dict, each key is the name reported and the respective value is the metric key reported to Keras. If this is None, all Keras logs are reported. report_metrics_on: When to report metrics. Must be one of the Keras event hooks (less the ``on_``), e.g. "train_start" or "predict_end". Defaults to "epoch_end". checkpoint_on: When to save checkpoints. Must be one of the Keras event hooks (less the ``on_``), e.g. "train_start" or "predict_end". Defaults to "epoch_end". """ def __init__( self, checkpoint_on: Union[str, List[str]] = "epoch_end", report_metrics_on: Union[str, List[str]] = "epoch_end", metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, ): if isinstance(checkpoint_on, str): checkpoint_on = [checkpoint_on] if isinstance(report_metrics_on, str): report_metrics_on = [report_metrics_on] on = list(set(checkpoint_on + report_metrics_on)) super().__init__(on=on) self._checkpoint_on: List[str] = checkpoint_on self._report_metrics_on: List[str] = report_metrics_on self._metrics = metrics def _handle(self, logs: Dict, when: str): assert when in self._checkpoint_on or when in self._report_metrics_on metrics = self._get_reported_metrics(logs) if when in self._checkpoint_on: checkpoint = TensorflowCheckpoint.from_model(self.model) else: checkpoint = None session.report(metrics, checkpoint=checkpoint) def _get_reported_metrics(self, logs: Dict) -> Dict: assert isinstance(self._metrics, (type(None), str, list, dict)) if self._metrics is None: reported_metrics = logs elif isinstance(self._metrics, str): reported_metrics = {self._metrics: logs[self._metrics]} elif isinstance(self._metrics, list): reported_metrics = {metric: logs[metric] for metric in self._metrics} elif isinstance(self._metrics, dict): reported_metrics = { key: logs[metric] for key, metric in self._metrics.items() } assert isinstance(reported_metrics, dict) return reported_metrics
@Deprecated class Callback(_Callback): """ Keras callback for Ray AIR reporting and checkpointing. You can use this in both TuneSession and TrainSession. Example: .. code-block: python ############# Using it in TrainSession ############### from ray.air.integrations.keras import Callback def train_loop_per_worker(): strategy = tf.distribute.MultiWorkerMirroredStrategy() with strategy.scope(): model = build_model() #model.compile(...) model.fit(dataset_shard, callbacks=[Callback()]) Args: metrics: Metrics to report. If this is a list, each item describes the metric key reported to Keras, and it will reported under the same name. If this is a dict, each key will be the name reported and the respective value will be the metric key reported to Keras. If this is None, all Keras logs will be reported. on: When to report metrics. Must be one of the Keras event hooks (less the ``on_``), e.g. "train_start", or "predict_end". Defaults to "epoch_end". frequency: Checkpoint frequency. If this is an integer `n`, checkpoints are saved every `n` times each hook was called. If this is a list, it specifies the checkpoint frequencies for each hook individually. """ def __init__( self, metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, on: Union[str, List[str]] = "epoch_end", frequency: Union[int, List[int]] = 1, ): # TODO: Remove this class in 2.6. raise DeprecationWarning( "`ray.air.integrations.keras.Callback` is deprecated. Use " "`ray.air.integrations.keras.ReportCheckpointCallback` instead.", )