Source code for ray.tune.integration.mxnet

from typing import Dict, List, Union, Optional

from ray import tune

import mxnet
from mxnet.model import save_checkpoint, BatchEndParam
import numpy as np

import os


class TuneCallback:
    """Base class for Tune's MXNet callbacks."""

    pass


[docs]class TuneReportCallback(TuneCallback): """MXNet to Ray Tune reporting callback Reports metrics to Ray Tune. This has to be passed to MXNet as the ``eval_end_callback``. Args: metrics: Metrics to report to Tune. If this is a list, each item describes the metric key reported to MXNet, and it will reported under the same name to Tune. If this is a dict, each key will be the name reported to Tune and the respective value will be the metric key reported to MXNet. Example: .. code-block:: python from ray.tune.integration.mxnet import TuneReportCallback # mlp_model is a MXNet model mlp_model.fit( train_iter, # ... eval_metric="acc", eval_end_callback=TuneReportCallback({ "mean_accuracy": "accuracy" })) """ def __init__(self, metrics: Optional[Union[str, List[str], Dict[str, str]]] = None): if isinstance(metrics, str): metrics = [metrics] self._metrics = metrics def __call__(self, param: BatchEndParam): if not param.eval_metric: return if not self._metrics: report_dict = dict(param.eval_metric.get_name_value()) else: report_dict = {} lookup_dict = dict(param.eval_metric.get_name_value()) for key in self._metrics: if isinstance(self._metrics, dict): metric = self._metrics[key] else: metric = key report_dict[key] = lookup_dict[metric] tune.report(**report_dict)
[docs]class TuneCheckpointCallback(TuneCallback): """MXNet checkpoint callback Saves checkpoints after each epoch. This has to be passed to the ``epoch_end_callback`` of the MXNet model. Checkpoint are currently not registered if no ``tune.report()`` call is made afterwards. You have to use this in conjunction with the ``TuneReportCallback`` to work! Args: filename: Filename of the checkpoint within the checkpoint directory. Defaults to "checkpoint". frequency: Integer indicating how often checkpoints should be saved. Example: .. code-block:: python from ray.tune.integration.mxnet import TuneReportCallback, \ TuneCheckpointCallback # mlp_model is a MXNet model mlp_model.fit( train_iter, # ... eval_metric="acc", eval_end_callback=TuneReportCallback({ "mean_accuracy": "accuracy" }), epoch_end_callback=TuneCheckpointCallback( filename="mxnet_cp", frequency=3 )) """ def __init__(self, filename: str = "checkpoint", frequency: int = 1): self._filename = filename self._frequency = frequency def __call__( self, epoch: int, sym: mxnet.symbol.Symbol, arg: Dict[str, np.ndarray], aux: Dict[str, np.ndarray], ): if epoch % self._frequency != 0: return with tune.checkpoint_dir(step=epoch) as checkpoint_dir: save_checkpoint( os.path.join(checkpoint_dir, self._filename), epoch, sym, arg, aux )