ray.tune.Callback.get_state#

Callback.get_state() Dict | None[source]#

Get the state of the callback.

This method should be implemented by subclasses to return a dictionary representation of the object’s current state.

This is called automatically by Tune to periodically checkpoint callback state. Upon Tune experiment restoration, callback state will be restored via set_state().

from typing import Dict, List, Optional

from ray.tune import Callback
from ray.tune.experiment import Trial

class MyCallback(Callback):
    def __init__(self):
        self._trial_ids = set()

    def on_trial_start(
        self, iteration: int, trials: List["Trial"], trial: "Trial", **info
    ):
        self._trial_ids.add(trial.trial_id)

    def get_state(self) -> Optional[Dict]:
        return {"trial_ids": self._trial_ids.copy()}

    def set_state(self, state: Dict) -> Optional[Dict]:
        self._trial_ids = state["trial_ids"]
Returns:

State of the callback. Should be None if the callback does not have any state to save (this is the default).

Return type:

dict