Experiment Tracking#
Most experiment tracking libraries work out-of-the-box with Ray Train. This guide provides instructions on how to set up the code so that your favorite experiment tracking libraries can work for distributed training with Ray Train. The end of the guide has common errors to aid in debugging the setup.
The following pseudo code demonstrates how to use the native experiment tracking library calls inside of Ray Train:
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
def train_func():
    # Training code and native experiment tracking library calls go here.
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
Ray Train lets you use native experiment tracking libraries by customizing the tracking logic inside the train_func function. In this way, you can port your experiment tracking logic to Ray Train with minimal changes.
Getting Started#
Let’s start by looking at some code snippets.
The following examples uses Weights & Biases (W&B) and MLflow but it’s adaptable to other frameworks.
import ray
from ray import train
import wandb
# Step 1
# This ensures that all ray worker processes have `WANDB_API_KEY` set.
ray.init(runtime_env={"env_vars": {"WANDB_API_KEY": "your_api_key"}})
def train_func():
    # Step 1 and 2
    if train.get_context().get_world_rank() == 0:
        wandb.init(
            name=...,
            project=...,
            # ...
        )
    # ...
    loss = optimize()
    metrics = {"loss": loss}
    # Step 3
    if train.get_context().get_world_rank() == 0:
        # Only report the results from the rank 0 worker to W&B to avoid duplication.
        wandb.log(metrics)
    # ...
    # Step 4
    # Make sure that all loggings are uploaded to the W&B backend.
    if train.get_context().get_world_rank() == 0:
        wandb.finish()
from ray import train
import mlflow
# Run the following on the head node:
# $ databricks configure --token
# mv ~/.databrickscfg YOUR_SHARED_STORAGE_PATH
# This function assumes `databricks_config_file` is specified in the Trainer's `train_loop_config`.
def train_func(config):
    # Step 1 and 2
    os.environ["DATABRICKS_CONFIG_FILE"] = config["databricks_config_file"]
    mlflow.set_tracking_uri("databricks")
    mlflow.set_experiment_id(...)
    mlflow.start_run()
    # ...
    loss = optimize()
    metrics = {"loss": loss}
    # Step 3
    if train.get_context().get_world_rank() == 0:
        # Only report the results from the rank 0 worker to MLflow to avoid duplication.
        mlflow.log_metrics(metrics)
Tip
A major difference between distributed and non-distributed training is that in distributed training,
multiple processes are running in parallel and under certain setups they have the same results. If all
of them report results to the tracking backend, you may get duplicated results. To address that,
Ray Train lets you apply logging logic to only the rank 0 worker with the following method:
ray.train.get_context().get_world_rank().
from ray import train
def train_func():
    ...
    if train.get_context().get_world_rank() == 0:
        # Add your logging logic only for rank0 worker.
    ...
