Source code for ray.train.v2.api.callback
from typing import Any, Dict, List, Optional
from ray.train import Checkpoint
from ray.train.v2._internal.execution.context import TrainRunContext
from ray.util.annotations import DeveloperAPI
@DeveloperAPI
class RayTrainCallback:
"""Base Ray Train callback interface."""
pass
@DeveloperAPI
class UserCallback(RayTrainCallback):
"""Callback interface for custom user-defined callbacks to handling events
during training.
This callback is called on the Ray Train controller process, not on the
worker processes.
"""
[docs]
def after_report(
self,
run_context: TrainRunContext,
metrics: List[Dict[str, Any]],
checkpoint: Optional[Checkpoint],
):
"""Called after all workers have reported a metric + checkpoint
via `ray.train.report`.
Args:
run_context: The `TrainRunContext` for the current training run.
metrics: A list of metric dictionaries reported by workers,
where metrics[i] is the metrics dict reported by worker i.
checkpoint: A Checkpoint object that has been persisted to
storage. This is None if no workers reported a checkpoint
(e.g. `ray.train.report(metrics, checkpoint=None)`).
"""
pass
[docs]
def after_exception(
self, run_context: TrainRunContext, worker_exceptions: Dict[int, Exception]
):
"""Called after one or more workers have raised an exception.
Args:
run_context: The `TrainRunContext` for the current training run.
worker_exceptions: A dict from worker world rank to the exception
raised by that worker.
"""
pass