Fine-tune a PyTorch Lightning Text Classifier with Ray Data#

Note

This is an intermediate example demonstrates how to use Ray Data with PyTorch Lightning in Ray Train.

If you just want to quickly convert your existing PyTorch Lightning scripts into Ray Train, you can refer to the Lightning Quick Start Guide.

This demo introduces how to fine-tune a text classifier on the CoLA(The Corpus of Linguistic Acceptability) dataset using a pre-trained BERT model. In particular, it follows three steps:

  • Preprocess the CoLA dataset with Ray Data.

  • Define a training function with PyTorch Lightning.

  • Launch distributed training with Ray Train’s TorchTrainer.

Run the following line in order to install all the necessary dependencies:

!pip install numpy datasets "transformers>=4.19.1" "pytorch_lightning>=1.6.5"

Start by importing the needed libraries:

import ray
import torch
import numpy as np
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset, load_metric
2023-08-14 16:45:51.059256: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-14 16:45:51.198481: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-14 16:45:52.005931: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-08-14 16:45:52.006010: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-08-14 16:45:52.006015: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Pre-process CoLA Dataset#

CoLA is a dataset for binary sentence classification with 10.6K training examples. First, download the dataset and metrics using the Hugging Face datasets API, and create a Ray Dataset for each split accordingly.

dataset = load_dataset("glue", "cola")

train_dataset = ray.data.from_huggingface(dataset["train"])
validation_dataset = ray.data.from_huggingface(dataset["validation"])

Next, tokenize the input sentences and pad the ID sequence to length 128 using the bert-base-uncased tokenizer. The map_batches applies this preprocessing function on all data samples.

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def tokenize_sentence(batch):
    outputs = tokenizer(
        batch["sentence"].tolist(),
        max_length=128,
        truncation=True,
        padding="max_length",
        return_tensors="np",
    )
    outputs["label"] = batch["label"]
    return outputs

train_dataset = train_dataset.map_batches(tokenize_sentence, batch_format="numpy")
validation_dataset = validation_dataset.map_batches(tokenize_sentence, batch_format="numpy")

Define a PyTorch Lightning model#

You don’t have to make any changes to your LightningModule definition. Just copy and paste your code here:

class SentimentModel(pl.LightningModule):
    def __init__(self, lr=2e-5, eps=1e-8):
        super().__init__()
        self.lr = lr
        self.eps = eps
        self.num_classes = 2
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-cased", num_labels=self.num_classes
        )
        self.metric = load_metric("glue", "cola")
        self.predictions = []
        self.references = []

    def forward(self, batch):
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        outputs = self.model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        return logits

    def training_step(self, batch, batch_idx):
        labels = batch["label"]
        logits = self.forward(batch)
        loss = F.cross_entropy(logits.view(-1, self.num_classes), labels)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        labels = batch["label"]
        logits = self.forward(batch)
        preds = torch.argmax(logits, dim=1)
        self.predictions.append(preds)
        self.references.append(labels)

    def on_validation_epoch_end(self):
        predictions = torch.concat(self.predictions).view(-1)
        references = torch.concat(self.references).view(-1)
        matthews_correlation = self.metric.compute(
            predictions=predictions, references=references
        )

        # self.metric.compute() returns a dictionary:
        # e.g. {"matthews_correlation": 0.53}
        self.log_dict(matthews_correlation, sync_dist=True)
        self.predictions.clear()
        self.references.clear()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)

Define a training function#

Define a training function that includes all of your lightning training logic. TorchTrainer launches this function on each worker in parallel.

import ray.train
from ray.train.lightning import (
    prepare_trainer,
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
)

train_func_config = {
    "lr": 1e-5,
    "eps": 1e-8,
    "batch_size": 16,
    "max_epochs": 5,
}

def train_func(config):
    # Unpack the input configs passed from `TorchTrainer(train_loop_config)`
    lr = config["lr"]
    eps = config["eps"]
    batch_size = config["batch_size"]
    max_epochs = config["max_epochs"]

    # Fetch the Dataset shards
    train_ds = ray.train.get_dataset_shard("train")
    val_ds = ray.train.get_dataset_shard("validation")

    # Create a dataloader for Ray Datasets
    train_ds_loader = train_ds.iter_torch_batches(batch_size=batch_size)
    val_ds_loader = val_ds.iter_torch_batches(batch_size=batch_size)

    # Model
    model = SentimentModel(lr=lr, eps=eps)

    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        devices="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[RayTrainReportCallback()],
        enable_progress_bar=False,
    )

    trainer = prepare_trainer(trainer)

    trainer.fit(model, train_dataloaders=train_ds_loader, val_dataloaders=val_ds_loader)