The interaction with the experiment tracking backend within the train_func has 4 logical steps:
- Set up the connection to a tracking backend 
- Configure and launch a run 
- Log metrics 
- Finish the run 
More details about each step follows.
Step 1: Connect to your tracking backend#
First, decide which tracking backend to use: W&B, MLflow, TensorBoard, Comet, etc. If applicable, make sure that you properly set up credentials on each training worker.
W&B offers both online and offline modes.
Online
For online mode, because you log to W&B’s tracking service, ensure that you set the credentials inside of train_func. See Set up credentials for more information.
# This is equivalent to `os.environ["WANDB_API_KEY"] = "your_api_key"`
wandb.login(key="your_api_key")
Offline
For offline mode, because you log towards a local file system, point the offline directory to a shared storage path that all nodes can write to. See Set up a shared file system for more information.
os.environ["WANDB_MODE"] = "offline"
wandb.init(dir="some_shared_storage_path/wandb")
MLflow offers both local and remote (for example, to Databrick’s MLflow service) modes.
Local
For local mode, because you log to a local file system, point offline directory to a shared storage path. that all nodes can write to. See Set up a shared file system for more information.
mlflow.set_tracking_uri(uri="file://some_shared_storage_path/mlruns")
mlflow.start_run()
Remote, hosted by Databricks
Ensure that all nodes have access to the Databricks config file. See Set up credentials for more information.
# The MLflow client looks for a Databricks config file
# at the location specified by `os.environ["DATABRICKS_CONFIG_FILE"]`.
os.environ["DATABRICKS_CONFIG_FILE"] = config["databricks_config_file"]
mlflow.set_tracking_uri("databricks")
mlflow.start_run()
Set up credentials#
Refer to each tracking library’s API documentation on setting up credentials. This step usually involves setting an environment variable or accessing a config file.
The easiest way to pass an environment variable credential to training workers is through runtime environments, where you initialize with the following code:
import ray
# This makes sure that training workers have the same env var set
ray.init(runtime_env={"env_vars": {"SOME_API_KEY": "your_api_key"}})
For accessing the config file, ensure that the config file is accessible to all nodes. One way to do this is by setting up a shared storage. Another way is to save a copy in each node.
Step 2: Configure and start the run#
This step usually involves picking an identifier for the run and associating it with a project. Refer to the tracking libraries’ documentation for semantics.
Tip
When performing fault-tolerant training with auto-restoration, use a consistent ID to configure all tracking runs that logically belong to the same training run.
Step 3: Log metrics#
You can customize how to log parameters, metrics, models, or media contents, within
train_func, just as in a non-distributed training script.
You can also use native integrations that a particular tracking framework has with
specific training frameworks. For example, mlflow.pytorch.autolog(),
lightning.pytorch.loggers.MLFlowLogger, etc.
Step 4: Finish the run#
This step ensures that all logs are synced to the tracking service. Depending on the implementation of various tracking libraries, sometimes logs are first cached locally and only synced to the tracking service in an asynchronous fashion. Finishing the run makes sure that all logs are synced by the time training workers exit.
# https://docs.wandb.ai/ref/python/finish
wandb.finish()
# https://mlflow.org/docs/1.2.0/python_api/mlflow.html
mlflow.end_run()
# https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/#experimentend
Experiment.end()
Examples#
The following are runnable examples for PyTorch and PyTorch Lightning.
PyTorch#
Log to W&B
from filelock import FileLock
import os
import torch
import wandb
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet18
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
# Run the following script with the WANDB_API_KEY env var set.
assert os.environ.get("WANDB_API_KEY", None), "Please set WANDB_API_KEY env var."
# This makes sure that all workers have this env var set.
ray.init(
    runtime_env={"env_vars": {"WANDB_API_KEY": os.environ["WANDB_API_KEY"]}}
)
def train_func(config):
    if ray.train.get_context().get_world_rank() == 0:
        wandb.init()
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model = ray.train.torch.prepare_model(model)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.module.parameters(), lr=0.001)
    # Data
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.28604,), (0.32025,))]
    )
    with FileLock("./data.lock"):
        train_data = datasets.FashionMNIST(
            root="./data", train=True, download=True, transform=transform
        )
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    # Training
    for epoch in range(1):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)
        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if ray.train.get_context().get_world_rank() == 0:
                wandb.log({"loss": loss, "epoch": epoch})
    if ray.train.get_context().get_world_rank() == 0:
        wandb.finish()
trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=2),
)
trainer.fit()
Log to file-based MLflow
# Run the following script with the SHARED_STORAGE_PATH env var set.
# The MLflow offline logs are saved to SHARED_STORAGE_PATH/mlruns.
import mlflow
import os
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
import torch
from torchvision import datasets, transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader
assert os.environ.get(
    "SHARED_STORAGE_PATH", None
), "Please set SHARED_STORAGE_PATH env var."
# Assumes you are passing a `save_dir` in `config`
def train_func(config):
    save_dir = config["save_dir"]
    if ray.train.get_context().get_world_rank() == 0:
        mlflow.set_tracking_uri(f"file:{save_dir}")
        mlflow.set_experiment("my_experiment")
        mlflow.start_run()
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model = ray.train.torch.prepare_model(model)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.module.parameters(), lr=0.001)
    # Data
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.28604,), (0.32025,))]
    )
    with FileLock("./data.lock"):
        train_data = datasets.FashionMNIST(
            root="./data", train=True, download=True, transform=transform
        )
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    # Training
    for epoch in range(1):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)
        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if ray.train.get_context().get_world_rank() == 0:
                mlflow.log_metrics({"loss": loss.item(), "epoch": epoch})
    if ray.train.get_context().get_world_rank() == 0:
        mlflow.end_run()
