{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tune a 🤗 Transformers model" ] }, { "cell_type": "markdown", "metadata": { "id": "VaFMt6AIhYbK" }, "source": [ "This notebook is based on [an official 🤗 notebook - \"How to fine-tune a model on text classification\"](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb). The main aim of this notebook is to show the process of conversion from vanilla 🤗 to [Ray AIR](https://docs.ray.io/en/latest/ray-air/getting-started.html) 🤗 without changing the training logic unless necessary.\n", "\n", "In this notebook, we will:\n", "1. [Set up Ray](#setup)\n", "2. [Load the dataset](#load)\n", "3. [Preprocess the dataset with Ray AIR](#preprocess)\n", "4. [Run the training with Ray AIR](#train)\n", "5. [Predict on test data with Ray AIR](#predict)\n", "6. [Optionally, share the model with the community](#share)" ] }, { "cell_type": "markdown", "metadata": { "id": "sQbdfyWQhYbO" }, "source": [ "Uncomment and run the following line in order to install all the necessary dependencies (this notebook is being tested with `transformers==4.19.1`):" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "YajFzmkthYbO" }, "outputs": [], "source": [ "#! pip install \"datasets\" \"transformers>=4.19.0\" \"torch>=1.10.0\" \"mlflow\" \"ray[air]>=1.13\"" ] }, { "cell_type": "markdown", "metadata": { "id": "pvSRaEHChYbP" }, "source": [ "## Set up Ray " ] }, { "cell_type": "markdown", "metadata": { "id": "LRdL3kWBhYbQ" }, "source": [ "We will use `ray.init()` to initialize a local cluster. By default, this cluster will be compromised of only the machine you are running this notebook on. You can also run this notebook on an Anyscale cluster.\n", "\n", "Note: this notebook *will not* run in Ray Client mode." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MOsHUjgdIrIW", "outputId": "e527bdbb-2f28-4142-cca0-762e0566cbcd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-08-25 10:09:51,282\tINFO worker.py:1223 -- Using address localhost:9031 set in the environment variable RAY_ADDRESS\n", "2022-08-25 10:09:51,697\tINFO worker.py:1333 -- Connecting to existing Ray cluster at address: 172.31.80.117:9031...\n", "2022-08-25 10:09:51,706\tINFO worker.py:1509 -- Connected to Ray cluster. View the dashboard at \u001b[1m\u001b[32mhttps://session-i8ddtfaxhwypbvnyb9uzg7xs.i.anyscaleuserdata-staging.com/auth/?token=agh0_CkcwRQIhAJXwvxwq31GryaWthvXGCXZebsijbuqi7qL2pCa5uROOAiBGjzsyXAJFHLlaEI9zSlNI8ewtghKg5UV3t8NmlxuMcRJmEiCtvjcKE0VPiU7iQx51P9oPQjfpo5g1RJXccVSS5005cBgCIgNuL2E6DAj9xazjBhDwj4veAUIMCP3ClJgGEPCPi94B-gEeChxzZXNfaThERFRmQVhId1lwYlZueWI5dVpnN3hT&redirect_to=dashboard \u001b[39m\u001b[22m\n", "2022-08-25 10:09:51,709\tINFO packaging.py:342 -- Pushing file package 'gcs://_ray_pkg_3332f64b0a461fddc20be71129115d0a.zip' (0.34MiB) to Ray cluster...\n", "2022-08-25 10:09:51,714\tINFO packaging.py:351 -- Successfully pushed file package 'gcs://_ray_pkg_3332f64b0a461fddc20be71129115d0a.zip'.\n" ] }, { "data": { "text/html": [ "
\n", "
\n", "

Ray

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", "
Python version:3.8.5
Ray version: 2.0.0
Dashboard:http://session-i8ddtfaxhwypbvnyb9uzg7xs.i.anyscaleuserdata-staging.com/auth/?token=agh0_CkcwRQIhAJXwvxwq31GryaWthvXGCXZebsijbuqi7qL2pCa5uROOAiBGjzsyXAJFHLlaEI9zSlNI8ewtghKg5UV3t8NmlxuMcRJmEiCtvjcKE0VPiU7iQx51P9oPQjfpo5g1RJXccVSS5005cBgCIgNuL2E6DAj9xazjBhDwj4veAUIMCP3ClJgGEPCPi94B-gEeChxzZXNfaThERFRmQVhId1lwYlZueWI5dVpnN3hT&redirect_to=dashboard
\n", "
\n", "
\n" ], "text/plain": [ "RayContext(dashboard_url='session-i8ddtfaxhwypbvnyb9uzg7xs.i.anyscaleuserdata-staging.com/auth/?token=agh0_CkcwRQIhAJXwvxwq31GryaWthvXGCXZebsijbuqi7qL2pCa5uROOAiBGjzsyXAJFHLlaEI9zSlNI8ewtghKg5UV3t8NmlxuMcRJmEiCtvjcKE0VPiU7iQx51P9oPQjfpo5g1RJXccVSS5005cBgCIgNuL2E6DAj9xazjBhDwj4veAUIMCP3ClJgGEPCPi94B-gEeChxzZXNfaThERFRmQVhId1lwYlZueWI5dVpnN3hT&redirect_to=dashboard', python_version='3.8.5', ray_version='2.0.0', ray_commit='cba26cc83f6b5b8a2ff166594a65cb74c0ec8740', address_info={'node_ip_address': '172.31.80.117', 'raylet_ip_address': '172.31.80.117', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-08-25_09-57-39_455459_216/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-08-25_09-57-39_455459_216/sockets/raylet', 'webui_url': 'session-i8ddtfaxhwypbvnyb9uzg7xs.i.anyscaleuserdata-staging.com/auth/?token=agh0_CkcwRQIhAJXwvxwq31GryaWthvXGCXZebsijbuqi7qL2pCa5uROOAiBGjzsyXAJFHLlaEI9zSlNI8ewtghKg5UV3t8NmlxuMcRJmEiCtvjcKE0VPiU7iQx51P9oPQjfpo5g1RJXccVSS5005cBgCIgNuL2E6DAj9xazjBhDwj4veAUIMCP3ClJgGEPCPi94B-gEeChxzZXNfaThERFRmQVhId1lwYlZueWI5dVpnN3hT&redirect_to=dashboard', 'session_dir': '/tmp/ray/session_2022-08-25_09-57-39_455459_216', 'metrics_export_port': 55366, 'gcs_address': '172.31.80.117:9031', 'address': '172.31.80.117:9031', 'dashboard_agent_listen_port': 52365, 'node_id': '422ff33444fd0f870aa6e718628407400a0ec9483a637c3026c3f9a3'})" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pprint import pprint\n", "import ray\n", "\n", "ray.init()" ] }, { "cell_type": "markdown", "metadata": { "id": "oJiSdWy2hYbR" }, "source": [ "We can check the resources our cluster is composed of. If you are running this notebook on your local machine or Google Colab, you should see the number of CPU cores and GPUs available on the said machine." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KlMz0dt9hYbS", "outputId": "2d485449-ee69-4334-fcba-47e0ceb63078" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'CPU': 208.0,\n", " 'GPU': 16.0,\n", " 'accelerator_type:T4': 4.0,\n", " 'memory': 616693614180.0,\n", " 'node:172.31.76.237': 1.0,\n", " 'node:172.31.80.117': 1.0,\n", " 'node:172.31.85.193': 1.0,\n", " 'node:172.31.85.32': 1.0,\n", " 'node:172.31.90.137': 1.0,\n", " 'object_store_memory': 259318055729.0}\n" ] } ], "source": [ "pprint(ray.cluster_resources())" ] }, { "cell_type": "markdown", "metadata": { "id": "uS6oeJELhYbS" }, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model to a text classification task of the [GLUE Benchmark](https://gluebenchmark.com/). We will be running the training using [Ray AIR](https://docs.ray.io/en/latest/ray-air/getting-started.html).\n", "\n", "You can change those two variables to control whether the training (which we will get to later) uses CPUs or GPUs, and how many workers should be spawned. Each worker will claim one CPU or GPU. Make sure not to request more resources than the resources present!\n", "\n", "By default, we will run the training with one GPU worker." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "gAbhv9OqhYbT" }, "outputs": [], "source": [ "use_gpu = True # set this to False to run on CPUs\n", "num_workers = 1 # set this to number of GPUs/CPUs you want to use" ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "## Fine-tuning a model on a text classification task" ] }, { "cell_type": "markdown", "metadata": { "id": "kTCFado4IrIc" }, "source": [ "The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences. If you would like to learn more, refer to the [original notebook](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb).\n", "\n", "Each task is named by its acronym, with `mnli-mm` standing for the mismatched version of MNLI (so same training set as `mnli` but different validation and test sets):" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "YZbiBDuGIrId" }, "outputs": [], "source": [ "GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]" ] }, { "cell_type": "markdown", "metadata": { "id": "4RRkXuteIrIh" }, "source": [ "This notebook is built to run on any of the tasks in the list above, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a classification head. Depending on your model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "task = \"cola\"\n", "model_checkpoint = \"distilbert-base-uncased\"\n", "batch_size = 16" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "### Loading the dataset " ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.\n", "\n", "Apart from `mnli-mm` being a special code, we can directly pass our task name to those functions.\n", "\n", "As Ray AIR doesn't provide integrations for 🤗 Datasets yet, we will simply run the normal 🤗 Datasets code to load the dataset from the Hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 200 }, "id": "MwhAeEOuhYbV", "outputId": "3aff8c73-d6eb-4784-890a-a419403b5bda" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "actual_task = \"mnli\" if task == \"mnli-mm\" else task\n", "datasets = load_dataset(\"glue\", actual_task)" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation, and test set (with more keys for the mismatched validation and test set in the special case of `mnli`)." ] }, { "cell_type": "markdown", "metadata": { "id": "_TOee7nohYbW" }, "source": [ "We will also need the metric. In order to avoid serialization errors, we will load the metric inside the training workers later. Therefore, now we will just define the function we will use." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "FNE583uBhYbW" }, "outputs": [], "source": [ "from datasets import load_metric\n", "\n", "def load_metric_fn():\n", " return load_metric('glue', actual_task)" ] }, { "cell_type": "markdown", "metadata": { "id": "lnjDIuQ3IrI-" }, "source": [ "The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric)." ] }, { "cell_type": "markdown", "metadata": { "id": "n9qywopnIrJH" }, "source": [ "### Preprocessing the data with Ray AIR " ] }, { "cell_type": "markdown", "metadata": { "id": "YVx71GdAIrJH" }, "source": [ "Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers' `Tokenizer`, which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.\n", "\n", "To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure that:\n", "\n", "- we get a tokenizer that corresponds to the model architecture we want to use,\n", "- we download the vocabulary used when pretraining this specific checkpoint." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145 }, "id": "eXNLu_-nIrJI", "outputId": "f545a7a5-f341-4315-cd89-9942a657aa31" }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "Vl6IidfdIrJK" }, "source": [ "We pass along `use_fast=True` to the call above to use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library. Those fast tokenizers are available for almost all models, but if you got an error with the previous call, remove that argument." ] }, { "cell_type": "markdown", "metadata": { "id": "qo_0B1M2IrJM" }, "source": [ "To preprocess our dataset, we will thus need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "fyGdtK9oIrJM" }, "outputs": [], "source": [ "task_to_keys = {\n", " \"cola\": (\"sentence\", None),\n", " \"mnli\": (\"premise\", \"hypothesis\"),\n", " \"mnli-mm\": (\"premise\", \"hypothesis\"),\n", " \"mrpc\": (\"sentence1\", \"sentence2\"),\n", " \"qnli\": (\"question\", \"sentence\"),\n", " \"qqp\": (\"question1\", \"question2\"),\n", " \"rte\": (\"sentence1\", \"sentence2\"),\n", " \"sst2\": (\"sentence\", None),\n", " \"stsb\": (\"sentence1\", \"sentence2\"),\n", " \"wnli\": (\"sentence1\", \"sentence2\"),\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "256fOuzjhYbY" }, "source": [ "For Ray AIR, instead of using 🤗 Dataset objects directly, we will convert them to [Ray Datasets](https://docs.ray.io/en/latest/data/dataset.html). Both are backed by Arrow tables, so the conversion is straightforward. We will use the built-in `ray.data.from_huggingface` function." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': Dataset(num_blocks=1, num_rows=8551, schema={sentence: string, label: int64, idx: int32}),\n", " 'validation': Dataset(num_blocks=1, num_rows=1043, schema={sentence: string, label: int64, idx: int32}),\n", " 'test': Dataset(num_blocks=1, num_rows=1063, schema={sentence: string, label: int64, idx: int32})}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ray.data\n", "\n", "ray_datasets = ray.data.from_huggingface(datasets)\n", "ray_datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "2C0hcmp9IrJQ" }, "source": [ "We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer than what the model selected can handle will be truncated to the maximum length accepted by the model.\n", "\n", "We use a `BatchMapper` to create a Ray AIR preprocessor that will map the function to the dataset in a distributed fashion. It will run during training and prediction." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "vc0BSBLIIrJQ" }, "outputs": [], "source": [ "import pandas as pd\n", "from ray.data.preprocessors import BatchMapper\n", "\n", "def preprocess_function(examples: pd.DataFrame):\n", " # if we only have one column, we are inferring.\n", " # no need to tokenize in that case. \n", " if len(examples.columns) == 1:\n", " return examples\n", " examples = examples.to_dict(\"list\")\n", " sentence1_key, sentence2_key = task_to_keys[task]\n", " if sentence2_key is None:\n", " ret = tokenizer(examples[sentence1_key], truncation=True)\n", " else:\n", " ret = tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n", " # Add back the original columns\n", " ret = {**examples, **ret}\n", " return pd.DataFrame.from_dict(ret)\n", "\n", "batch_encoder = BatchMapper(preprocess_function, batch_format=\"pandas\")" ] }, { "cell_type": "markdown", "metadata": { "id": "545PP3o8IrJV" }, "source": [ "### Fine-tuning the model with Ray AIR " ] }, { "cell_type": "markdown", "metadata": { "id": "FBiW8UpKIrJW" }, "source": [ "Now that our data is ready, we can download the pretrained model and fine-tune it.\n", "\n", "Since all our tasks are about sentence classification, we use the `AutoModelForSequenceClassification` class.\n", "\n", "We will not go into details about each specific component of the training (see the [original notebook](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb) for that). The tokenizer is the same as we have used to encoded the dataset before.\n", "\n", "The main difference when using the Ray AIR is that we need to create our 🤗 Transformers `Trainer` inside a function (`trainer_init_per_worker`) and return it. That function will be passed to the `HuggingFaceTrainer` and will run on every Ray worker. The training will then proceed by the means of PyTorch DDP.\n", "\n", "Make sure that you initialize the model, metric, and tokenizer inside that function. Otherwise, you may run into serialization errors.\n", "\n", "Furthermore, `push_to_hub=True` is not yet supported. Ray will, however, checkpoint the model at every epoch, allowing you to push it to hub manually. We will do that after the training.\n", "\n", "If you wish to use thrid party logging libraries, such as MLflow or Weights&Biases, do not set them in `TrainingArguments` (they will be automatically disabled) - instead, you should pass Ray AIR callbacks to `HuggingFaceTrainer`'s `run_config`. In this example, we will use MLflow." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "TlqNaB8jIrJW" }, "outputs": [], "source": [ "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", "import numpy as np\n", "import torch\n", "\n", "num_labels = 3 if task.startswith(\"mnli\") else 1 if task==\"stsb\" else 2\n", "metric_name = \"pearson\" if task == \"stsb\" else \"matthews_correlation\" if task == \"cola\" else \"accuracy\"\n", "model_name = model_checkpoint.split(\"/\")[-1]\n", "validation_key = \"validation_mismatched\" if task == \"mnli-mm\" else \"validation_matched\" if task == \"mnli\" else \"validation\"\n", "name = f\"{model_name}-finetuned-{task}\"\n", "\n", "def trainer_init_per_worker(train_dataset, eval_dataset = None, **config):\n", " print(f\"Is CUDA available: {torch.cuda.is_available()}\")\n", " metric = load_metric_fn()\n", " tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", " model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)\n", " args = TrainingArguments(\n", " name,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " logging_strategy=\"epoch\",\n", " learning_rate=config.get(\"learning_rate\", 2e-5),\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " num_train_epochs=config.get(\"epochs\", 2),\n", " weight_decay=config.get(\"weight_decay\", 0.01),\n", " push_to_hub=False,\n", " disable_tqdm=True, # declutter the output a little\n", " no_cuda=not use_gpu, # you need to explicitly set no_cuda if you want CPUs\n", " )\n", "\n", " def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " if task != \"stsb\":\n", " predictions = np.argmax(predictions, axis=1)\n", " else:\n", " predictions = predictions[:, 0]\n", " return metric.compute(predictions=predictions, references=labels)\n", "\n", " trainer = Trainer(\n", " model,\n", " args,\n", " train_dataset=train_dataset,\n", " eval_dataset=eval_dataset,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics\n", " )\n", "\n", " print(\"Starting training\")\n", " return trainer" ] }, { "cell_type": "markdown", "metadata": { "id": "CdzABDVcIrJg" }, "source": [ "With our `trainer_init_per_worker` complete, we can now instantiate the `HuggingFaceTrainer`. Aside from the function, we set the `scaling_config`, controlling the amount of workers and resources used, and the `datasets` we will use for training and evaluation.\n", "\n", "We specify the `MLflowLoggerCallback` inside the `run_config`, and pass the preprocessor we have defined earlier as an argument. The preprocessor will be included with the returned `Checkpoint`, meaning it will also be applied during inference." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "RElw7OgLhYba" }, "outputs": [], "source": [ "from ray.train.huggingface import HuggingFaceTrainer\n", "from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig\n", "from ray.air.integrations.mlflow import MLflowLoggerCallback\n", "\n", "trainer = HuggingFaceTrainer(\n", " trainer_init_per_worker=trainer_init_per_worker,\n", " scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),\n", " datasets={\"train\": ray_datasets[\"train\"], \"evaluation\": ray_datasets[validation_key]},\n", " run_config=RunConfig(\n", " callbacks=[MLflowLoggerCallback(experiment_name=name)],\n", " checkpoint_config=CheckpointConfig(num_to_keep=1, checkpoint_score_attribute=\"eval_loss\", checkpoint_score_order=\"min\"),\n", " ),\n", " preprocessor=batch_encoder,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "XvS136zKhYba" }, "source": [ "Finally, we call the `fit` method to start training with Ray AIR. We will save the `Result` object to a variable so we can access metrics and checkpoints." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "uNx5pyRlIrJh", "outputId": "8496fe4f-f1c3-48ad-a6d3-b16a65716135" }, "outputs": [ { "data": { "text/html": [ "== Status ==
Current time: 2022-08-25 10:14:09 (running for 00:04:06.45)
Memory usage on this node: 4.3/62.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/208 CPUs, 0/16 GPUs, 0.0/574.34 GiB heap, 0.0/241.51 GiB objects (0.0/4.0 accelerator_type:T4)
Result logdir: /home/ray/ray_results/HuggingFaceTrainer_2022-08-25_10-10-02
Number of trials: 1/1 (1 TERMINATED)
\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) loss learning_rate epoch
HuggingFaceTrainer_c1ff5_00000TERMINATED172.31.90.137:947 2 200.2170.3886 0 2


" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "(RayTrainWorker pid=1114, ip=172.31.90.137) 2022-08-25 10:10:44,617\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=4]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(RayTrainWorker pid=1114, ip=172.31.90.137) Is CUDA available: True\n", "(RayTrainWorker pid=1116, ip=172.31.90.137) Is CUDA available: True\n", "(RayTrainWorker pid=1117, ip=172.31.90.137) Is CUDA available: True\n", "(RayTrainWorker pid=1115, ip=172.31.90.137) Is CUDA available: True\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 5.76kB [00:00, 6.45MB/s] \n", "Downloading builder script: 5.76kB [00:00, 6.91MB/s] \n", "Downloading builder script: 5.76kB [00:00, 6.44MB/s] \n", "Downloading builder script: 5.76kB [00:00, 6.94MB/s] \n", "Downloading tokenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<00:00, 30.5kB/s]\n", "Downloading config.json: 100%|██████████| 483/483 [00:00<00:00, 817kB/s]\n", "Downloading vocab.txt: 0%| | 0.00/226k [00:00" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we would like to tune any hyperparameters of the model, we can do so by simply passing our `HuggingFaceTrainer` into a `Tuner` and defining the search space.\n", "\n", "We can also take advantage of the advanced search algorithms and schedulers provided by Ray Tune. In this example, we will use an `ASHAScheduler` to aggresively terminate underperforming trials." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from ray import tune\n", "from ray.tune import Tuner\n", "from ray.tune.schedulers.async_hyperband import ASHAScheduler\n", "\n", "tune_epochs = 4\n", "tuner = Tuner(\n", " trainer,\n", " param_space={\n", " \"trainer_init_config\": {\n", " \"learning_rate\": tune.grid_search([2e-5, 2e-4, 2e-3, 2e-2]),\n", " \"epochs\": tune_epochs,\n", " }\n", " },\n", " tune_config=tune.TuneConfig(\n", " metric=\"eval_loss\",\n", " mode=\"min\",\n", " num_samples=1,\n", " scheduler=ASHAScheduler(\n", " max_t=tune_epochs,\n", " )\n", " ),\n", " run_config=RunConfig(\n", " checkpoint_config=CheckpointConfig(num_to_keep=1, checkpoint_score_attribute=\"eval_loss\", checkpoint_score_order=\"min\")\n", " ),\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "== Status ==
Current time: 2022-08-25 10:20:13 (running for 00:06:01.75)
Memory usage on this node: 4.4/62.0 GiB
Using AsyncHyperBand: num_stopped=4\n", "Bracket: Iter 4.000: -0.8064090609550476 | Iter 1.000: -0.6378736793994904
Resources requested: 0/208 CPUs, 0/16 GPUs, 0.0/574.34 GiB heap, 0.0/241.51 GiB objects (0.0/4.0 accelerator_type:T4)
Current best trial: 5654d_00001 with eval_loss=0.6492420434951782 and parameters={'trainer_init_config': {'learning_rate': 0.0002, 'epochs': 4}}
Result logdir: /home/ray/ray_results/HuggingFaceTrainer_2022-08-25_10-14-11
Number of trials: 4/4 (4 TERMINATED)
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc trainer_init_conf... iter total time (s) loss learning_rate epoch
HuggingFaceTrainer_5654d_00000TERMINATED172.31.90.137:1729 2e-05 4 347.171 0.1958 0 4
HuggingFaceTrainer_5654d_00001TERMINATED172.31.76.237:1805 0.0002 1 95.24920.6225 0.00015 1
HuggingFaceTrainer_5654d_00002TERMINATED172.31.85.32:1322 0.002 1 93.76130.6463 0.0015 1
HuggingFaceTrainer_5654d_00003TERMINATED172.31.85.193:1060 0.02 1 99.36770.926 0.015 1


" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "(RayTrainWorker pid=1789, ip=172.31.90.137) 2022-08-25 10:14:23,379\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=4]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(RayTrainWorker pid=1792, ip=172.31.90.137) Is CUDA available: True\n", "(RayTrainWorker pid=1790, ip=172.31.90.137) Is CUDA available: True\n", "(RayTrainWorker pid=1791, ip=172.31.90.137) Is CUDA available: True\n", "(RayTrainWorker pid=1789, ip=172.31.90.137) Is CUDA available: True\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "(RayTrainWorker pid=1974, ip=172.31.76.237) 2022-08-25 10:14:29,354\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=4]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(RayTrainWorker pid=1977, ip=172.31.76.237) Is CUDA available: True\n", "(RayTrainWorker pid=1976, ip=172.31.76.237) Is CUDA available: True\n", "(RayTrainWorker pid=1975, ip=172.31.76.237) Is CUDA available: True\n", "(RayTrainWorker pid=1974, ip=172.31.76.237) Is CUDA available: True\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "(RayTrainWorker pid=1483, ip=172.31.85.32) 2022-08-25 10:14:35,313\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=4]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(RayTrainWorker pid=1790, ip=172.31.90.137) Starting training\n", "(RayTrainWorker pid=1792, ip=172.31.90.137) Starting training\n", "(RayTrainWorker pid=1791, ip=172.31.90.137) Starting training\n", "(RayTrainWorker pid=1789, ip=172.31.90.137) Starting training\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "(RayTrainWorker pid=1789, ip=172.31.90.137) ***** Running training *****\n", "(RayTrainWorker pid=1789, ip=172.31.90.137) Num examples = 8551\n", "(RayTrainWorker pid=1789, ip=172.31.90.137) Num Epochs = 4\n", "(RayTrainWorker pid=1789, ip=172.31.90.137) Instantaneous batch size per device = 16\n", "(RayTrainWorker pid=1789, ip=172.31.90.137) Total train batch size (w. parallel, distributed & accumulation) = 64\n", "(RayTrainWorker pid=1789, ip=172.31.90.137) Gradient Accumulation steps = 1\n", "(RayTrainWorker pid=1789, ip=172.31.90.137) Total optimization steps = 2140\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(RayTrainWorker pid=1483, ip=172.31.85.32) Is CUDA available: True\n", "(RayTrainWorker pid=1485, ip=172.31.85.32) Is CUDA available: True\n", "(RayTrainWorker pid=1486, ip=172.31.85.32) Is CUDA available: True\n", "(RayTrainWorker pid=1484, ip=172.31.85.32) Is CUDA available: True\n", "(RayTrainWorker pid=1977, ip=172.31.76.237) Starting training\n", "(RayTrainWorker pid=1976, ip=172.31.76.237) Starting training\n", "(RayTrainWorker pid=1975, ip=172.31.76.237) Starting training\n", "(RayTrainWorker pid=1974, ip=172.31.76.237) Starting training\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "(RayTrainWorker pid=1974, ip=172.31.76.237) ***** Running training *****\n", "(RayTrainWorker pid=1974, ip=172.31.76.237) Num examples = 8551\n", "(RayTrainWorker pid=1974, ip=172.31.76.237) Num Epochs = 4\n", "(RayTrainWorker pid=1974, ip=172.31.76.237) Instantaneous batch size per device = 16\n", "(RayTrainWorker pid=1974, ip=172.31.76.237) Total train batch size (w. parallel, distributed & accumulation) = 64\n", "(RayTrainWorker pid=1974, ip=172.31.76.237) Gradient Accumulation steps = 1\n", "(RayTrainWorker pid=1974, ip=172.31.76.237) Total optimization steps = 2140\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(RayTrainWorker pid=1483, ip=172.31.85.32) Starting training\n", "(RayTrainWorker pid=1485, ip=172.31.85.32) Starting training\n", "(RayTrainWorker pid=1486, ip=172.31.85.32) Starting training\n", "(RayTrainWorker pid=1484, ip=172.31.85.32) Starting training\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "(RayTrainWorker pid=1483, ip=172.31.85.32) ***** Running training *****\n", "(RayTrainWorker pid=1483, ip=172.31.85.32) Num examples = 8551\n", "(RayTrainWorker pid=1483, ip=172.31.85.32) Num Epochs = 4\n", "(RayTrainWorker pid=1483, ip=172.31.85.32) Instantaneous batch size per device = 16\n", "(RayTrainWorker pid=1483, ip=172.31.85.32) Total train batch size (w. parallel, distributed & accumulation) = 64\n", "(RayTrainWorker pid=1483, ip=172.31.85.32) Gradient Accumulation steps = 1\n", "(RayTrainWorker pid=1483, ip=172.31.85.32) Total optimization steps = 2140\n", "(RayTrainWorker pid=1223, ip=172.31.85.193) 2022-08-25 10:14:48,193\tINFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=4]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(RayTrainWorker pid=1223, ip=172.31.85.193) Is CUDA available: True\n", "(RayTrainWorker pid=1224, ip=172.31.85.193) Is CUDA available: True\n", "(RayTrainWorker pid=1226, ip=172.31.85.193) Is CUDA available: True\n", "(RayTrainWorker pid=1225, ip=172.31.85.193) Is CUDA available: True\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 5.76kB [00:00, 6.59MB/s] \n", "Downloading builder script: 5.76kB [00:00, 6.52MB/s] \n", "Downloading builder script: 5.76kB [00:00, 6.07MB/s] \n", "Downloading builder script: 5.76kB [00:00, 6.81MB/s] \n", "Downloading tokenizer_config.json: 100%|██████████| 28.0/28.0 [00:00<00:00, 46.0kB/s]\n", "Downloading config.json: 100%|██████████| 483/483 [00:00<00:00, 766kB/s]\n", "Downloading vocab.txt: 0%| | 0.00/226k [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
losslearning_rateepochstepeval_losseval_matthews_correlationeval_runtimeeval_samples_per_secondeval_steps_per_second_timestamp...pidhostnamenode_iptime_since_restoretimesteps_since_restoreiterations_since_restorewarmup_timeconfig/trainer_init_config/epochsconfig/trainer_init_config/learning_ratelogdir
10.62250.000151.05350.6492420.0000001.0157267.7924.9231661447759...1805ip-172-31-76-237172.31.76.23795.249164010.00366140.00020/home/ray/ray_results/HuggingFaceTrainer_2022-...
30.92600.015001.05350.6529430.0000000.9428288.5105.3031661447782...1060ip-172-31-85-193172.31.85.19399.367746010.00413340.02000/home/ray/ray_results/HuggingFaceTrainer_2022-...
20.64630.001501.05350.6586530.0000000.9576284.0505.2221661447764...1322ip-172-31-85-32172.31.85.3293.761317010.00453340.00200/home/ray/ray_results/HuggingFaceTrainer_2022-...
00.19580.000004.021400.8064090.5322861.0006271.8274.9971661448005...1729ip-172-31-90-137172.31.90.137347.170584040.00370240.00002/home/ray/ray_results/HuggingFaceTrainer_2022-...
\n", "

4 rows × 33 columns

\n", "" ], "text/plain": [ " loss learning_rate epoch step eval_loss eval_matthews_correlation \\\n", "1 0.6225 0.00015 1.0 535 0.649242 0.000000 \n", "3 0.9260 0.01500 1.0 535 0.652943 0.000000 \n", "2 0.6463 0.00150 1.0 535 0.658653 0.000000 \n", "0 0.1958 0.00000 4.0 2140 0.806409 0.532286 \n", "\n", " eval_runtime eval_samples_per_second eval_steps_per_second _timestamp \\\n", "1 1.0157 267.792 4.923 1661447759 \n", "3 0.9428 288.510 5.303 1661447782 \n", "2 0.9576 284.050 5.222 1661447764 \n", "0 1.0006 271.827 4.997 1661448005 \n", "\n", " ... pid hostname node_ip time_since_restore \\\n", "1 ... 1805 ip-172-31-76-237 172.31.76.237 95.249164 \n", "3 ... 1060 ip-172-31-85-193 172.31.85.193 99.367746 \n", "2 ... 1322 ip-172-31-85-32 172.31.85.32 93.761317 \n", "0 ... 1729 ip-172-31-90-137 172.31.90.137 347.170584 \n", "\n", " timesteps_since_restore iterations_since_restore warmup_time \\\n", "1 0 1 0.003661 \n", "3 0 1 0.004133 \n", "2 0 1 0.004533 \n", "0 0 4 0.003702 \n", "\n", " config/trainer_init_config/epochs config/trainer_init_config/learning_rate \\\n", "1 4 0.00020 \n", "3 4 0.02000 \n", "2 4 0.00200 \n", "0 4 0.00002 \n", "\n", " logdir \n", "1 /home/ray/ray_results/HuggingFaceTrainer_2022-... \n", "3 /home/ray/ray_results/HuggingFaceTrainer_2022-... \n", "2 /home/ray/ray_results/HuggingFaceTrainer_2022-... \n", "0 /home/ray/ray_results/HuggingFaceTrainer_2022-... \n", "\n", "[4 rows x 33 columns]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tune_results.get_dataframe().sort_values(\"eval_loss\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "best_result = tune_results.get_best_result()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predict on test data with Ray AIR " ] }, { "cell_type": "markdown", "metadata": { "id": "Tfoyu1q7hYbb" }, "source": [ "You can now use the checkpoint to run prediction with `HuggingFacePredictor`, which wraps around [🤗 Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines). In order to distribute prediction, we use `BatchPredictor`. While this is not necessary for the very small example we are using (you could use `HuggingFacePredictor` directly), it will scale well to a large dataset." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 262 }, "id": "UOUcBkX8IrJi", "outputId": "4dc16812-1400-482d-8c3f-85991ce4b081" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 12.41it/s]\n", "Map_Batches: 100%|██████████| 1/1 [00:00<00:00, 7.46it/s]\n", "Map Progress (1 actors 1 pending): 100%|██████████| 1/1 [00:18<00:00, 18.46s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'label': 'LABEL_1', 'score': 0.6822417974472046}\n", "{'label': 'LABEL_1', 'score': 0.6822402477264404}\n", "{'label': 'LABEL_1', 'score': 0.6822407841682434}\n", "{'label': 'LABEL_1', 'score': 0.6822386980056763}\n", "{'label': 'LABEL_1', 'score': 0.6822428107261658}\n", "{'label': 'LABEL_1', 'score': 0.6822453737258911}\n", "{'label': 'LABEL_1', 'score': 0.6822437047958374}\n", "{'label': 'LABEL_1', 'score': 0.6822428703308105}\n", "{'label': 'LABEL_1', 'score': 0.6822431683540344}\n", "{'label': 'LABEL_1', 'score': 0.6822426915168762}\n", "{'label': 'LABEL_1', 'score': 0.6822447776794434}\n", "{'label': 'LABEL_1', 'score': 0.6822456121444702}\n", "{'label': 'LABEL_1', 'score': 0.6822471022605896}\n", "{'label': 'LABEL_1', 'score': 0.6822477579116821}\n", "{'label': 'LABEL_1', 'score': 0.682244598865509}\n", "{'label': 'LABEL_1', 'score': 0.6822422742843628}\n", "{'label': 'LABEL_1', 'score': 0.6822470426559448}\n", "{'label': 'LABEL_1', 'score': 0.6822417378425598}\n", "{'label': 'LABEL_1', 'score': 0.6822449564933777}\n", "{'label': 'LABEL_1', 'score': 0.682239294052124}\n" ] } ], "source": [ "from ray.train.huggingface import HuggingFacePredictor\n", "from ray.train.batch_predictor import BatchPredictor\n", "import pandas as pd\n", "\n", "predictor = BatchPredictor.from_checkpoint(\n", " checkpoint=best_result.checkpoint,\n", " predictor_cls=HuggingFacePredictor,\n", " task=\"text-classification\",\n", " device=0 if use_gpu else -1, # -1 is CPU, otherwise device index\n", ")\n", "prediction = predictor.predict(ray_datasets[\"test\"].map_batches(lambda x: x[[\"sentence\"]]), num_gpus_per_worker=int(use_gpu))\n", "prediction.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Share the model " ] }, { "cell_type": "markdown", "metadata": { "id": "mS8PId_NhYbb" }, "source": [ "To be able to share your model with the community, there are a few more steps to follow.\n", "\n", "We have conducted the training on the Ray cluster, but share the model from the local enviroment - this will allow us to easily authenticate.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your username and password:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2LClXkN8hYbb", "tags": [ "remove-cell-ci" ] }, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": { "id": "SybKUDryhYbb" }, "source": [ "Then you need to install Git-LFS. Uncomment the following instructions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_wF6aT-0hYbb", "tags": [ "remove-cell-ci" ] }, "outputs": [], "source": [ "# !apt install git-lfs" ] }, { "cell_type": "markdown", "metadata": { "id": "5fr6E0e8hYbb" }, "source": [ "Now, load the model and tokenizer locally, and recreate the 🤗 Transformers `Trainer`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cjH2A8m6hYbc", "tags": [ "remove-cell-ci" ] }, "outputs": [], "source": [ "from ray.train.huggingface import HuggingFaceCheckpoint\n", "\n", "checkpoint = HuggingFaceCheckpoint.from_checkpoint(result.checkpoint)\n", "hf_trainer = checkpoint.get_model(model=AutoModelForSequenceClassification)" ] }, { "cell_type": "markdown", "metadata": { "id": "tgV2xKfFhYbc" }, "source": [ "You can now upload the result of the training to the Hub, just execute this instruction:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XSkfJe3nhYbc", "tags": [ "remove-cell-ci" ] }, "outputs": [], "source": [ "hf_trainer.push_to_hub()" ] }, { "cell_type": "markdown", "metadata": { "id": "UL-Boc4dhYbc" }, "source": [ "You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier `\"your-username/the-name-you-picked\"` so for instance:\n", "\n", "```python\n", "from transformers import AutoModelForSequenceClassification\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(\"sgugger/my-awesome-model\")\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ipJBReeWhYbc", "tags": [ "remove-cell-ci" ] }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "huggingface_text_classification.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3.8.9 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" }, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 0 }