{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tune a Hugging Face Transformers Model" ] }, { "cell_type": "markdown", "metadata": { "id": "VaFMt6AIhYbK" }, "source": [ "This notebook is based on an official Hugging Face example, [How to fine-tune a model on text classification](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb). This notebook shows the process of conversion from vanilla HF to Ray Train without changing the training logic unless necessary.\n", "\n", "This notebook consists of the following steps:\n", "1. [Set up Ray](#hf-setup)\n", "2. [Load the dataset](#hf-load)\n", "3. [Preprocess the dataset with Ray Data](#hf-preprocess)\n", "4. [Run the training with Ray Train](#hf-train)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "sQbdfyWQhYbO" }, "source": [ "Uncomment and run the following line to install all the necessary dependencies. (This notebook is being tested with `transformers==4.19.1`.):" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "YajFzmkthYbO" }, "outputs": [], "source": [ "#! pip install \"datasets\" \"transformers>=4.19.0\" \"torch>=1.10.0\" \"mlflow\"" ] }, { "cell_type": "markdown", "metadata": { "id": "pvSRaEHChYbP" }, "source": [ "(hf-setup)=\n", "## Set up Ray" ] }, { "cell_type": "markdown", "metadata": { "id": "LRdL3kWBhYbQ" }, "source": [ "Use `ray.init()` to initialize a local cluster. By default, this cluster contains only the machine you are running this notebook on. You can also run this notebook on an [Anyscale](https://www.anyscale.com/) cluster." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MOsHUjgdIrIW", "outputId": "e527bdbb-2f28-4142-cca0-762e0566cbcd", "tags": [] }, "outputs": [], "source": [ "from pprint import pprint\n", "import ray\n", "\n", "ray.init()" ] }, { "cell_type": "markdown", "metadata": { "id": "oJiSdWy2hYbR" }, "source": [ "Check the resources our cluster is composed of. If you are running this notebook on your local machine or Google Colab, you should see the number of CPU cores and GPUs available on your machine." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KlMz0dt9hYbS", "outputId": "2d485449-ee69-4334-fcba-47e0ceb63078", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'CPU': 48.0,\n", " 'GPU': 4.0,\n", " 'accelerator_type:T4': 1.0,\n", " 'anyscale/accelerator_shape:4xT4': 1.0,\n", " 'anyscale/node-group:head': 1.0,\n", " 'anyscale/provider:aws': 1.0,\n", " 'anyscale/region:us-west-2': 1.0,\n", " 'memory': 206158430208.0,\n", " 'node:10.0.114.132': 1.0,\n", " 'node:__internal_head__': 1.0,\n", " 'object_store_memory': 58913938636.0}\n" ] } ], "source": [ "pprint(ray.cluster_resources())" ] }, { "cell_type": "markdown", "metadata": { "id": "uS6oeJELhYbS" }, "source": [ "This notebook fine-tunes a [HF Transformers](https://github.com/huggingface/transformers) model for one of the text classification task of the [GLUE Benchmark](https://gluebenchmark.com/). It runs the training using Ray Train.\n", "\n", "You can change these two variables to control whether the training, which happens later, uses CPUs or GPUs, and how many workers to spawn. Each worker claims one CPU or GPU. Make sure to not request more resources than the resources present. By default, the training runs with one GPU worker." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "gAbhv9OqhYbT", "tags": [] }, "outputs": [], "source": [ "use_gpu = True # set this to False to run on CPUs\n", "num_workers = 1 # set this to number of GPUs or CPUs you want to use" ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "## Fine-tune a model on a text classification task" ] }, { "cell_type": "markdown", "metadata": { "id": "kTCFado4IrIc" }, "source": [ "The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences. To learn more, see the [original notebook](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb).\n", "\n", "Each task has a name that is its acronym, with `mnli-mm` to indicate that it is a mismatched version of MNLI. Each one has the same training set as `mnli` but different validation and test sets." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "YZbiBDuGIrId", "tags": [] }, "outputs": [], "source": [ "GLUE_TASKS = [\n", " \"cola\",\n", " \"mnli\",\n", " \"mnli-mm\",\n", " \"mrpc\",\n", " \"qnli\",\n", " \"qqp\",\n", " \"rte\",\n", " \"sst2\",\n", " \"stsb\",\n", " \"wnli\",\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "4RRkXuteIrIh" }, "source": [ "This notebook runs on any of the tasks in the list above, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a classification head. Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set these three parameters, and the rest of the notebook should run smoothly:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "zVvslsfMIrIh", "tags": [] }, "outputs": [], "source": [ "task = \"cola\"\n", "model_checkpoint = \"distilbert-base-uncased\"\n", "batch_size = 16" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "(hf-load)=\n", "### Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "Use the [HF Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric to use for evaluation and to compare your model to the benchmark. You can do this comparison easily with the `load_dataset` and `load_metric` functions.\n", "\n", "Apart from `mnli-mm` being special code, you can directly pass the task name to those functions.\n", "\n", "Run the normal HF Datasets code to load the dataset from the Hub." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 200 }, "id": "MwhAeEOuhYbV", "outputId": "3aff8c73-d6eb-4784-890a-a419403b5bda", "tags": [] }, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "actual_task = \"mnli\" if task == \"mnli-mm\" else task\n", "datasets = load_dataset(\"glue\", actual_task)" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `dataset` object itself is a [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation, and test set, with more keys for the mismatched validation and test set in the special case of `mnli`." ] }, { "cell_type": "markdown", "metadata": { "id": "n9qywopnIrJH" }, "source": [ "(hf-preprocess)=\n", "### Preprocessing the data with Ray Data" ] }, { "cell_type": "markdown", "metadata": { "id": "YVx71GdAIrJH" }, "source": [ "Before you can feed these texts to the model, you need to preprocess them. Preprocess them with a HF Transformers' `Tokenizer`, which tokenizes the inputs, including converting the tokens to their corresponding IDs in the pretrained vocabulary, and puts them in a format the model expects. It also generates the other inputs that the model requires.\n", "\n", "To do all of this preprocessing, instantiate your tokenizer with the `AutoTokenizer.from_pretrained` method, which ensures that you:\n", "\n", "- Get a tokenizer that corresponds to the model architecture you want to use.\n", "- Download the vocabulary used when pretraining this specific checkpoint." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145 }, "id": "eXNLu_-nIrJI", "outputId": "f545a7a5-f341-4315-cd89-9942a657aa31", "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ray/anaconda3/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", " _torch_pytree._register_pytree_node(\n", "/home/ray/anaconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cbfed0e37b1f4546a98878cc6090ff6f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/48.0 [00:00SplitBlocks(96) 1: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1b3a0c80878d46bfa496425a6c750405", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41621) - split(1, equal=True) 2: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SplitCoordinator pid=41621)\u001b[0m Registered dataset logger for dataset train_23_0\n", "\u001b[36m(SplitCoordinator pid=41621)\u001b[0m Starting execution of Dataset train_23_0. Full logs are in /tmp/ray/session_2025-07-09_15-09-59_163606_3385/logs/ray-data\n", "\u001b[36m(SplitCoordinator pid=41621)\u001b[0m Execution plan of Dataset train_23_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(1, equal=True)]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "== Status ==\n", "Current time: 2025-07-09 15:56:52 (running for 00:00:20.23)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 accelerator_type:T4, 0.0/1.0 anyscale/node-group:head)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m /tmp/ipykernel_40967/133795194.py:24: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n", "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m [rank0]:[W reducer.cpp:1389] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "== Status ==\n", "Current time: 2025-07-09 15:56:57 (running for 00:00:25.25)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 accelerator_type:T4, 0.0/1.0 anyscale/node-group:head)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:02 (running for 00:00:30.27)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 accelerator_type:T4)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:07 (running for 00:00:35.29)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 accelerator_type:T4)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:12 (running for 00:00:40.32)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 accelerator_type:T4)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:17 (running for 00:00:45.34)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 accelerator_type:T4)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SplitCoordinator pid=41621)\u001b[0m ✔️ Dataset train_23_0 execution finished in 28.21 seconds\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m {'loss': 0.5441, 'learning_rate': 9.9812734082397e-06, 'epoch': 0.5}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SplitCoordinator pid=41622)\u001b[0m Registered dataset logger for dataset eval_24_0\n", "\u001b[36m(SplitCoordinator pid=41622)\u001b[0m Starting execution of Dataset eval_24_0. Full logs are in /tmp/ray/session_2025-07-09_15-09-59_163606_3385/logs/ray-data\n", "\u001b[36m(SplitCoordinator pid=41622)\u001b[0m Execution plan of Dataset eval_24_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(1, equal=True)]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7afaf8757b8e49ff8ba1432fc50fb053", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41622) Running 0: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1c0abd88c1694102bd779fca0ea61804", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41622) - ReadParquet->SplitBlocks(96) 1: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a5cd7456953145c195d74c23ed8badd7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41622) - split(1, equal=True) 2: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "== Status ==\n", "Current time: 2025-07-09 15:57:22 (running for 00:00:50.36)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:T4, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m {'eval_loss': 0.51453697681427, 'eval_matthews_correlation': 0.37793570732654813, 'eval_runtime': 1.8456, 'eval_samples_per_second': 565.126, 'eval_steps_per_second': 35.761, 'epoch': 0.5}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-07-09 15:57:26,970\tWARNING experiment_state.py:206 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds and may become a bottleneck. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this warning by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0). Set it to 0 to completely suppress this warning.\n", "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/TorchTrainer_2025-07-09_15-56-32/TorchTrainer_f5114_00000_0_2025-07-09_15-56-32/checkpoint_000000)\n", "\u001b[36m(SplitCoordinator pid=41622)\u001b[0m ✔️ Dataset eval_24_0 execution finished in 1.73 seconds\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "99ce23489c4a46ec9f621589bd83ff9a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41621) Running 0: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c77d0ec4db2e40b982a32d3171238ac8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41621) - ReadParquet->SplitBlocks(96) 1: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e20adeec80504e68875294b053e0c1c0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41621) - split(1, equal=True) 2: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "== Status ==\n", "Current time: 2025-07-09 15:57:27 (running for 00:00:55.36)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:T4, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:32 (running for 00:01:00.38)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 accelerator_type:T4, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:38 (running for 00:01:05.41)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 accelerator_type:T4, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:43 (running for 00:01:10.43)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 accelerator_type:T4, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/provider:aws)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:48 (running for 00:01:15.45)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/node-group:head, 0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 accelerator_type:T4, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/provider:aws)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n", "== Status ==\n", "Current time: 2025-07-09 15:57:53 (running for 00:01:20.47)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 accelerator_type:T4)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SplitCoordinator pid=41621)\u001b[0m ✔️ Dataset train_23_1 execution finished in 26.58 seconds\n", "\u001b[36m(SplitCoordinator pid=41621)\u001b[0m Registered dataset logger for dataset train_23_1\n", "\u001b[36m(SplitCoordinator pid=41621)\u001b[0m Starting execution of Dataset train_23_1. Full logs are in /tmp/ray/session_2025-07-09_15-09-59_163606_3385/logs/ray-data\n", "\u001b[36m(SplitCoordinator pid=41621)\u001b[0m Execution plan of Dataset train_23_1: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(1, equal=True)]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m {'loss': 0.3864, 'learning_rate': 0.0, 'epoch': 1.5}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(SplitCoordinator pid=41622)\u001b[0m Registered dataset logger for dataset eval_24_1\n", "\u001b[36m(SplitCoordinator pid=41622)\u001b[0m Starting execution of Dataset eval_24_1. Full logs are in /tmp/ray/session_2025-07-09_15-09-59_163606_3385/logs/ray-data\n", "\u001b[36m(SplitCoordinator pid=41622)\u001b[0m Execution plan of Dataset eval_24_1: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadParquet] -> OutputSplitter[split(1, equal=True)]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3853079fdd524064890b0dfccb41aa9b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41622) Running 0: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7398954babdc47798b8da28a2adfc080", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41622) - ReadParquet->SplitBlocks(96) 1: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "37889407c41e4a5782e8f74b02a401a3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=41622) - split(1, equal=True) 2: 0.00 row [00:00, ? row/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m {'eval_loss': 0.5683005452156067, 'eval_matthews_correlation': 0.45115517656589194, 'eval_runtime': 1.6027, 'eval_samples_per_second': 650.77, 'eval_steps_per_second': 41.18, 'epoch': 1.5}\n", "== Status ==\n", "Current time: 2025-07-09 15:57:58 (running for 00:01:25.49)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 accelerator_type:T4)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 RUNNING)\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-07-09 15:57:59,354\tWARNING experiment_state.py:206 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds and may become a bottleneck. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", "You can suppress this warning by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0). Set it to 0 to completely suppress this warning.\n", "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/TorchTrainer_2025-07-09_15-56-32/TorchTrainer_f5114_00000_0_2025-07-09_15-56-32/checkpoint_000001)\n", "\u001b[36m(SplitCoordinator pid=41622)\u001b[0m ✔️ Dataset eval_24_1 execution finished in 1.49 seconds\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[36m(RayTrainWorker pid=41521)\u001b[0m {'train_runtime': 66.7725, 'train_samples_per_second': 255.914, 'train_steps_per_second': 15.995, 'train_loss': 0.4653928092356478, 'epoch': 1.5}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-07-09 15:58:00,649\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/home/ray/ray_results/TorchTrainer_2025-07-09_15-56-32' in 0.0022s.\n", "2025-07-09 15:58:00,651\tINFO tune.py:1041 -- Total run time: 88.09 seconds (88.03 seconds for the tuning loop).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "== Status ==\n", "Current time: 2025-07-09 15:58:00 (running for 00:01:28.04)\n", "Using FIFO scheduling algorithm.\n", "Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 anyscale/region:us-west-2, 0.0/1.0 anyscale/provider:aws, 0.0/1.0 anyscale/accelerator_shape:4xT4, 0.0/1.0 anyscale/node-group:head, 0.0/1.0 accelerator_type:T4)\n", "Result logdir: /tmp/ray/session_2025-07-09_15-09-59_163606_3385/artifacts/2025-07-09_15-56-32/TorchTrainer_2025-07-09_15-56-32/driver_artifacts\n", "Number of trials: 1/1 (1 TERMINATED)\n", "\n", "\n" ] } ], "source": [ "result = trainer.fit()" ] }, { "cell_type": "markdown", "metadata": { "id": "4cnWqUWmhYba" }, "source": [ "You can use the returned `Result` object to access metrics and the Ray Train `Checkpoint` associated with the last iteration." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AMN5qjUwhYba", "outputId": "7b754c36-c58b-4ff4-d7a8-63ec9764bd0c" }, "outputs": [ { "data": { "text/plain": [ "Result(\n", " metrics={'loss': 0.3864, 'learning_rate': 0.0, 'epoch': 1.5, 'step': 1068, 'eval_loss': 0.5683005452156067, 'eval_matthews_correlation': 0.45115517656589194, 'eval_runtime': 1.6027, 'eval_samples_per_second': 650.77, 'eval_steps_per_second': 41.18},\n", " path='/home/ray/ray_results/TorchTrainer_2025-07-09_15-56-32/TorchTrainer_f5114_00000_0_2025-07-09_15-56-32',\n", " filesystem='local',\n", " checkpoint=Checkpoint(filesystem=local, path=/home/ray/ray_results/TorchTrainer_2025-07-09_15-56-32/TorchTrainer_f5114_00000_0_2025-07-09_15-56-32/checkpoint_000001)\n", ")" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## See also\n", "\n", "* {doc}`Ray Train Examples <../../examples>` for more use cases\n", "* {ref}`Ray Train User Guides ` for how-to guides\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "huggingface_text_classification.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.23" }, "orphan": true }, "nbformat": 4, "nbformat_minor": 4 }