trainer = TorchTrainer(
    train_func,
    train_loop_config={
        "save_dir": os.path.join(os.environ["SHARED_STORAGE_PATH"], "mlruns")
    },
    scaling_config=ScalingConfig(num_workers=2),
)
trainer.fit()
PyTorch Lightning#
You can use the native Logger integration in PyTorch Lightning with W&B, CometML, MLFlow, and Tensorboard, while using Ray Train’s TorchTrainer.
The following example walks you through the process. The code here is runnable.
W&B
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
# Create dummy data
X = torch.randn(128, 3)  # 128 samples, 3 features
y = torch.randint(0, 2, (128,))  # 128 binary labels
# Create a TensorDataset to wrap the data
dataset = TensorDataset(X, y)
# Create a DataLoader to iterate over the dataset
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Define a dummy model
class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 1)
    def forward(self, x):
        return self.layer(x)
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float())
        # The metrics below will be reported to Loggers
        self.log("train_loss", loss)
        self.log_dict({
            "metric_1": 1 / (batch_idx + 1), "metric_2": batch_idx * 100
        })
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
import os
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers.wandb import WandbLogger
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_func(config):
    logger = None
    if ray.train.get_context().get_world_rank() == 0:
        logger = WandbLogger(name="demo-run", project="demo-project")
    ptl_trainer = pl.Trainer(
        max_epochs=5,
        accelerator="cpu",
        logger=logger,
        log_every_n_steps=1,
    )
    model = DummyModel()
    ptl_trainer.fit(model, train_dataloaders=dataloader)
    if ray.train.get_context().get_world_rank() == 0:
        wandb.finish()
scaling_config = ScalingConfig(num_workers=2, use_gpu=False)
assert (
    "WANDB_API_KEY" in os.environ
), 'Please set WANDB_API_KEY="abcde" when running this script.'
# This ensures that all workers have this env var set.
ray.init(
    runtime_env={"env_vars": {"WANDB_API_KEY": os.environ["WANDB_API_KEY"]}}
)
trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
)
trainer.fit()
MLflow
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
# Create dummy data
X = torch.randn(128, 3)  # 128 samples, 3 features
y = torch.randint(0, 2, (128,))  # 128 binary labels
# Create a TensorDataset to wrap the data
dataset = TensorDataset(X, y)
# Create a DataLoader to iterate over the dataset
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Define a dummy model
class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 1)
    def forward(self, x):
        return self.layer(x)
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float())
        # The metrics below will be reported to Loggers
        self.log("train_loss", loss)
        self.log_dict({
            "metric_1": 1 / (batch_idx + 1), "metric_2": batch_idx * 100
        })
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers.mlflow import MLFlowLogger
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_func(config):
    save_dir = config["save_dir"]
    logger = None
    if ray.train.get_context().get_world_rank() == 0:
        logger = MLFlowLogger(
            experiment_name="demo-project",
            tracking_uri=f"file:{save_dir}",
        )
    ptl_trainer = pl.Trainer(
        max_epochs=5,
        accelerator="cpu",
        logger=logger,
        log_every_n_steps=1,
    )
    model = DummyModel()
    ptl_trainer.fit(model, train_dataloaders=dataloader)
