ray.air.integrations.wandb.WandbLoggerCallback
ray.air.integrations.wandb.WandbLoggerCallback#
- class ray.air.integrations.wandb.WandbLoggerCallback(project: Optional[str] = None, group: Optional[str] = None, api_key_file: Optional[str] = None, api_key: Optional[str] = None, excludes: Optional[List[str]] = None, log_config: bool = False, upload_checkpoints: bool = False, save_checkpoints: bool = False, upload_timeout: int = 1800, **kwargs)[source]#
Bases:
ray.tune.logger.logger.LoggerCallback
Weights and biases (https://www.wandb.ai/) is a tool for experiment tracking, model optimization, and dataset versioning. This Ray Tune
LoggerCallback
sends metrics to Wandb for automatic tracking and visualization.Example
import random from ray import tune from ray.air import session, RunConfig from ray.air.integrations.wandb import WandbLoggerCallback def train_func(config): offset = random.random() / 5 for epoch in range(2, config["epochs"]): acc = 1 - (2 + config["lr"]) ** -epoch - random.random() / epoch - offset loss = (2 + config["lr"]) ** -epoch + random.random() / epoch + offset session.report({"acc": acc, "loss": loss}) tuner = tune.Tuner( train_func, param_space={ "lr": tune.grid_search([0.001, 0.01, 0.1, 1.0]), "epochs": 10, }, run_config=RunConfig( callbacks=[WandbLoggerCallback(project="Optimization_Project")] ), ) results = tuner.fit()
- Parameters
project β Name of the Wandb project. Mandatory.
group β Name of the Wandb group. Defaults to the trainable name.
api_key_file β Path to file containing the Wandb API KEY. This file only needs to be present on the node running the Tune script if using the WandbLogger.
api_key β Wandb API Key. Alternative to setting
api_key_file
.excludes β List of metrics and config that should be excluded from the log.
log_config β Boolean indicating if the
config
parameter of theresults
dict should be logged. This makes sense if parameters will change during training, e.g. with PopulationBasedTraining. Defaults to False.upload_checkpoints β If
True
, model checkpoints will be uploaded to Wandb as artifacts. Defaults toFalse
.**kwargs β The keyword arguments will be pased to
wandb.init()
.
Wandbβs
group
,run_id
andrun_name
are automatically selected by Tune, but can be overwritten by filling out the respective configuration values.Please see here for all other valid configuration settings: https://docs.wandb.ai/library/init
- AUTO_CONFIG_KEYS = ['trial_id', 'experiment_tag', 'node_ip', 'experiment_id', 'hostname', 'pid', 'date']#
Results that are saved with
wandb.config
instead ofwandb.log
.
- setup(*args, **kwargs)[source]#
Called once at the very beginning of training.
Any Callback setup should be added here (setting environment variables, etc.)
- Parameters
stop β Stopping criteria. If
time_budget_s
was passed toair.RunConfig
, aTimeoutStopper
will be passed here, either by itself or as a part of aCombinedStopper
.num_samples β Number of times to sample from the hyperparameter space. Defaults to 1. If
grid_search
is provided as an argument, the grid will be repeatednum_samples
of times. If this is -1, (virtually) infinite samples are generated until a stopping condition is met.total_num_samples β Total number of samples factoring in grid search samplers.
**info β Kwargs dict for forward compatibility.
- log_trial_start(trial: ray.tune.experiment.trial.Trial)[source]#
Handle logging when a trial starts.
- Parameters
trial β Trial object.
- log_trial_result(iteration: int, trial: ray.tune.experiment.trial.Trial, result: Dict)[source]#
Handle logging when a trial reports a result.
- Parameters
trial β Trial object.
result β Result dictionary.
- log_trial_save(trial: ray.tune.experiment.trial.Trial)[source]#
Handle logging when a trial saves a checkpoint.
- Parameters
trial β Trial object.
- log_trial_end(trial: ray.tune.experiment.trial.Trial, failed: bool = False)[source]#
Handle logging when a trial ends.
- Parameters
trial β Trial object.
failed β True if the Trial finished gracefully, False if it failed (e.g. when it raised an exception).
- on_experiment_end(trials: List[ray.tune.experiment.trial.Trial], **info)[source]#
Wait for the actors to finish their call to
wandb.finish
. This includes uploading all logs + artifacts to wandb.