{ "cells": [ { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "# Using PyTorch Lightning with Tune\n", "\n", "(tune-pytorch-lightning-ref)=\n", "\n", "PyTorch Lightning is a framework which brings structure into training PyTorch models. It aims to avoid boilerplate code, so you don't have to write the same training loops all over again when building a new model.\n", "\n", "```{image} /images/pytorch_lightning_full.png\n", ":align: center\n", "```\n", "\n", "The main abstraction of PyTorch Lightning is the `LightningModule` class, which should be extended by your application. There is [a great post on how to transfer your models from vanilla PyTorch to Lightning](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09).\n", "\n", "The class structure of PyTorch Lightning makes it very easy to define and tune model parameters. This tutorial will show you how to use Tune with Ray Train's {class}`TorchTrainer ` to find the best set of parameters for your application on the example of training a MNIST classifier. Notably, the `LightningModule` does not have to be altered at all for this - so you can use it plug and play for your existing models, assuming their parameters are configurable!\n", "\n", ":::{note}\n", "To run this example, you will need to install the following:\n", "\n", "```bash\n", "$ pip install \"ray[tune]\" torch torchvision pytorch_lightning\n", "```\n", ":::\n", "\n", "```{contents}\n", ":backlinks: none\n", ":local: true\n", "```\n", "\n", "## PyTorch Lightning classifier for MNIST\n", "\n", "Let's first start with the basic PyTorch Lightning implementation of an MNIST classifier. This classifier does not include any tuning code at this point.\n", "\n", "First, we run some imports:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import tempfile\n", "import pytorch_lightning as pl\n", "import torch.nn.functional as F\n", "from filelock import FileLock\n", "from torchmetrics import Accuracy\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision.datasets import MNIST\n", "from torchvision import transforms" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "# If you want to run full test, please set SMOKE_TEST to False\n", "SMOKE_TEST = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our example builds on the MNIST example from the [blog post](https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09) we mentioned before. We adapted the original model and dataset definitions into `MNISTClassifier` and `MNISTDataModule`. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class MNISTClassifier(pl.LightningModule):\n", " def __init__(self, config):\n", " super(MNISTClassifier, self).__init__()\n", " self.accuracy = Accuracy(task=\"multiclass\", num_classes=10, top_k=1)\n", " self.layer_1_size = config[\"layer_1_size\"]\n", " self.layer_2_size = config[\"layer_2_size\"]\n", " self.lr = config[\"lr\"]\n", "\n", " # mnist images are (1, 28, 28) (channels, width, height)\n", " self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)\n", " self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)\n", " self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)\n", " self.eval_loss = []\n", " self.eval_accuracy = []\n", "\n", " def cross_entropy_loss(self, logits, labels):\n", " return F.nll_loss(logits, labels)\n", "\n", " def forward(self, x):\n", " batch_size, channels, width, height = x.size()\n", " x = x.view(batch_size, -1)\n", "\n", " x = self.layer_1(x)\n", " x = torch.relu(x)\n", "\n", " x = self.layer_2(x)\n", " x = torch.relu(x)\n", "\n", " x = self.layer_3(x)\n", " x = torch.log_softmax(x, dim=1)\n", "\n", " return x\n", "\n", " def training_step(self, train_batch, batch_idx):\n", " x, y = train_batch\n", " logits = self.forward(x)\n", " loss = self.cross_entropy_loss(logits, y)\n", " accuracy = self.accuracy(logits, y)\n", "\n", " self.log(\"ptl/train_loss\", loss)\n", " self.log(\"ptl/train_accuracy\", accuracy)\n", " return loss\n", "\n", " def validation_step(self, val_batch, batch_idx):\n", " x, y = val_batch\n", " logits = self.forward(x)\n", " loss = self.cross_entropy_loss(logits, y)\n", " accuracy = self.accuracy(logits, y)\n", " self.eval_loss.append(loss)\n", " self.eval_accuracy.append(accuracy)\n", " return {\"val_loss\": loss, \"val_accuracy\": accuracy}\n", "\n", " def on_validation_epoch_end(self):\n", " avg_loss = torch.stack(self.eval_loss).mean()\n", " avg_acc = torch.stack(self.eval_accuracy).mean()\n", " self.log(\"ptl/val_loss\", avg_loss, sync_dist=True)\n", " self.log(\"ptl/val_accuracy\", avg_acc, sync_dist=True)\n", " self.eval_loss.clear()\n", " self.eval_accuracy.clear()\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", " return optimizer\n", "\n", "\n", "class MNISTDataModule(pl.LightningDataModule):\n", " def __init__(self, batch_size=128):\n", " super().__init__()\n", " self.data_dir = tempfile.mkdtemp()\n", " self.batch_size = batch_size\n", " self.transform = transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", " )\n", "\n", " def setup(self, stage=None):\n", " with FileLock(f\"{self.data_dir}.lock\"):\n", " mnist = MNIST(\n", " self.data_dir, train=True, download=True, transform=self.transform\n", " )\n", " self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])\n", "\n", " self.mnist_test = MNIST(\n", " self.data_dir, train=False, download=True, transform=self.transform\n", " )\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)\n", "\n", " def test_dataloader(self):\n", " return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "default_config = {\n", " \"layer_1_size\": 128,\n", " \"layer_2_size\": 256,\n", " \"lr\": 1e-3,\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define a training function that creates model, datamodule, and lightning trainer with Ray Train utilities." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ray.train.lightning import (\n", " RayDDPStrategy,\n", " RayLightningEnvironment,\n", " RayTrainReportCallback,\n", " prepare_trainer,\n", ")\n", "\n", "\n", "def train_func(config):\n", " dm = MNISTDataModule(batch_size=config[\"batch_size\"])\n", " model = MNISTClassifier(config)\n", "\n", " trainer = pl.Trainer(\n", " devices=\"auto\",\n", " accelerator=\"auto\",\n", " strategy=RayDDPStrategy(),\n", " callbacks=[RayTrainReportCallback()],\n", " plugins=[RayLightningEnvironment()],\n", " enable_progress_bar=False,\n", " )\n", " trainer = prepare_trainer(trainer)\n", " trainer.fit(model, datamodule=dm)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tuning the model parameters\n", "\n", "The parameters above should give you a good accuracy of over 90% already. However, we might improve on this simply by changing some of the hyperparameters. For instance, maybe we get an even higher accuracy if we used a smaller learning rate and larger middle layer size.\n", "\n", "Instead of manually loop through all the parameter combinitions, let's use Tune to systematically try out parameter combinations and find the best performing set.\n", "\n", "First, we need some additional imports:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from ray import tune\n", "from ray.tune.schedulers import ASHAScheduler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configuring the search space\n", "\n", "Now we configure the parameter search space. We would like to choose between different layer dimensions, learning rate, and batch sizes. The learning rate should be sampled uniformly between `0.0001` and `0.1`. The `tune.loguniform()` function is syntactic sugar to make sampling between these different orders of magnitude easier, specifically we are able to also sample small values. Similarly for `tune.choice()`, which samples from all the provided options." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "search_space = {\n", " \"layer_1_size\": tune.choice([32, 64, 128]),\n", " \"layer_2_size\": tune.choice([64, 128, 256]),\n", " \"lr\": tune.loguniform(1e-4, 1e-1),\n", " \"batch_size\": tune.choice([32, 64]),\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Selecting a scheduler\n", "\n", "In this example, we use an [Asynchronous Hyperband](https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/)\n", "scheduler. This scheduler decides at each iteration which trials are likely to perform\n", "badly, and stops these trials. This way we don't waste any resources on bad hyperparameter\n", "configurations." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# The maximum training epochs\n", "num_epochs = 5\n", "\n", "# Number of sampls from parameter space\n", "num_samples = 10" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "If you have more resources available, you can modify the above parameters accordingly. e.g. more epochs, more parameter samples." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "if SMOKE_TEST:\n", " num_epochs = 3\n", " num_samples = 3" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training with GPUs\n", "\n", "We can specify the number of resources, including GPUs, that Tune should request for each trial.\n", "\n", "`TorchTrainer` takes care of environment setup for Distributed Data Parallel training, the model and data will automatically get distributed across GPUs. You only need to set the number of GPUs per worker in `ScalingConfig` and also set `accelerator=\"auto\"` in your training function." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "from ray.train import RunConfig, ScalingConfig, CheckpointConfig\n", "\n", "scaling_config = ScalingConfig(\n", " num_workers=3, use_gpu=True, resources_per_worker={\"CPU\": 1, \"GPU\": 1}\n", ")\n", "\n", "run_config = RunConfig(\n", " checkpoint_config=CheckpointConfig(\n", " num_to_keep=2,\n", " checkpoint_score_attribute=\"ptl/val_accuracy\",\n", " checkpoint_score_order=\"max\",\n", " ),\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "if SMOKE_TEST:\n", " scaling_config = ScalingConfig(\n", " num_workers=3, use_gpu=False, resources_per_worker={\"CPU\": 1}\n", " )" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from ray.train.torch import TorchTrainer\n", "\n", "# Define a TorchTrainer without hyper-parameters for Tuner\n", "ray_trainer = TorchTrainer(\n", " train_func,\n", " scaling_config=scaling_config,\n", " run_config=run_config,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Putting it together\n", "\n", "Lastly, we need to create a `Tuner()` object and start Ray Tune with `tuner.fit()`. The full code looks like this:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2023-09-07 14:03:52
Running for: 00:05:13.92
Memory: 20.5/186.6 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using AsyncHyperBand: num_stopped=10
Bracket: Iter 4.000: 0.9709362387657166 | Iter 2.000: 0.9617255330085754 | Iter 1.000: 0.9477165043354034
Logical resource usage: 4.0/48 CPUs, 3.0/4 GPUs (0.0/1.0 accelerator_type:None)\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc train_loop_config/ba\n", "tch_size train_loop_config/la\n", "yer_1_size train_loop_config/la\n", "yer_2_size train_loop_config/lr iter total time (s) ptl/train_loss ptl/train_accuracy ptl/val_loss
TorchTrainer_5144b_00000TERMINATED10.0.0.84:63990 32 64256 0.0316233 5 29.3336 0.973613 0.766667 0.580943
TorchTrainer_5144b_00001TERMINATED10.0.0.84:71294 64128 64 0.0839278 1 12.2275 2.19514 0.266667 1.56644
TorchTrainer_5144b_00002TERMINATED10.0.0.84:73540 32 64256 0.000233034 5 29.1314 0.146903 0.933333 0.114229
TorchTrainer_5144b_00003TERMINATED10.0.0.84:80840 64128 64 0.00109259 5 21.6534 0.0474913 0.966667 0.0714878
TorchTrainer_5144b_00004TERMINATED10.0.0.84:88077 32 32128 0.00114083 5 29.6367 0.0990443 0.966667 0.0891999
TorchTrainer_5144b_00005TERMINATED10.0.0.84:95388 32 64 64 0.00924264 4 25.7089 0.0349707 1 0.153937
TorchTrainer_5144b_00006TERMINATED10.0.0.84:10143432128256 0.00325671 5 29.5763 0.0708755 0.966667 0.0820903
TorchTrainer_5144b_00007TERMINATED10.0.0.84:10875032 32 64 0.000123766 1 13.9326 0.27464 0.966667 0.401102
TorchTrainer_5144b_00008TERMINATED10.0.0.84:11101964128256 0.00371762 5 21.8337 0.00108961 1 0.0579874
TorchTrainer_5144b_00009TERMINATED10.0.0.84:11825532128128 0.00397956 5 29.8334 0.00940019 1 0.0685028
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(TrainTrainable pid=63990)\u001b[0m 2023-09-07 13:58:43.025064: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(TrainTrainable pid=63990)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(TrainTrainable pid=63990)\u001b[0m 2023-09-07 13:58:43.165187: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "\u001b[2m\u001b[36m(TrainTrainable pid=63990)\u001b[0m 2023-09-07 13:58:43.907088: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(TrainTrainable pid=63990)\u001b[0m 2023-09-07 13:58:43.907153: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(TrainTrainable pid=63990)\u001b[0m 2023-09-07 13:58:43.907160: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", "\u001b[2m\u001b[36m(TorchTrainer pid=63990)\u001b[0m Starting distributed worker processes: ['64101 (10.0.0.84)', '64102 (10.0.0.84)', '64103 (10.0.0.84)']\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m Setting up process group for: env:// [rank=0, world_size=3]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m 2023-09-07 13:58:50.419714: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 2023-09-07 13:58:50.419718: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m 2023-09-07 13:58:50.555450: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m 2023-09-07 13:58:51.317522: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m 2023-09-07 13:58:51.317610: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m 2023-09-07 13:58:51.317618: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m Missing logger folder: /home/ray/ray_results/TorchTrainer_2023-09-07_13-58-38/TorchTrainer_5144b_00000_0_batch_size=32,layer_1_size=64,layer_2_size=256,lr=0.0316_2023-09-07_13-58-38/lightning_logs\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m /home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:92: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m rank_zero_warn(\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m GPU available: True, used: True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m TPU available: False, using: 0 TPU cores\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m IPU available: False, using: 0 IPUs\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/tmpydcy4598/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 9912422/9912422 [00:00<00:00, 120812916.07it/s]\n", "100%|██████████| 9912422/9912422 [00:00<00:00, 101305832.98it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m Extracting /tmp/tmpydcy4598/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/tmpydcy4598/MNIST/raw\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m | Name | Type | Params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m ------------------------------------------------\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 0 | accuracy | MulticlassAccuracy | 0 \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 1 | layer_1 | Linear | 50.2 K\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 2 | layer_2 | Linear | 16.6 K\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 3 | layer_3 | Linear | 2.6 K \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m ------------------------------------------------\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 69.5 K Trainable params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 0 Non-trainable params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 69.5 K Total params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 0.278 Total estimated model params size (MB)\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m [W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[1m\u001b[36m(autoscaler +7m33s)\u001b[0m [autoscaler] Current infeasible resource requests: {\"resourcesBundle\":{\"bundle_group_289661bddaad4820732f117e33d702000000\":0.001}}, {\"resourcesBundle\":{\"bundle_group_d14ed93ffcb267f77984fc5e097c02000000\":0.001}}, {\"resourcesBundle\":{\"bundle_group_9d0f0584af89d9185ad87362359402000000\":0.001}}, {\"resourcesBundle\":{\"bundle_group_b8fdebe2246b003d6e5d0451465b02000000\":0.001}}, {\"resourcesBundle\":{\"bundle_group_35d0a11b5707ef020363a907e5fc02000000\":0.001}}, {\"resourcesBundle\":{\"bundle_group_ba2b3c448809cad351fc7dc545a402000000\":0.001}}, {\"resourcesBundle\":{\"bundle_group_05283c0cbfbb775ad68aacf47bc702000000\":0.001}}, {\"resourcesBundle\":{\"bundle_group_2cd0e3d931d1e356a1ab0f3afb6a02000000\":0.001}}, {\"resourcesBundle\":{\"bundle_group_14f2bd9329dfcde35c77e8474b0f02000000\":0.001}}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/TorchTrainer_2023-09-07_13-58-38/TorchTrainer_5144b_00000_0_batch_size=32,layer_1_size=64,layer_2_size=256,lr=0.0316_2023-09-07_13-58-38/checkpoint_000000)\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64103)\u001b[0m 2023-09-07 13:58:50.448640: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64103)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 2023-09-07 13:58:50.555450: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 2023-09-07 13:58:51.317611: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\u001b[32m [repeated 4x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m 2023-09-07 13:58:51.317618: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m Missing logger folder: /home/ray/ray_results/TorchTrainer_2023-09-07_13-58-38/TorchTrainer_5144b_00000_0_batch_size=32,layer_1_size=64,layer_2_size=256,lr=0.0316_2023-09-07_13-58-38/lightning_logs\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "100%|██████████| 4542/4542 [00:00<00:00, 42147187.54it/s]\u001b[32m [repeated 11x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m [W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/TorchTrainer_2023-09-07_13-58-38/TorchTrainer_5144b_00000_0_batch_size=32,layer_1_size=64,layer_2_size=256,lr=0.0316_2023-09-07_13-58-38/checkpoint_000002)\u001b[32m [repeated 6x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64102)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/TorchTrainer_2023-09-07_13-58-38/TorchTrainer_5144b_00000_0_batch_size=32,layer_1_size=64,layer_2_size=256,lr=0.0316_2023-09-07_13-58-38/checkpoint_000004)\u001b[32m [repeated 6x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(TrainTrainable pid=71294)\u001b[0m 2023-09-07 13:59:19.340985: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(TrainTrainable pid=71294)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/TorchTrainer_2023-09-07_13-58-38/TorchTrainer_5144b_00000_0_batch_size=32,layer_1_size=64,layer_2_size=256,lr=0.0316_2023-09-07_13-58-38/checkpoint_000004)\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(TrainTrainable pid=71294)\u001b[0m 2023-09-07 13:59:19.479380: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "\u001b[2m\u001b[36m(TrainTrainable pid=71294)\u001b[0m 2023-09-07 13:59:20.227539: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(TrainTrainable pid=71294)\u001b[0m 2023-09-07 13:59:20.227616: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(TrainTrainable pid=71294)\u001b[0m 2023-09-07 13:59:20.227623: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", "\u001b[2m\u001b[36m(TorchTrainer pid=71294)\u001b[0m Starting distributed worker processes: ['71407 (10.0.0.84)', '71408 (10.0.0.84)', '71409 (10.0.0.84)']\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m Setting up process group for: env:// [rank=0, world_size=3]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71408)\u001b[0m 2023-09-07 13:59:26.852631: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71408)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m 2023-09-07 13:59:26.854221: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71408)\u001b[0m 2023-09-07 13:59:26.986178: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71408)\u001b[0m 2023-09-07 13:59:27.752593: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71408)\u001b[0m 2023-09-07 13:59:27.752672: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71408)\u001b[0m 2023-09-07 13:59:27.752679: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m /home/ray/anaconda3/lib/python3.9/site-packages/pytorch_lightning/loops/utilities.py:92: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m rank_zero_warn(\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m GPU available: True, used: True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m TPU available: False, using: 0 TPU cores\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m IPU available: False, using: 0 IPUs\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71407)\u001b[0m HPU available: False, using: 0 HPUs\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=71408)\u001b[0m Missing logger folder: /home/ray/ray_results/TorchTrainer_2023-09-07_13-58-38/TorchTrainer_5144b_00001_1_batch_size=64,layer_1_size=128,layer_2_size=64,lr=0.0839_2023-09-07_13-58-38/lightning_logs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=71408)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\u001b[32m [repeated 12x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/tmpt8k8jglf/MNIST/raw/t10k-labels-idx1-ubyte.gz\u001b[32m [repeated 11x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m Extracting /tmp/tmpt8k8jglf/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/tmpt8k8jglf/MNIST/raw\u001b[32m [repeated 11x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=64101)\u001b[0m \u001b[32m [repeated 11x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/9912422 [00:00`.\n", "- {doc}`[Intermediate] Fine-tune a BERT Text Classifier with PyTorch Lightning and Ray Train <../../train/examples/lightning/lightning_cola_advanced>`\n", "- {doc}`[Advanced] Fine-tune dolly-v2-7b with PyTorch Lightning and FSDP <../../train/examples/lightning/dolly_lightning_fsdp_finetuning>`\n", "- {doc}`/tune/examples/includes/mlflow_ptl_example`: Example for using [MLflow](https://github.com/mlflow/mlflow/)\n", " and [Pytorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) with Ray Tune.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 4 }