scaling_config = ScalingConfig(num_workers=2, use_gpu=False)
assert (
    "SHARED_STORAGE_PATH" in os.environ
), "Please do SHARED_STORAGE_PATH=/a/b/c when running this script."
trainer = TorchTrainer(
    train_func,
    train_loop_config={
        "save_dir": os.path.join(os.environ["SHARED_STORAGE_PATH"], "mlruns")
    },
    scaling_config=scaling_config,
)
trainer.fit()
Comet
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
# Create dummy data
X = torch.randn(128, 3)  # 128 samples, 3 features
y = torch.randint(0, 2, (128,))  # 128 binary labels
# Create a TensorDataset to wrap the data
dataset = TensorDataset(X, y)
# Create a DataLoader to iterate over the dataset
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Define a dummy model
class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 1)
    def forward(self, x):
        return self.layer(x)
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float())
        # The metrics below will be reported to Loggers
        self.log("train_loss", loss)
        self.log_dict({
            "metric_1": 1 / (batch_idx + 1), "metric_2": batch_idx * 100
        })
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers.comet import CometLogger
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_func(config):
    logger = None
    if ray.train.get_context().get_world_rank() == 0:
        logger = CometLogger(api_key=os.environ["COMET_API_KEY"])
    ptl_trainer = pl.Trainer(
        max_epochs=5,
        accelerator="cpu",
        logger=logger,
        log_every_n_steps=1,
    )
    model = DummyModel()
    ptl_trainer.fit(model, train_dataloaders=dataloader)
scaling_config = ScalingConfig(num_workers=2, use_gpu=False)
assert (
    "COMET_API_KEY" in os.environ
), 'Please do COMET_API_KEY="abcde" when running this script.'
# This makes sure that all workers have this env var set.
ray.init(runtime_env={"env_vars": {"COMET_API_KEY": os.environ["COMET_API_KEY"]}})
trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
)
trainer.fit()
TensorBoard
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
# Create dummy data
X = torch.randn(128, 3)  # 128 samples, 3 features
y = torch.randint(0, 2, (128,))  # 128 binary labels
# Create a TensorDataset to wrap the data
dataset = TensorDataset(X, y)
# Create a DataLoader to iterate over the dataset
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Define a dummy model
class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 1)
    def forward(self, x):
        return self.layer(x)
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.binary_cross_entropy_with_logits(y_hat.flatten(), y.float())
        # The metrics below will be reported to Loggers
        self.log("train_loss", loss)
        self.log_dict({
            "metric_1": 1 / (batch_idx + 1), "metric_2": batch_idx * 100
        })
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
import os
import pytorch_lightning as pl
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
import ray
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
def train_func(config):
    save_dir = config["save_dir"]
    logger = None
    if ray.train.get_context().get_world_rank() == 0:
        logger = TensorBoardLogger(name="demo-run", save_dir=f"file:{save_dir}")
    ptl_trainer = pl.Trainer(
        max_epochs=5,
        accelerator="cpu",
        logger=logger,
        log_every_n_steps=1,
    )
    model = DummyModel()
    ptl_trainer.fit(model, train_dataloaders=dataloader)
scaling_config = ScalingConfig(num_workers=2, use_gpu=False)
assert (
    "SHARED_STORAGE_PATH" in os.environ
), "Please do SHARED_STORAGE_PATH=/a/b/c when running this script."
trainer = TorchTrainer(
    train_func,
    train_loop_config={
        "save_dir": os.path.join(os.environ["SHARED_STORAGE_PATH"], "tensorboard")
    },
    scaling_config=scaling_config,
)
trainer.fit()
Common Errors#
Missing Credentials#
I have already called `wandb login` cli, but am still getting
wandb: ERROR api_key not configured (no-tty). call wandb.login(key=[your_api_key]).
This is probably due to wandb credentials are not set up correctly
on worker nodes. Make sure that you run wandb.login
or pass WANDB_API_KEY to each training function.
See Set up credentials for more details.
Missing Configurations#
I have already run `databricks configure`, but am still getting
databricks_cli.utils.InvalidConfigurationError: You haven't configured the CLI yet!
This is usually caused by running databricks configure which
generates ~/.databrickscfg only on head node. Move this file to a shared
location or copy it to each node.
See Set up credentials for more details.