To enable distributed training with Ray Train, configure the Lightning Trainer with the following utilities:

To ingest Ray Data with Lightning Trainer, follow these three steps:

Distributed training with Ray TorchTrainer#

Next, define a TorchTrainer to launch your training function on 4 GPU workers.

You can pass the full Ray dataset to the datasets argument of TorchTrainer. TorchTrainer automatically shards the datasets among multiple workers.

from ray.train.torch import TorchTrainer
from ray.train import RunConfig, ScalingConfig, CheckpointConfig, DataConfig


# Save the top-2 checkpoints according to the evaluation metric
# The checkpoints and metrics are reported by `RayTrainReportCallback`
run_config = RunConfig(
    name="ptl-sent-classification",
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="matthews_correlation",
        checkpoint_score_order="max",
    ),
)

# Schedule four workers for DDP training (1 GPU/worker by default)
scaling_config = ScalingConfig(num_workers=4, use_gpu=True)

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=train_func_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets={"train": train_dataset, "validation": validation_dataset}, # <- Feed the Ray Datasets here
)
result = trainer.fit()

Tune Status

Current time:2023-08-14 16:51:48
Running for: 00:05:50.88
Memory: 34.5/186.6 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 1.0/48 CPUs, 4.0/4 GPUs

Trial Status

