{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-tune a Hugging Face Transformers Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VaFMt6AIhYbK"
},
"source": [
"This notebook is based on an official Hugging Face example, [How to fine-tune a model on text classification](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb). This notebook shows the process of conversion from vanilla HF to Ray Train without changing the training logic unless necessary.\n",
"\n",
"This notebook consists of the following steps:\n",
"1. [Set up Ray](#hf-setup)\n",
"2. [Load the dataset](#hf-load)\n",
"3. [Preprocess the dataset with Ray Data](#hf-preprocess)\n",
"4. [Run the training with Ray Train](#hf-train)\n",
"5. [Optionally, share the model with the community](#hf-share)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sQbdfyWQhYbO"
},
"source": [
"Uncomment and run the following line to install all the necessary dependencies. (This notebook is being tested with `transformers==4.19.1`.):"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "YajFzmkthYbO"
},
"outputs": [],
"source": [
"#! pip install \"datasets\" \"transformers>=4.19.0\" \"torch>=1.10.0\" \"mlflow\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pvSRaEHChYbP"
},
"source": [
"(hf-setup)=\n",
"## Set up Ray"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LRdL3kWBhYbQ"
},
"source": [
"Use `ray.init()` to initialize a local cluster. By default, this cluster contains only the machine you are running this notebook on. You can also run this notebook on an [Anyscale](https://www.anyscale.com/) cluster."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MOsHUjgdIrIW",
"outputId": "e527bdbb-2f28-4142-cca0-762e0566cbcd",
"tags": []
},
"outputs": [],
"source": [
"from pprint import pprint\n",
"import ray\n",
"\n",
"ray.init()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oJiSdWy2hYbR"
},
"source": [
"Check the resources our cluster is composed of. If you are running this notebook on your local machine or Google Colab, you should see the number of CPU cores and GPUs available on the your machine."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KlMz0dt9hYbS",
"outputId": "2d485449-ee69-4334-fcba-47e0ceb63078",
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'CPU': 48.0,\n",
" 'GPU': 4.0,\n",
" 'accelerator_type:None': 1.0,\n",
" 'memory': 206158430208.0,\n",
" 'node:10.0.27.125': 1.0,\n",
" 'node:__internal_head__': 1.0,\n",
" 'object_store_memory': 59052625920.0}\n"
]
}
],
"source": [
"pprint(ray.cluster_resources())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uS6oeJELhYbS"
},
"source": [
"This notebook fine-tunes a [HF Transformers](https://github.com/huggingface/transformers) model for one of the text classification task of the [GLUE Benchmark](https://gluebenchmark.com/). It runs the training using Ray Train.\n",
"\n",
"You can change these two variables to control whether the training, which happens later, uses CPUs or GPUs, and how many workers to spawn. Each worker claims one CPU or GPU. Make sure to not request more resources than the resources present. By default, the training runs with one GPU worker."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "gAbhv9OqhYbT",
"tags": []
},
"outputs": [],
"source": [
"use_gpu = True # set this to False to run on CPUs\n",
"num_workers = 1 # set this to number of GPUs or CPUs you want to use"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rEJBSTyZIrIb"
},
"source": [
"## Fine-tune a model on a text classification task"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kTCFado4IrIc"
},
"source": [
"The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences. To learn more, see the [original notebook](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb).\n",
"\n",
"Each task has a name that is its acronym, with `mnli-mm` to indicate that it is a mismatched version of MNLI. Each one has the same training set as `mnli` but different validation and test sets."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "YZbiBDuGIrId",
"tags": []
},
"outputs": [],
"source": [
"GLUE_TASKS = [\n",
" \"cola\",\n",
" \"mnli\",\n",
" \"mnli-mm\",\n",
" \"mrpc\",\n",
" \"qnli\",\n",
" \"qqp\",\n",
" \"rte\",\n",
" \"sst2\",\n",
" \"stsb\",\n",
" \"wnli\",\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4RRkXuteIrIh"
},
"source": [
"This notebook runs on any of the tasks in the list above, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a classification head. Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set these three parameters, and the rest of the notebook should run smoothly:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "zVvslsfMIrIh",
"tags": []
},
"outputs": [],
"source": [
"task = \"cola\"\n",
"model_checkpoint = \"distilbert-base-uncased\"\n",
"batch_size = 16"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "whPRbBNbIrIl"
},
"source": [
"(hf-load)=\n",
"### Loading the dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W7QYTpxXIrIl"
},
"source": [
"Use the [HF Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric to use for evaluation and to compare your model to the benchmark. You can do this comparison easily with the `load_dataset` and `load_metric` functions.\n",
"\n",
"Apart from `mnli-mm` being special code, you can directly pass the task name to those functions.\n",
"\n",
"Run the normal HF Datasets code to load the dataset from the Hub."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 200
},
"id": "MwhAeEOuhYbV",
"outputId": "3aff8c73-d6eb-4784-890a-a419403b5bda",
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset glue (/home/ray/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8217c4d4e1e7402c92477b3e2cf8961c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset\n",
"\n",
"actual_task = \"mnli\" if task == \"mnli-mm\" else task\n",
"datasets = load_dataset(\"glue\", actual_task)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RzfPtOMoIrIu"
},
"source": [
"The `dataset` object itself is a [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation, and test set, with more keys for the mismatched validation and test set in the special case of `mnli`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n9qywopnIrJH"
},
"source": [
"(hf-preprocess)=\n",
"### Preprocessing the data with Ray Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YVx71GdAIrJH"
},
"source": [
"Before you can feed these texts to the model, you need to preprocess them. Preprocess them with a HF Transformers' `Tokenizer`, which tokenizes the inputs, including converting the tokens to their corresponding IDs in the pretrained vocabulary, and puts them in a format the model expects. It also generates the other inputs that the model requires.\n",
"\n",
"To do all of this preprocessing, instantiate your tokenizer with the `AutoTokenizer.from_pretrained` method, which ensures that you:\n",
"\n",
"- Get a tokenizer that corresponds to the model architecture you want to use.\n",
"- Download the vocabulary used when pretraining this specific checkpoint."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 145
},
"id": "eXNLu_-nIrJI",
"outputId": "f545a7a5-f341-4315-cd89-9942a657aa31",
"tags": []
},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Vl6IidfdIrJK"
},
"source": [
"Pass `use_fast=True` to the preceding call to use one of the fast tokenizers, backed by Rust, from the HF Tokenizers library. These fast tokenizers are available for almost all models, but if you get an error with the previous call, remove the argument."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qo_0B1M2IrJM"
},
"source": [
"To preprocess the dataset, you need the names of the columns containing the sentence(s). The following dictionary keeps track of the correspondence task to column names:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "fyGdtK9oIrJM",
"tags": []
},
"outputs": [],
"source": [
"task_to_keys = {\n",
" \"cola\": (\"sentence\", None),\n",
" \"mnli\": (\"premise\", \"hypothesis\"),\n",
" \"mnli-mm\": (\"premise\", \"hypothesis\"),\n",
" \"mrpc\": (\"sentence1\", \"sentence2\"),\n",
" \"qnli\": (\"question\", \"sentence\"),\n",
" \"qqp\": (\"question1\", \"question2\"),\n",
" \"rte\": (\"sentence1\", \"sentence2\"),\n",
" \"sst2\": (\"sentence\", None),\n",
" \"stsb\": (\"sentence1\", \"sentence2\"),\n",
" \"wnli\": (\"sentence1\", \"sentence2\"),\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "256fOuzjhYbY"
},
"source": [
"Instead of using HF Dataset objects directly, convert them to [Ray Data](data). Arrow tables back both of them, so the conversion is straightforward. Use the built-in {meth}`~ray.data.from_huggingface` function."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'train': MaterializedDataset(\n",
" num_blocks=1,\n",
" num_rows=8551,\n",
" schema={sentence: string, label: int64, idx: int32}\n",
" ),\n",
" 'validation': MaterializedDataset(\n",
" num_blocks=1,\n",
" num_rows=1043,\n",
" schema={sentence: string, label: int64, idx: int32}\n",
" ),\n",
" 'test': MaterializedDataset(\n",
" num_blocks=1,\n",
" num_rows=1063,\n",
" schema={sentence: string, label: int64, idx: int32}\n",
" )}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ray.data\n",
"\n",
"ray_datasets = {\n",
" \"train\": ray.data.from_huggingface(datasets[\"train\"]),\n",
" \"validation\": ray.data.from_huggingface(datasets[\"validation\"]),\n",
" \"test\": ray.data.from_huggingface(datasets[\"test\"]),\n",
"}\n",
"ray_datasets"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2C0hcmp9IrJQ"
},
"source": [
"You can then write the function that preprocesses the samples. Feed them to the `tokenizer` with the argument `truncation=True`. This configuration ensures that the `tokenizer` truncates and pads to the longest sequence in the batch, any input longer than what the model selected can handle."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "vc0BSBLIIrJQ",
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"from typing import Dict\n",
"\n",
"\n",
"# Tokenize input sentences\n",
"def collate_fn(examples: Dict[str, np.array]):\n",
" sentence1_key, sentence2_key = task_to_keys[task]\n",
" if sentence2_key is None:\n",
" outputs = tokenizer(\n",
" list(examples[sentence1_key]),\n",
" truncation=True,\n",
" padding=\"longest\",\n",
" return_tensors=\"pt\",\n",
" )\n",
" else:\n",
" outputs = tokenizer(\n",
" list(examples[sentence1_key]),\n",
" list(examples[sentence2_key]),\n",
" truncation=True,\n",
" padding=\"longest\",\n",
" return_tensors=\"pt\",\n",
" )\n",
"\n",
" outputs[\"labels\"] = torch.LongTensor(examples[\"label\"])\n",
"\n",
" # Move all input tensors to GPU\n",
" for key, value in outputs.items():\n",
" outputs[key] = value.cuda()\n",
"\n",
" return outputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "545PP3o8IrJV"
},
"source": [
"(hf-train)=\n",
"### Fine-tuning the model with Ray Train"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FBiW8UpKIrJW"
},
"source": [
"Now that the data is ready, download the pretrained model and fine-tune it.\n",
"\n",
"Because all of the tasks involve sentence classification, use the `AutoModelForSequenceClassification` class. For more specifics about each individual training component, see the [original notebook](https://github.com/huggingface/notebooks/blob/6ca682955173cc9d36ffa431ddda505a048cbe80/examples/text_classification.ipynb). The original notebook uses the same tokenizer used to encode the dataset in this notebook's preceding example.\n",
"\n",
"The main difference when using Ray Train is that you need to define the training logic as a function (`train_func`). You pass this [training function](train-overview-training-function) to the {class}`~ray.train.torch.TorchTrainer` to on every Ray worker. The training then proceeds using PyTorch DDP.\n",
"\n",
"\n",
"```{note}\n",
"\n",
"Be sure to initialize the model, metric, and tokenizer within the function. Otherwise, you may encounter serialization errors.\n",
"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "TlqNaB8jIrJW",
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-09-06 14:25:28.144428: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2023-09-06 14:25:28.284936: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2023-09-06 14:25:29.025734: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
"2023-09-06 14:25:29.025801: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
"2023-09-06 14:25:29.025807: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
"comet_ml is installed but `COMET_API_KEY` is not set.\n"
]
}
],
"source": [
"import torch\n",
"import numpy as np\n",
"\n",
"from datasets import load_metric\n",
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"\n",
"import ray.train\n",
"from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback\n",
"\n",
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
"metric_name = (\n",
" \"pearson\"\n",
" if task == \"stsb\"\n",
" else \"matthews_correlation\"\n",
" if task == \"cola\"\n",
" else \"accuracy\"\n",
")\n",
"model_name = model_checkpoint.split(\"/\")[-1]\n",
"validation_key = (\n",
" \"validation_mismatched\"\n",
" if task == \"mnli-mm\"\n",
" else \"validation_matched\"\n",
" if task == \"mnli\"\n",
" else \"validation\"\n",
")\n",
"name = f\"{model_name}-finetuned-{task}\"\n",
"\n",
"# Calculate the maximum steps per epoch based on the number of rows in the training dataset.\n",
"# Make sure to scale by the total number of training workers and the per device batch size.\n",
"max_steps_per_epoch = ray_datasets[\"train\"].count() // (batch_size * num_workers)\n",
"\n",
"\n",
"def train_func(config):\n",
" print(f\"Is CUDA available: {torch.cuda.is_available()}\")\n",
"\n",
" metric = load_metric(\"glue\", actual_task)\n",
" tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n",
" model = AutoModelForSequenceClassification.from_pretrained(\n",
" model_checkpoint, num_labels=num_labels\n",
" )\n",
"\n",
" train_ds = ray.train.get_dataset_shard(\"train\")\n",
" eval_ds = ray.train.get_dataset_shard(\"eval\")\n",
"\n",
" train_ds_iterable = train_ds.iter_torch_batches(\n",
" batch_size=batch_size, collate_fn=collate_fn\n",
" )\n",
" eval_ds_iterable = eval_ds.iter_torch_batches(\n",
" batch_size=batch_size, collate_fn=collate_fn\n",
" )\n",
"\n",
" print(\"max_steps_per_epoch: \", max_steps_per_epoch)\n",
"\n",
" args = TrainingArguments(\n",
" name,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" logging_strategy=\"epoch\",\n",
" per_device_train_batch_size=batch_size,\n",
" per_device_eval_batch_size=batch_size,\n",
" learning_rate=config.get(\"learning_rate\", 2e-5),\n",
" num_train_epochs=config.get(\"epochs\", 2),\n",
" weight_decay=config.get(\"weight_decay\", 0.01),\n",
" push_to_hub=False,\n",
" max_steps=max_steps_per_epoch * config.get(\"epochs\", 2),\n",
" disable_tqdm=True, # declutter the output a little\n",
" no_cuda=not use_gpu, # you need to explicitly set no_cuda if you want CPUs\n",
" report_to=\"none\",\n",
" )\n",
"\n",
" def compute_metrics(eval_pred):\n",
" predictions, labels = eval_pred\n",
" if task != \"stsb\":\n",
" predictions = np.argmax(predictions, axis=1)\n",
" else:\n",
" predictions = predictions[:, 0]\n",
" return metric.compute(predictions=predictions, references=labels)\n",
"\n",
" trainer = Trainer(\n",
" model,\n",
" args,\n",
" train_dataset=train_ds_iterable,\n",
" eval_dataset=eval_ds_iterable,\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics,\n",
" )\n",
"\n",
" trainer.add_callback(RayTrainReportCallback())\n",
"\n",
" trainer = prepare_trainer(trainer)\n",
"\n",
" print(\"Starting training\")\n",
" trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CdzABDVcIrJg"
},
"source": [
"With your `train_func` complete, you can now instantiate the {class}`~ray.train.torch.TorchTrainer`. Aside from calling the function, set the `scaling_config`, which controls the amount of workers and resources used, and the `datasets` to use for training and evaluation."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "RElw7OgLhYba",
"tags": []
},
"outputs": [],
"source": [
"from ray.train.torch import TorchTrainer\n",
"from ray.train import RunConfig, ScalingConfig, CheckpointConfig\n",
"\n",
"trainer = TorchTrainer(\n",
" train_func,\n",
" scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),\n",
" datasets={\n",
" \"train\": ray_datasets[\"train\"],\n",
" \"eval\": ray_datasets[\"validation\"],\n",
" },\n",
" run_config=RunConfig(\n",
" checkpoint_config=CheckpointConfig(\n",
" num_to_keep=1,\n",
" checkpoint_score_attribute=\"eval_loss\",\n",
" checkpoint_score_order=\"min\",\n",
" ),\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XvS136zKhYba"
},
"source": [
"Finally, call the `fit` method to start training with Ray Train. Save the `Result` object to a variable so you can access metrics and checkpoints."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "uNx5pyRlIrJh",
"outputId": "8496fe4f-f1c3-48ad-a6d3-b16a65716135",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"
\n",
"
\n",
"
Tune Status
\n",
"
\n",
"\n",
"Current time: | 2023-09-06 14:27:12 |
\n",
"Running for: | 00:01:40.12 |
\n",
"Memory: | 18.4/186.6 GiB |
\n",
"\n",
"
\n",
"
\n",
"
\n",
"
\n",
"
System Info
\n",
" Using FIFO scheduling algorithm.
Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:None)\n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
"
Trial Status
\n",
"
\n",
"\n",
"Trial name | status | loc | iter | total time (s) | loss | learning_rate | epoch |
\n",
"\n",
"\n",
"TorchTrainer_e8bd4_00000 | TERMINATED | 10.0.27.125:43821 | 2 | 76.6259 | 0.3866 | 0 | 1.5 |
\n",
"\n",
"
\n",
"
\n",
"
\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(TrainTrainable pid=43821)\u001b[0m 2023-09-06 14:25:35.638885: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=43821)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=43821)\u001b[0m 2023-09-06 14:25:35.782950: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=43821)\u001b[0m 2023-09-06 14:25:36.501583: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=43821)\u001b[0m 2023-09-06 14:25:36.501653: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=43821)\u001b[0m 2023-09-06 14:25:36.501660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=43821)\u001b[0m comet_ml is installed but `COMET_API_KEY` is not set.\n",
"\u001b[2m\u001b[36m(TorchTrainer pid=43821)\u001b[0m Starting distributed worker processes: ['43946 (10.0.27.125)']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Setting up process group for: env:// [rank=0, world_size=1]\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m 2023-09-06 14:25:42.756510: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m 2023-09-06 14:25:42.903398: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44017)\u001b[0m Auto configuring locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m 2023-09-06 14:25:43.737476: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m 2023-09-06 14:25:43.737544: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m 2023-09-06 14:25:43.737554: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m comet_ml is installed but `COMET_API_KEY` is not set.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Is CUDA available: True\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier.bias', 'pre_classifier.weight']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44016)\u001b[0m Auto configuring locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026']\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m max_steps_per_epoch: 534\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m max_steps is given, it will override any value given in num_train_epochs\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m /home/ray/anaconda3/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Starting training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m ***** Running training *****\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Num examples = 17088\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Num Epochs = 9223372036854775807\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Instantaneous batch size per device = 16\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Total train batch size (w. parallel, distributed & accumulation) = 16\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Gradient Accumulation steps = 1\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Total optimization steps = 1068\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m /tmp/ipykernel_43503/4088900328.py:23: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44016)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44016)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44016)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=44016) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m [W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m {'loss': 0.5414, 'learning_rate': 9.9812734082397e-06, 'epoch': 0.5}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m ***** Running Evaluation *****\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Num examples: Unknown\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Batch size = 16\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44017)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44017)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44017)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=44017) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-535\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/config.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m {'eval_loss': 0.5018134117126465, 'eval_matthews_correlation': 0.4145623770066859, 'eval_runtime': 0.6595, 'eval_samples_per_second': 1581.584, 'eval_steps_per_second': 100.081, 'epoch': 0.5}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/special_tokens_map.json\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_results/TorchTrainer_2023-09-06_14-25-31/TorchTrainer_e8bd4_00000_0_2023-09-06_14-25-32/checkpoint_000000)\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44016)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44016)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44016)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=44016) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m {'loss': 0.3866, 'learning_rate': 0.0, 'epoch': 1.5}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m ***** Running Evaluation *****\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Num examples: Unknown\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Batch size = 16\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44017)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44017)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=44017)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=44017) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-1068\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-1068/config.json\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m {'eval_loss': 0.5527923107147217, 'eval_matthews_correlation': 0.44860917123689154, 'eval_runtime': 0.6646, 'eval_samples_per_second': 1569.42, 'eval_steps_per_second': 99.311, 'epoch': 1.5}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-1068/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1068/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1068/special_tokens_map.json\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_results/TorchTrainer_2023-09-06_14-25-31/TorchTrainer_e8bd4_00000_0_2023-09-06_14-25-32/checkpoint_000001)\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=43946)\u001b[0m {'train_runtime': 66.0485, 'train_samples_per_second': 258.719, 'train_steps_per_second': 16.17, 'train_loss': 0.46413421630859375, 'epoch': 1.5}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-09-06 14:27:12,180\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n",
"2023-09-06 14:27:12,184\tINFO tune.py:1141 -- Total run time: 100.17 seconds (85.12 seconds for the tuning loop).\n"
]
}
],
"source": [
"result = trainer.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4cnWqUWmhYba"
},
"source": [
"You can use the returned `Result` object to access metrics and the Ray Train `Checkpoint` associated with the last iteration."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AMN5qjUwhYba",
"outputId": "7b754c36-c58b-4ff4-d7a8-63ec9764bd0c"
},
"outputs": [
{
"data": {
"text/plain": [
"Result(\n",
" metrics={'loss': 0.3866, 'learning_rate': 0.0, 'epoch': 1.5, 'step': 1068, 'eval_loss': 0.5527923107147217, 'eval_matthews_correlation': 0.44860917123689154, 'eval_runtime': 0.6646, 'eval_samples_per_second': 1569.42, 'eval_steps_per_second': 99.311},\n",
" path='/mnt/cluster_storage/ray_results/TorchTrainer_2023-09-06_14-25-31/TorchTrainer_e8bd4_00000_0_2023-09-06_14-25-32',\n",
" filesystem='local',\n",
" checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_results/TorchTrainer_2023-09-06_14-25-31/TorchTrainer_e8bd4_00000_0_2023-09-06_14-25-32/checkpoint_000001)\n",
")"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(hf-predict)=\n",
"### Tune hyperparameters with Ray Tune"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To tune any hyperparameters of the model, pass your `TorchTrainer` into a `Tuner` and define the search space.\n",
"\n",
"You can also take advantage of the advanced search algorithms and schedulers from Ray Tune. This example uses an `ASHAScheduler` to aggresively terminate underperforming trials."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-09-06 14:46:47,821\tINFO tuner_internal.py:508 -- A `RunConfig` was passed to both the `Tuner` and the `TorchTrainer`. The run config passed to the `Tuner` is the one that will be used.\n"
]
}
],
"source": [
"from ray import tune\n",
"from ray.tune import Tuner\n",
"from ray.tune.schedulers.async_hyperband import ASHAScheduler\n",
"\n",
"tune_epochs = 4\n",
"tuner = Tuner(\n",
" trainer,\n",
" param_space={\n",
" \"train_loop_config\": {\n",
" \"learning_rate\": tune.grid_search([2e-5, 2e-4, 2e-3, 2e-2]),\n",
" \"epochs\": tune_epochs,\n",
" }\n",
" },\n",
" tune_config=tune.TuneConfig(\n",
" metric=\"eval_loss\",\n",
" mode=\"min\",\n",
" num_samples=1,\n",
" scheduler=ASHAScheduler(\n",
" max_t=tune_epochs,\n",
" ),\n",
" ),\n",
" run_config=RunConfig(\n",
" name=\"tune_transformers\",\n",
" checkpoint_config=CheckpointConfig(\n",
" num_to_keep=1,\n",
" checkpoint_score_attribute=\"eval_loss\",\n",
" checkpoint_score_order=\"min\",\n",
" ),\n",
" ),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"
\n",
"
\n",
"
Tune Status
\n",
"
\n",
"\n",
"Current time: | 2023-09-06 14:49:04 |
\n",
"Running for: | 00:02:16.18 |
\n",
"Memory: | 19.6/186.6 GiB |
\n",
"\n",
"
\n",
"
\n",
"
\n",
"
\n",
"
System Info
\n",
" Using AsyncHyperBand: num_stopped=4
Bracket: Iter 4.000: -0.6517604142427444 | Iter 1.000: -0.5936744660139084
Logical resource usage: 1.0/48 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:None)\n",
" \n",
" \n",
"
\n",
"
\n",
"
\n",
"
Trial Status
\n",
"
\n",
"\n",
"Trial name | status | loc | train_loop_config/le\n",
"arning_rate | iter | total time (s) | loss | learning_rate | epoch |
\n",
"\n",
"\n",
"TorchTrainer_e1825_00000 | TERMINATED | 10.0.27.125:57496 | 2e-05 | 4 | 128.443 | 0.1934 | 0 | 3.25 |
\n",
"TorchTrainer_e1825_00001 | TERMINATED | 10.0.27.125:57497 | 0.0002 | 1 | 41.2486 | 0.616 | 0.000149906 | 0.25 |
\n",
"TorchTrainer_e1825_00002 | TERMINATED | 10.0.27.125:57498 | 0.002 | 1 | 41.1336 | 0.6699 | 0.00149906 | 0.25 |
\n",
"TorchTrainer_e1825_00003 | TERMINATED | 10.0.27.125:57499 | 0.02 | 4 | 126.699 | 0.6073 | 0 | 3.25 |
\n",
"\n",
"
\n",
"
\n",
"
\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(TrainTrainable pid=57498)\u001b[0m 2023-09-06 14:46:52.049839: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57498)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57498)\u001b[0m 2023-09-06 14:46:52.195780: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57498)\u001b[0m 2023-09-06 14:46:52.944517: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57498)\u001b[0m 2023-09-06 14:46:52.944590: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57498)\u001b[0m 2023-09-06 14:46:52.944597: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57498)\u001b[0m comet_ml is installed but `COMET_API_KEY` is not set.\n",
"\u001b[2m\u001b[36m(TorchTrainer pid=57498)\u001b[0m Starting distributed worker processes: ['57731 (10.0.27.125)']\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57499)\u001b[0m 2023-09-06 14:46:52.229406: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57499)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57499)\u001b[0m 2023-09-06 14:46:52.378805: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Setting up process group for: env:// [rank=0, world_size=1]\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57499)\u001b[0m 2023-09-06 14:46:53.174151: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\u001b[32m [repeated 6x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57499)\u001b[0m 2023-09-06 14:46:53.174160: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(TrainTrainable pid=57499)\u001b[0m comet_ml is installed but `COMET_API_KEY` is not set.\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Auto configuring locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026']\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Is CUDA available: True\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m max_steps_per_epoch: 534\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"\u001b[2m\u001b[36m(TorchTrainer pid=57499)\u001b[0m Starting distributed worker processes: ['57746 (10.0.27.125)']\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m 2023-09-06 14:47:00.036649: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m 2023-09-06 14:47:00.198894: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Setting up process group for: env:// [rank=0, world_size=1]\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m 2023-09-06 14:47:01.085704: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\u001b[32m [repeated 8x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m 2023-09-06 14:47:01.085711: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m comet_ml is installed but `COMET_API_KEY` is not set.\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57965)\u001b[0m Auto configuring locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026']\u001b[32m [repeated 7x across cluster]\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Starting training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m max_steps is given, it will override any value given in num_train_epochs\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m /home/ray/anaconda3/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m warnings.warn(\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.bias']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57731)\u001b[0m Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier.bias', 'classifier.weight']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m ***** Running training *****\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Num examples = 34176\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Num Epochs = 9223372036854775807\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Instantaneous batch size per device = 16\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Total train batch size (w. parallel, distributed & accumulation) = 16\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Gradient Accumulation steps = 1\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Total optimization steps = 2136\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m /tmp/ipykernel_43503/4088900328.py:23: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57927) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m [W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57946) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57954) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57965) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m {'loss': 0.5481, 'learning_rate': 1.4990636704119851e-05, 'epoch': 0.25}\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Is CUDA available: True\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m max_steps_per_epoch: 534\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Starting training\u001b[32m [repeated 3x across cluster]\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m ***** Running Evaluation *****\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Num examples: Unknown\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Batch size = 16\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57731)\u001b[0m Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias']\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m max_steps is given, it will override any value given in num_train_epochs\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m /home/ray/anaconda3/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m warnings.warn(\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m ***** Running training *****\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Num examples = 34176\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Num Epochs = 9223372036854775807\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Instantaneous batch size per device = 16\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Total train batch size (w. parallel, distributed & accumulation) = 16\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Gradient Accumulation steps = 1\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Total optimization steps = 2136\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m /tmp/ipykernel_43503/4088900328.py:23: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57965)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57965)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57965)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m [W reducer.cpp:1300] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\u001b[32m [repeated 3x across cluster]\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57928) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57955) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57947) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57966) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m {'eval_loss': 0.5202918648719788, 'eval_matthews_correlation': 0.37321205597032797, 'eval_runtime': 0.7255, 'eval_samples_per_second': 1437.704, 'eval_steps_per_second': 90.976, 'epoch': 0.25}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-535\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/config.json\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/pytorch_model.bin\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/tokenizer_config.json\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/special_tokens_map.json\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/tune_transformers/TorchTrainer_e1825_00000_0_learning_rate=0.0000_2023-09-06_14-46-48/checkpoint_000000)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57927) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57954) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m {'loss': 0.6064, 'learning_rate': 0.009981273408239701, 'epoch': 1.25}\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m {'eval_loss': 0.6181353330612183, 'eval_matthews_correlation': 0.0, 'eval_runtime': 0.7543, 'eval_samples_per_second': 1382.828, 'eval_steps_per_second': 87.504, 'epoch': 0.25}\u001b[32m [repeated 3x across cluster]\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m ***** Running Evaluation *****\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Num examples: Unknown\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Batch size = 16\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57954)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\u001b[32m [repeated 6x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57954)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\u001b[32m [repeated 6x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57954)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\u001b[32m [repeated 6x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-535\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/config.json\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/pytorch_model.bin\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/tokenizer_config.json\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-535/special_tokens_map.json\u001b[32m [repeated 3x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57740)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/tune_transformers/TorchTrainer_e1825_00001_1_learning_rate=0.0002_2023-09-06_14-46-48/checkpoint_000000)\u001b[32m [repeated 3x across cluster]\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57955) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57928) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57954) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57927) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m {'loss': 0.6061, 'learning_rate': 0.004971910112359551, 'epoch': 2.25}\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m {'eval_loss': 0.5246258974075317, 'eval_matthews_correlation': 0.489934557943789, 'eval_runtime': 0.6462, 'eval_samples_per_second': 1614.032, 'eval_steps_per_second': 102.134, 'epoch': 1.25}\u001b[32m [repeated 2x across cluster]\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m ***** Running Evaluation *****\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Num examples: Unknown\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Batch size = 16\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-1070\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-1070/config.json\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-1070/pytorch_model.bin\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1070/tokenizer_config.json\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1070/special_tokens_map.json\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/tune_transformers/TorchTrainer_e1825_00000_0_learning_rate=0.0000_2023-09-06_14-46-48/checkpoint_000001)\u001b[32m [repeated 2x across cluster]\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57955) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57928) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57954) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57927) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m {'loss': 0.6073, 'learning_rate': 0.0, 'epoch': 3.25}\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m {'eval_loss': 0.6450843811035156, 'eval_matthews_correlation': 0.5259674254268325, 'eval_runtime': 0.6474, 'eval_samples_per_second': 1611.106, 'eval_steps_per_second': 101.949, 'epoch': 2.25}\u001b[32m [repeated 2x across cluster]\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m ***** Running Evaluation *****\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Num examples: Unknown\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Batch size = 16\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Executing DAG InputDataBuffer[Input] -> OutputSplitter[split(1, equal=True)]\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=['84374908fd32ea9885fdd6d21aadf2ce3e296daf28a26522e7a8d026'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(SplitCoordinator pid=57927)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\u001b[32m [repeated 4x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Saving model checkpoint to distilbert-base-uncased-finetuned-cola/checkpoint-1605\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Configuration saved in distilbert-base-uncased-finetuned-cola/checkpoint-1605/config.json\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Model weights saved in distilbert-base-uncased-finetuned-cola/checkpoint-1605/pytorch_model.bin\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m tokenizer config file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1605/tokenizer_config.json\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Special tokens file saved in distilbert-base-uncased-finetuned-cola/checkpoint-1605/special_tokens_map.json\u001b[32m [repeated 2x across cluster]\u001b[0m\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/ray_results/tune_transformers/TorchTrainer_e1825_00000_0_learning_rate=0.0000_2023-09-06_14-46-48/checkpoint_000002)\u001b[32m [repeated 2x across cluster]\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57955) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"(pid=57928) Running 0: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57746)\u001b[0m {'train_runtime': 115.5377, 'train_samples_per_second': 295.8, 'train_steps_per_second': 18.487, 'train_loss': 0.6787891173630618, 'epoch': 3.25}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-09-06 14:49:04,574\tINFO tune.py:1141 -- Total run time: 136.19 seconds (136.17 seconds for the tuning loop).\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m \n",
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(RayTrainWorker pid=57741)\u001b[0m {'train_runtime': 117.6791, 'train_samples_per_second': 290.417, 'train_steps_per_second': 18.151, 'train_loss': 0.3468295286657212, 'epoch': 3.25}\n"
]
}
],
"source": [
"tune_results = tuner.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"View the results of the tuning run as a dataframe, and find the best result."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" loss | \n",
" learning_rate | \n",
" epoch | \n",
" step | \n",
" eval_loss | \n",
" eval_matthews_correlation | \n",
" eval_runtime | \n",
" eval_samples_per_second | \n",
" eval_steps_per_second | \n",
" timestamp | \n",
" ... | \n",
" time_total_s | \n",
" pid | \n",
" hostname | \n",
" node_ip | \n",
" time_since_restore | \n",
" iterations_since_restore | \n",
" checkpoint_dir_name | \n",
" config/train_loop_config/learning_rate | \n",
" config/train_loop_config/epochs | \n",
" logdir | \n",
"
\n",
" \n",
" \n",
" \n",
" 1 | \n",
" 0.6160 | \n",
" 0.000150 | \n",
" 0.25 | \n",
" 535 | \n",
" 0.618135 | \n",
" 0.000000 | \n",
" 0.7543 | \n",
" 1382.828 | \n",
" 87.504 | \n",
" 1694036857 | \n",
" ... | \n",
" 41.248600 | \n",
" 57497 | \n",
" ip-10-0-27-125 | \n",
" 10.0.27.125 | \n",
" 41.248600 | \n",
" 1 | \n",
" checkpoint_000000 | \n",
" 0.00020 | \n",
" 4 | \n",
" e1825_00001 | \n",
"
\n",
" \n",
" 2 | \n",
" 0.6699 | \n",
" 0.001499 | \n",
" 0.25 | \n",
" 535 | \n",
" 0.619657 | \n",
" 0.000000 | \n",
" 0.7449 | \n",
" 1400.202 | \n",
" 88.603 | \n",
" 1694036856 | \n",
" ... | \n",
" 41.133609 | \n",
" 57498 | \n",
" ip-10-0-27-125 | \n",
" 10.0.27.125 | \n",
" 41.133609 | \n",
" 1 | \n",
" checkpoint_000000 | \n",
" 0.00200 | \n",
" 4 | \n",
" e1825_00002 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.6073 | \n",
" 0.000000 | \n",
" 3.25 | \n",
" 2136 | \n",
" 0.619694 | \n",
" 0.000000 | \n",
" 0.6329 | \n",
" 1648.039 | \n",
" 104.286 | \n",
" 1694036942 | \n",
" ... | \n",
" 126.699238 | \n",
" 57499 | \n",
" ip-10-0-27-125 | \n",
" 10.0.27.125 | \n",
" 126.699238 | \n",
" 4 | \n",
" checkpoint_000003 | \n",
" 0.02000 | \n",
" 4 | \n",
" e1825_00003 | \n",
"
\n",
" \n",
" 0 | \n",
" 0.1934 | \n",
" 0.000000 | \n",
" 3.25 | \n",
" 2136 | \n",
" 0.747960 | \n",
" 0.520756 | \n",
" 0.6530 | \n",
" 1597.187 | \n",
" 101.068 | \n",
" 1694036944 | \n",
" ... | \n",
" 128.443495 | \n",
" 57496 | \n",
" ip-10-0-27-125 | \n",
" 10.0.27.125 | \n",
" 128.443495 | \n",
" 4 | \n",
" checkpoint_000003 | \n",
" 0.00002 | \n",
" 4 | \n",
" e1825_00000 | \n",
"
\n",
" \n",
"
\n",
"
4 rows × 26 columns
\n",
"
"
],
"text/plain": [
" loss learning_rate epoch step eval_loss eval_matthews_correlation \\\n",
"1 0.6160 0.000150 0.25 535 0.618135 0.000000 \n",
"2 0.6699 0.001499 0.25 535 0.619657 0.000000 \n",
"3 0.6073 0.000000 3.25 2136 0.619694 0.000000 \n",
"0 0.1934 0.000000 3.25 2136 0.747960 0.520756 \n",
"\n",
" eval_runtime eval_samples_per_second eval_steps_per_second timestamp \\\n",
"1 0.7543 1382.828 87.504 1694036857 \n",
"2 0.7449 1400.202 88.603 1694036856 \n",
"3 0.6329 1648.039 104.286 1694036942 \n",
"0 0.6530 1597.187 101.068 1694036944 \n",
"\n",
" ... time_total_s pid hostname node_ip time_since_restore \\\n",
"1 ... 41.248600 57497 ip-10-0-27-125 10.0.27.125 41.248600 \n",
"2 ... 41.133609 57498 ip-10-0-27-125 10.0.27.125 41.133609 \n",
"3 ... 126.699238 57499 ip-10-0-27-125 10.0.27.125 126.699238 \n",
"0 ... 128.443495 57496 ip-10-0-27-125 10.0.27.125 128.443495 \n",
"\n",
" iterations_since_restore checkpoint_dir_name \\\n",
"1 1 checkpoint_000000 \n",
"2 1 checkpoint_000000 \n",
"3 4 checkpoint_000003 \n",
"0 4 checkpoint_000003 \n",
"\n",
" config/train_loop_config/learning_rate config/train_loop_config/epochs \\\n",
"1 0.00020 4 \n",
"2 0.00200 4 \n",
"3 0.02000 4 \n",
"0 0.00002 4 \n",
"\n",
" logdir \n",
"1 e1825_00001 \n",
"2 e1825_00002 \n",
"3 e1825_00003 \n",
"0 e1825_00000 \n",
"\n",
"[4 rows x 26 columns]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tune_results.get_dataframe().sort_values(\"eval_loss\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"best_result = tune_results.get_best_result()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(hf-share)=\n",
"### Share the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mS8PId_NhYbb"
},
"source": [
"To share the model with the community, a few more steps follow.\n",
"\n",
"You conducted the training on the Ray cluster, but want share the model from the local enviroment. This configuration allows you to easily authenticate.\n",
"\n",
"First, store your authentication token from the Hugging Face website. Sign up [here](https://huggingface.co/join) if you haven't already. Then execute the following cell and input your username and password:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2LClXkN8hYbb",
"tags": [
"remove-cell-ci"
]
},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SybKUDryhYbb"
},
"source": [
"Then you need to install Git-LFS. Uncomment the following instructions:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "_wF6aT-0hYbb",
"tags": [
"remove-cell-ci"
]
},
"outputs": [],
"source": [
"# !apt install git-lfs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5fr6E0e8hYbb"
},
"source": [
"Load the model with the best-performing checkpoint:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cjH2A8m6hYbc",
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"from ray.train import Checkpoint\n",
"\n",
"checkpoint: Checkpoint = best_result.checkpoint\n",
"\n",
"with checkpoint.as_directory() as checkpoint_dir:\n",
" checkpoint_path = os.path.join(checkpoint_dir, \"checkpoint\")\n",
" model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tgV2xKfFhYbc"
},
"source": [
"You can now upload the result of the training to the Hub. Execute this instruction:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XSkfJe3nhYbc",
"tags": [
"remove-cell-ci"
]
},
"outputs": [],
"source": [
"model.push_to_hub()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UL-Boc4dhYbc"
},
"source": [
"You can now share this model. Others can load it with the identifier `\"your-username/the-name-you-picked\"`. For example:\n",
"\n",
"```python\n",
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(\"sgugger/my-awesome-model\")\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## See also\n",
"\n",
"* {doc}`Ray Train Examples <../../examples>` for more use cases\n",
"* {ref}`Ray Train User Guides ` for how-to guides\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "huggingface_text_classification.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orphan": true,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}