Monitoring and Logging Metrics#

Ray Train provides an API for reporting intermediate results and checkpoints from the training function (run on distributed workers) up to the Trainer (where your python script is executed) by calling train.report(metrics). The results will be collected from the distributed workers and passed to the driver to be logged and displayed.

Warning

Only the results from rank 0 worker will be used. However, in order to ensure consistency, train.report() has to be called on each worker. If you want to aggregate results from multiple workers, see How to obtain and aggregate results from different workers?.

The primary use-case for reporting is for metrics (accuracy, loss, etc.) at the end of each training epoch.

from ray import train

def train_func():
    ...
    for i in range(num_epochs):
        result = model.train(...)
        train.report({"result": result})

In PyTorch Lightning, we use a callback to call train.report().

from ray import train
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

class MyRayTrainReportCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        metrics = {k: v.item() for k, v in metrics.items()}

        train.report(metrics=metrics)

def train_func_per_worker():
    ...
    trainer = pl.Trainer(
        # ...
        callbacks=[MyRayTrainReportCallback()]
    )
    trainer.fit()

How to obtain and aggregate results from different workers?#

In real applications, you may want to calculate optimization metrics besides accuracy and loss: recall, precision, Fbeta, etc. You may also want to collect metrics from multiple workers. While Ray Train currently only reports metrics from the rank 0 worker, you can use third-party libraries or distributed primitives of your machine learning framework to report metrics from multiple workers.

Ray Train natively supports TorchMetrics, which provides a collection of machine learning metrics for distributed, scalable PyTorch models.

Here is an example of reporting both the aggregated R2 score and mean train and validation loss from all workers.


# First, pip install torchmetrics
# This code is tested with torchmetrics==0.7.3 and torch==1.12.1

import ray.train.torch
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

import torch
import torch.nn as nn
import torchmetrics
from torch.optim import Adam
import numpy as np


def train_func(config):
    n = 100
    # create a toy dataset
    X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
    X_valid = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
    Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
    Y_valid = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
    # toy neural network : 1-layer
    # wrap the model in DDP
    model = ray.train.torch.prepare_model(nn.Linear(4, 1))
    criterion = nn.MSELoss()

    mape = torchmetrics.MeanAbsolutePercentageError()
    # for averaging loss
    mean_valid_loss = torchmetrics.MeanMetric()

    optimizer = Adam(model.parameters(), lr=3e-4)
    for epoch in range(config["num_epochs"]):
        model.train()
        y = model.forward(X)

        # compute loss
        loss = criterion(y, Y)

        # back-propagate loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # evaluate
        model.eval()
        with torch.no_grad():
            pred = model(X_valid)
            valid_loss = criterion(pred, Y_valid)
            # save loss in aggregator
            mean_valid_loss(valid_loss)
            mape(pred, Y_valid)

        # collect all metrics
        # use .item() to obtain a value that can be reported
        valid_loss = valid_loss.item()
        mape_collected = mape.compute().item()
        mean_valid_loss_collected = mean_valid_loss.compute().item()

        train.report(
            {
                "mape_collected": mape_collected,
                "valid_loss": valid_loss,
                "mean_valid_loss_collected": mean_valid_loss_collected,
            }
        )

        # reset for next epoch
        mape.reset()
        mean_valid_loss.reset()


trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 5},
    scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
print(result.metrics["valid_loss"], result.metrics["mean_valid_loss_collected"])
# 0.5109779238700867 0.5512474775314331