Using Weights & Biases with Tune

Weights & Biases (Wandb) is a tool for experiment tracking, model optimizaton, and dataset versioning. It is very popular in the machine learning and data science community for its superb visualization tools.

Weights & Biases

Ray Tune currently offers two lightweight integrations for Weights & Biases. One is the WandbLoggerCallback, which automatically logs metrics reported to Tune to the Wandb API.

The other one is the @wandb_mixin decorator, which can be used with the function API. It automatically initializes the Wandb API with Tune’s training information. You can just use the Wandb API like you would normally do, e.g. using wandb.log() to log your training process.

Running A Weights & Biases Example

In the following example we’re going to use both of the above methods, namely the WandbLoggerCallback and the wandb_mixin decorator to log metrics. Let’s start with a few crucial imports:

import numpy as np
import wandb

from ray import tune
from ray.tune import Trainable
from ray.tune.integration.wandb import (
    WandbLoggerCallback,
    WandbTrainableMixin,
    wandb_mixin,
)

Next, let’s define an easy objective function (a Tune Trainable) that reports a random loss to Tune. The objective function itself is not important for this example, since we want to focus on the Weights & Biases integration primarily.

def objective(config, checkpoint_dir=None):
    for i in range(30):
        loss = config["mean"] + config["sd"] * np.random.randn()
        tune.report(loss=loss)

Given that you provide an api_key_file pointing to your Weights & Biases API key, you cna define a simple grid-search Tune run using the WandbLoggerCallback as follows:

def tune_function(api_key_file):
    """Example for using a WandbLoggerCallback with the function API"""
    analysis = tune.run(
        objective,
        metric="loss",
        mode="min",
        config={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
        },
        callbacks=[
            WandbLoggerCallback(api_key_file=api_key_file, project="Wandb_example")
        ],
    )
    return analysis.best_config

To use the wandb_mixin decorator, you can simply decorate the objective function from earlier. Note that we also use wandb.log(...) to log the loss to Weights & Biases as a dictionary. Otherwise, the decorated version of our objective is identical to its original.

@wandb_mixin
def decorated_objective(config, checkpoint_dir=None):
    for i in range(30):
        loss = config["mean"] + config["sd"] * np.random.randn()
        tune.report(loss=loss)
        wandb.log(dict(loss=loss))

With the decorated_objective defined, running a Tune experiment is as simple as providing this objective and passing the api_key_file to the wandb key of your Tune config:

def tune_decorated(api_key_file):
    """Example for using the @wandb_mixin decorator with the function API"""
    analysis = tune.run(
        decorated_objective,
        metric="loss",
        mode="min",
        config={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
            "wandb": {"api_key_file": api_key_file, "project": "Wandb_example"},
        },
    )
    return analysis.best_config

Finally, you can also define a class-based Tune Trainable by using the WandbTrainableMixin to define your objective:

class WandbTrainable(WandbTrainableMixin, Trainable):
    def step(self):
        for i in range(30):
            loss = self.config["mean"] + self.config["sd"] * np.random.randn()
            wandb.log({"loss": loss})
        return {"loss": loss, "done": True}

Running Tune with this WandbTrainable works exactly the same as with the function API. The below tune_trainable function differs from tune_decorated above only in the first argument we pass to tune.run():

def tune_trainable(api_key_file):
    """Example for using a WandTrainableMixin with the class API"""
    analysis = tune.run(
        WandbTrainable,
        metric="loss",
        mode="min",
        config={
            "mean": tune.grid_search([1, 2, 3, 4, 5]),
            "sd": tune.uniform(0.2, 0.8),
            "wandb": {"api_key_file": api_key_file, "project": "Wandb_example"},
        },
    )
    return analysis.best_config

Since you may not have an API key for Wandb, we can mock the Wandb logger and test all three of our training functions as follows. If you do have an API key file, make sure to set mock_api to False and pass in the right api_key_file below.

import tempfile
from unittest.mock import MagicMock

mock_api = True

api_key_file = "~/.wandb_api_key"