Trial name status loc iter total time (s) train_loss matthews_correlation epoch
TorchTrainer_b723f_00000TERMINATED10.0.63.245:150507 5 337.748 0.0199119 0.577705 4
(TrainTrainable pid=150507) 2023-08-14 16:46:02.166995: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
(TrainTrainable pid=150507) To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
(TrainTrainable pid=150507) 2023-08-14 16:46:02.306203: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
(TrainTrainable pid=150507) 2023-08-14 16:46:03.087593: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
(TrainTrainable pid=150507) 2023-08-14 16:46:03.087670: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
(TrainTrainable pid=150507) 2023-08-14 16:46:03.087677: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
(TorchTrainer pid=150507) Starting distributed worker processes: ['150618 (10.0.63.245)', '150619 (10.0.63.245)', '150620 (10.0.63.245)', '150621 (10.0.63.245)']
(RayTrainWorker pid=150618) Setting up process group for: env:// [rank=0, world_size=4]
(SplitCoordinator pid=150822) Auto configuring locality_with_output=['d4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b']
(RayTrainWorker pid=150620) 2023-08-14 16:46:10.311338: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
(RayTrainWorker pid=150620) To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
(RayTrainWorker pid=150618) 2023-08-14 16:46:10.408092: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
(RayTrainWorker pid=150620) 2023-08-14 16:46:11.238415: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
(RayTrainWorker pid=150620) 2023-08-14 16:46:11.238492: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
(RayTrainWorker pid=150620) 2023-08-14 16:46:11.238500: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
(RayTrainWorker pid=150620) Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
(RayTrainWorker pid=150620) - This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
(RayTrainWorker pid=150620) - This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
(RayTrainWorker pid=150620) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
(RayTrainWorker pid=150620) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(RayTrainWorker pid=150621) Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
(RayTrainWorker pid=150621) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
(RayTrainWorker pid=150619) Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
(RayTrainWorker pid=150618) Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
(RayTrainWorker pid=150618) GPU available: True, used: True
(RayTrainWorker pid=150618) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=150618) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=150618) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=150621) Missing logger folder: /home/ray/ray_results/ptl-sent-classification/TorchTrainer_b723f_00000_0_2023-08-14_16-45-57/rank_3/lightning_logs
(RayTrainWorker pid=150620) LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
(RayTrainWorker pid=150621) 2023-08-14 16:46:10.337167: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=150621) To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. [repeated 3x across cluster]
(RayTrainWorker pid=150621) 2023-08-14 16:46:10.467812: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. [repeated 3x across cluster]
(RayTrainWorker pid=150621) 2023-08-14 16:46:11.270123: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64 [repeated 6x across cluster]
(RayTrainWorker pid=150621) 2023-08-14 16:46:11.270131: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. [repeated 3x across cluster]
(RayTrainWorker pid=150618) 
(RayTrainWorker pid=150618)   | Name  | Type                          | Params
(RayTrainWorker pid=150618) --------------------------------------------------------
(RayTrainWorker pid=150618) 0 | model | BertForSequenceClassification | 108 M 
(RayTrainWorker pid=150618) --------------------------------------------------------
(RayTrainWorker pid=150618) 108 M     Trainable params
(RayTrainWorker pid=150618) 0         Non-trainable params
(RayTrainWorker pid=150618) 108 M     Total params
(RayTrainWorker pid=150618) 433.247   Total estimated model params size (MB)
(RayTrainWorker pid=150620) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150620) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150620) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
(SplitCoordinator pid=150822) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] -> OutputSplitter[split(4, equal=True)]
(SplitCoordinator pid=150822) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=['d4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150618) - This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). [repeated 3x across cluster]
(RayTrainWorker pid=150618) - This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). [repeated 3x across cluster]
(RayTrainWorker pid=150619) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
(RayTrainWorker pid=150618) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. [repeated 3x across cluster]
(RayTrainWorker pid=150618) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
(RayTrainWorker pid=150620) Missing logger folder: /home/ray/ray_results/ptl-sent-classification/TorchTrainer_b723f_00000_0_2023-08-14_16-45-57/rank_2/lightning_logs [repeated 3x across cluster]
(RayTrainWorker pid=150620) [W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
(RayTrainWorker pid=150618) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3] [repeated 3x across cluster]
(RayTrainWorker pid=150620) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] [repeated 4x across cluster]
(RayTrainWorker pid=150620) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 4x across cluster]
(RayTrainWorker pid=150620) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True` [repeated 5x across cluster]
(RayTrainWorker pid=150618) [W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [repeated 3x across cluster]
(SplitCoordinator pid=150822) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] -> OutputSplitter[split(4, equal=True)]
(SplitCoordinator pid=150822) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=['d4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150618) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] [repeated 3x across cluster]
(RayTrainWorker pid=150618) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 3x across cluster]
(SplitCoordinator pid=150822) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True` [repeated 4x across cluster]
(RayTrainWorker pid=150620) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150620) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150620) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
(RayTrainWorker pid=150621) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150621) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150621) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
(autoscaler +2m37s) Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
(SplitCoordinator pid=150822) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] -> OutputSplitter[split(4, equal=True)]
(SplitCoordinator pid=150822) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=['d4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150618) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] [repeated 2x across cluster]
(RayTrainWorker pid=150618) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 2x across cluster]
(SplitCoordinator pid=150822) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True` [repeated 3x across cluster]
(RayTrainWorker pid=150620) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150620) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150620) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
(RayTrainWorker pid=150621) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150621) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150621) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
(SplitCoordinator pid=150822) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] -> OutputSplitter[split(4, equal=True)]
(SplitCoordinator pid=150822) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=['d4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150618) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] [repeated 2x across cluster]
(RayTrainWorker pid=150618) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 2x across cluster]
(SplitCoordinator pid=150822) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True` [repeated 3x across cluster]
(RayTrainWorker pid=150620) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150620) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150620) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
(RayTrainWorker pid=150621) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150621) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150621) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
(SplitCoordinator pid=150822) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] -> OutputSplitter[split(4, equal=True)]
(SplitCoordinator pid=150822) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=['d4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b', 'd4dd34cdb4b35e8b1e0f1ab4187b66ed900ab78de951f03e1125233b'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150618) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)] [repeated 2x across cluster]
(RayTrainWorker pid=150618) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 2x across cluster]
(SplitCoordinator pid=150822) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True` [repeated 3x across cluster]
(RayTrainWorker pid=150620) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150620) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150620) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
(RayTrainWorker pid=150621) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(tokenize_sentence)]
(RayTrainWorker pid=150621) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=True, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(RayTrainWorker pid=150621) Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
2023-08-14 16:51:48,299	INFO tune.py:1146 -- Total run time: 350.99 seconds (350.87 seconds for the tuning loop).

Note

Note that this examples uses Ray Data for data ingestion for faster preprocessing, but you can also continue to use the native PyTorch DataLoader or LightningDataModule. See Train a Pytorch Lightning Image Classifier.

result
Result(
  metrics={'train_loss': 0.019911885261535645, 'matthews_correlation': 0.577705364544777, 'epoch': 4, 'step': 670},
  path='/home/ray/ray_results/ptl-sent-classification/TorchTrainer_b723f_00000_0_2023-08-14_16-45-57',
  checkpoint=TorchCheckpoint(local_path=/home/ray/ray_results/ptl-sent-classification/TorchTrainer_b723f_00000_0_2023-08-14_16-45-57/checkpoint_000004)
)
(autoscaler +50m28s) Cluster is terminating (reason: user action).

See also#