if mock_api:
    WandbLoggerCallback._logger_process_cls = MagicMock
    decorated_objective.__mixins__ = tuple()
    WandbTrainable._wandb = MagicMock()
    wandb = MagicMock()  # noqa: F811
    temp_file = tempfile.NamedTemporaryFile()
    temp_file.write(b"1234")
    temp_file.flush()
    api_key_file = temp_file.name

tune_function(api_key_file)
tune_decorated(api_key_file)
tune_trainable(api_key_file)

if mock_api:
    temp_file.close()

This completes our Tune and Wandb walk-through. In the following sections you can find more details on the API of the Tune-Wandb integration.

Tune Wandb API Reference

WandbLoggerCallback

class ray.tune.integration.wandb.WandbLoggerCallback(project: str, group: Optional[str] = None, api_key_file: Optional[str] = None, api_key: Optional[str] = None, excludes: Optional[List[str]] = None, log_config: bool = False, save_checkpoints: bool = False, **kwargs)[source]

Weights and biases (https://www.wandb.ai/) is a tool for experiment tracking, model optimization, and dataset versioning. This Ray Tune LoggerCallback sends metrics to Wandb for automatic tracking and visualization.

Parameters
  • project – Name of the Wandb project. Mandatory.

  • group – Name of the Wandb group. Defaults to the trainable name.

  • api_key_file – Path to file containing the Wandb API KEY. This file only needs to be present on the node running the Tune script if using the WandbLogger.

  • api_key – Wandb API Key. Alternative to setting api_key_file.

  • excludes – List of metrics that should be excluded from the log.

  • log_config – Boolean indicating if the config parameter of the results dict should be logged. This makes sense if parameters will change during training, e.g. with PopulationBasedTraining. Defaults to False.

  • save_checkpoints – If True, model checkpoints will be saved to Wandb as artifacts. Defaults to False.

  • **kwargs – The keyword arguments will be pased to wandb.init().

Wandb’s group, run_id and run_name are automatically selected by Tune, but can be overwritten by filling out the respective configuration values.

Please see here for all other valid configuration settings: https://docs.wandb.ai/library/init

Example:

from ray.tune.logger import DEFAULT_LOGGERS
from ray.tune.integration.wandb import WandbLoggerCallback
tune.run(
    train_fn,
    config={
        # define search space here
        "parameter_1": tune.choice([1, 2, 3]),
        "parameter_2": tune.choice([4, 5, 6]),
    },
    callbacks=[WandbLoggerCallback(
        project="Optimization_Project",
        api_key_file="/path/to/file",
        log_config=True)])

Wandb-Mixin

ray.tune.integration.wandb.wandb_mixin(func: Callable)[source]

Weights and biases (https://www.wandb.ai/) is a tool for experiment tracking, model optimization, and dataset versioning. This Ray Tune Trainable mixin helps initializing the Wandb API for use with the Trainable class or with @wandb_mixin for the function API.

For basic usage, just prepend your training function with the @wandb_mixin decorator:

from ray.tune.integration.wandb import wandb_mixin

@wandb_mixin
def train_fn(config):
    wandb.log()

Wandb configuration is done by passing a wandb key to the config parameter of tune.run() (see example below).

The content of the wandb config entry is passed to wandb.init() as keyword arguments. The exception are the following settings, which are used to configure the WandbTrainableMixin itself:

Parameters
  • api_key_file – Path to file containing the Wandb API KEY. This file must be on all nodes if using the wandb_mixin.

  • api_key – Wandb API Key. Alternative to setting api_key_file.

Wandb’s group, run_id and run_name are automatically selected by Tune, but can be overwritten by filling out the respective configuration values.

Please see here for all other valid configuration settings: https://docs.wandb.ai/library/init

Example:

from ray import tune
from ray.tune.integration.wandb import wandb_mixin

@wandb_mixin
def train_fn(config):
    for i in range(10):
        loss = self.config["a"] + self.config["b"]
        wandb.log({"loss": loss})
    tune.report(loss=loss, done=True)

tune.run(
    train_fn,
    config={
        # define search space here
        "a": tune.choice([1, 2, 3]),
        "b": tune.choice([4, 5, 6]),
        # wandb configuration
        "wandb": {
            "project": "Optimization_Project",
            "api_key_file": "/path/to/file"
        }
    })