{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# GPT-J-6B Fine-Tuning with Ray Train and DeepSpeed\n", "\n", "\n", " \"try-anyscale-quickstart\"\n", "\n", "

\n", "\n", "This example showcases how to use Ray Train for **GPT-J fine-tuning**. GPT-J is a GPT-2-like causal language model trained on the Pile dataset. This particular model has 6 billion parameters. For more information, see [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj).\n", "\n", "This example uses the Ray Train 🤗 Transformers integration and a pre-trained model from the Hugging Face Hub. Note that this example is adaptable to other similar models.\n", "\n", "This is an advanced example that focuses on the performance and distributed computing aspects of Ray Train. For a beginner-friendly introduction to the Ray Train 🤗 Transformers integration, see {ref}`Basic Example for HuggingFace Transformers `.\n", "\n", "Read [Ray Train Key Concepts](train-key-concepts) and [Ray Data Integration User Guides](data-ingest-torch) before starting this example.\n", "\n", "```{note}\n", "To run this example, make sure your Ray cluster has access to at least one GPU with 16 or more GBs of memory. The required amount of memory depends on the model. This notebook is tested with 16 g4dn.4xlarge instances (including the head node).\n", "```\n", "\n", "This notebook has the following steps:\n", "1. [Set up Ray](#gptj-setup)\n", "2. [Load the dataset](#gptj-load)\n", "3. [Preprocess the dataset with Ray Data](#gptj-preprocess)\n", "4. [Run the training with Ray Train](#gptj-train)\n", "5. [Generate text from prompt](#gptj-predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uncomment and run the following line in order to install all the necessary dependencies (this notebook was tested with `accelerate=0.18.0`, `transformers==4.26.0`, `deepspeed==0.12.3`):" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "! pip install -q \"datasets\" \"evaluate\" \"accelerate==0.18.0\" \"transformers==4.26.0\" \"torch>=1.12.0\" \"deepspeed==0.12.3\"" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import os" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(gptj-setup)=\n", "## Set up Ray\n", "\n", "First, let's set some global variables. We will use 16 workers, each being assigned 1 GPU and 8 CPUs." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [], "source": [ "model_name = \"EleutherAI/gpt-j-6B\"\n", "use_gpu = True\n", "num_workers = 16\n", "cpus_per_worker = 8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use `ray.init()` to initialize a local cluster. By default, this cluster will be comprised of only the machine you are running this notebook on. You can also run this notebook on an Anyscale cluster.\n", "\n", "We define a {ref}`runtime environment ` to ensure that the Ray workers have access to all the necessary packages. You can omit the `runtime_env` argument if you have all of the packages already installed on each node in your cluster." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import ray\n", "\n", "ray.init(\n", " runtime_env={\n", " \"pip\": [\n", " \"datasets\",\n", " \"evaluate\",\n", " # The latest combination accelerate==0.25.0, transformers==4.36.0, deepspeed==0.12.4\n", " # has issues with DeepSpeed process group initialization,\n", " # and will result in a batch_size validation problem.\n", " # TODO(ml-team): get rid of the pins once the issue is fixed.\n", " \"accelerate==0.18.0\",\n", " \"transformers==4.26.0\",\n", " \"torch>=1.12.0\",\n", " \"deepspeed==0.12.3\",\n", " ],\n", " },\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "# THIS SHOULD BE HIDDEN IN DOCS AND ONLY RAN IN CI\n", "# Download the model from our S3 mirror as it's faster\n", "\n", "import ray\n", "import subprocess\n", "import ray.util.scheduling_strategies\n", "\n", "\n", "def force_on_node(node_id: str, remote_func_or_actor_class):\n", " scheduling_strategy = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(\n", " node_id=node_id, soft=False\n", " )\n", " options = {\"scheduling_strategy\": scheduling_strategy}\n", " return remote_func_or_actor_class.options(**options)\n", "\n", "\n", "def run_on_every_node(remote_func_or_actor_class, **remote_kwargs):\n", " refs = []\n", " for node in ray.nodes():\n", " if node[\"Alive\"] and node[\"Resources\"].get(\"GPU\", None):\n", " refs.append(\n", " force_on_node(node[\"NodeID\"], remote_func_or_actor_class).remote(\n", " **remote_kwargs\n", " )\n", " )\n", " return ray.get(refs)\n", "\n", "\n", "@ray.remote(num_gpus=1)\n", "def download_model():\n", " from transformers.utils.hub import TRANSFORMERS_CACHE\n", "\n", " path = os.path.expanduser(\n", " os.path.join(TRANSFORMERS_CACHE, \"models--EleutherAI--gpt-j-6B\")\n", " )\n", " subprocess.run([\"mkdir\", \"-p\", os.path.join(path, \"snapshots\", \"main\")])\n", " subprocess.run([\"mkdir\", \"-p\", os.path.join(path, \"refs\")])\n", " if os.path.exists(os.path.join(path, \"refs\", \"main\")):\n", " return\n", " subprocess.run(\n", " [\n", " \"aws\",\n", " \"s3\",\n", " \"sync\",\n", " \"--no-sign-request\",\n", " \"s3://large-dl-models-mirror/models--EleutherAI--gpt-j-6B/main/\",\n", " os.path.join(path, \"snapshots\", \"main\"),\n", " ]\n", " )\n", " with open(os.path.join(path, \"snapshots\", \"main\", \"hash\"), \"r\") as f:\n", " f_hash = f.read().strip()\n", " with open(os.path.join(path, \"refs\", \"main\"), \"w\") as f:\n", " f.write(f_hash)\n", " os.rename(\n", " os.path.join(path, \"snapshots\", \"main\"), os.path.join(path, \"snapshots\", f_hash)\n", " )\n", "\n", "\n", "_ = run_on_every_node(download_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(gptj-load)=\n", "## Loading the dataset\n", "\n", "We will be fine-tuning the model on the [`tiny_shakespeare` dataset](https://huggingface.co/datasets/tiny_shakespeare), comprised of 40,000 lines of Shakespeare from a variety of Shakespeare's plays. The aim will be to make the GPT-J model better at generating text in the style of Shakespeare." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "print(\"Loading tiny_shakespeare dataset\")\n", "current_dataset = load_dataset(\"tiny_shakespeare\")\n", "current_dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use [Ray Data](data) for distributed preprocessing and data ingestion. We can easily convert the dataset obtained from Hugging Face Hub to Ray Data by using {meth}`ray.data.from_huggingface`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'train': MaterializedDataset(num_blocks=1, num_rows=1, schema={text: string}),\n", " 'validation': MaterializedDataset(num_blocks=1, num_rows=1, schema={text: string})}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import ray.data\n", "\n", "ray_datasets = {\n", " \"train\": ray.data.from_huggingface(current_dataset[\"train\"]),\n", " \"validation\": ray.data.from_huggingface(current_dataset[\"validation\"]),\n", "}\n", "\n", "ray_datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(gptj-preprocess)=\n", "Note that the dataset is represented by a single line of large string, and needs some preprocessing. To do this, use the {meth}`~ray.data.Dataset.map_batches` API to apply transformation functions to batches of data.\n", "\n", "The `split_text` function takes the single string and splits it into separate lines, removing empty lines and character names ending with ':' (eg. 'ROMEO:'). The `tokenize` function takes the lines and tokenizes them using the 🤗 Tokenizer associated with the model, ensuring each entry has the same length (`block_size`) by padding and truncating. This preprocessing is necessary for training.\n", "\n", "```{note}\n", "This preprocessing can be done in other ways. A common pattern is to tokenize first, and then split the obtained tokens into equally-sized blocks.\n", "```" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [], "source": [ "block_size = 512" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'train': MapBatches(tokenize)\n", " +- MapBatches(split_text)\n", " +- Dataset(num_blocks=1, num_rows=1, schema={text: string}),\n", " 'validation': MapBatches(tokenize)\n", " +- MapBatches(split_text)\n", " +- Dataset(num_blocks=1, num_rows=1, schema={text: string})}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import AutoTokenizer\n", "\n", "\n", "def split_text(batch: pd.DataFrame) -> pd.DataFrame:\n", " text = list(batch[\"text\"])\n", " flat_text = \"\".join(text)\n", " split_text = [\n", " x.strip()\n", " for x in flat_text.split(\"\\n\")\n", " if x.strip() and not x.strip()[-1] == \":\"\n", " ]\n", " return pd.DataFrame(split_text, columns=[\"text\"])\n", "\n", "\n", "def tokenize(batch: pd.DataFrame) -> dict:\n", " tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n", " tokenizer.pad_token = tokenizer.eos_token\n", " ret = tokenizer(\n", " list(batch[\"text\"]),\n", " truncation=True,\n", " max_length=block_size,\n", " padding=\"max_length\",\n", " return_tensors=\"np\",\n", " )\n", " ret[\"labels\"] = ret[\"input_ids\"].copy()\n", " return dict(ret)\n", "\n", "\n", "processed_datasets = {\n", " key: (\n", " ds.map_batches(split_text, batch_format=\"pandas\")\n", " .map_batches(tokenize, batch_format=\"pandas\")\n", " )\n", " for key, ds in ray_datasets.items()\n", "}\n", "processed_datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(gptj-train)=\n", "### Fine-tuning the model with Ray Train\n", "\n", "Configure Ray Train's {class}`~ray.train.torch.TorchTrainer` to perform distributed fine-tuning of the model. Specify a `train_loop_per_worker` function, which defines the training logic to be distributed by Ray using Distributed Data Parallelism, which uses the PyTorch Distributed backend internally. Each worker has its own copy of the model, but operates on different data. At the end of each step, all the workers sync gradients.\n", "\n", "Because GPT-J is a relatively large model, it may not be possible to fit it on smaller GPU types (<=16 GB GRAM). To deal with that issue, this example uses [DeepSpeed](https://github.com/microsoft/DeepSpeed), a library to optimize the training process and to offload and partition optimizer and parameter states, reducing GRAM usage. Furthermore, DeepSpeed ZeRO Stage 3 can load large models without running out of memory.\n", "\n", "🤗 Transformers and Ray Train's {ref}`integrations ` allow you to easily configure and use DDP and DeepSpeed. All you need to do is specify the DeepSpeed configuration in the [`TrainingArguments`](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments) object.\n", "\n", "```{tip}\n", "There are many DeepSpeed settings that allow you to trade-off speed for memory usage. The settings used below are tailored to the cluster setup used (16 g4dn.4xlarge nodes) and per device batch size of 16. Some things to keep in mind:\n", "- If your GPUs support bfloat16, use that instead of float16 mixed precision to get better performance and prevent overflows. Replace `fp16=True` with `bf16=True` in `TrainingArguments`.\n", "- If you are running out of GRAM: try reducing batch size (defined in the cell below the next one), set `\"overlap_comm\": False` in DeepSpeed config.\n", "- If you are running out of RAM, add more nodes to your cluster, use nodes with more RAM, set `\"pin_memory\": False` in the DeepSpeed config, reduce the batch size, and remove `\"offload_param\"` from the DeepSpeed config.\n", "\n", "For more information on DeepSpeed configuration, refer to [Hugging Face documentation](https://huggingface.co/docs/transformers/main_classes/deepspeed) and [DeepSpeed documentation](https://www.deepspeed.ai/docs/config-json/).\n", "\n", "Additionally, if you prefer a lower-level API, the logic below can be expressed as an [Accelerate training loop](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/deepspeed_with_config_support.py) distributed by a Ray Train {class}`~ray.train.torch.torch_trainer.TorchTrainer`.\n", "```\n", "\n", "#### Training speed\n", "\n", "As this example uses data parallelism, each worker operates on its own shard of the data. The batch size set in `train_ds.iter_torch_batches` is the **per device batch size** (per worker batch size). By changing the number of workers, you can change the **effective batch size** and thus the time needed for training to complete. Calculate the effective batch size as `per device batch size * number of workers * number of gradient accumulation steps`. As you add more workers, the effective batch size rises and thus less time is needed to complete a full epoch. While the speedup is not exactly linear due to extra communication overheads, in many cases it can be close to linear.\n", "\n", "The preprocessed dataset has 1348 examples. We have set per device batch size to 16.\n", "\n", "* With 16 g4dn.4xlarge nodes, the effective batch size was 256, which equals to 85 steps per epoch. One epoch took **~2440 seconds** (including initialization time).\n", "\n", "* With 32 g4dn.4xlarge nodes, the effective batch size was 512, which equals to 43 steps per epoch. One epoch took **~1280 seconds** (including initialization time)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import evaluate\n", "import torch\n", "from transformers import (\n", " Trainer,\n", " TrainingArguments,\n", " GPTJForCausalLM,\n", " AutoTokenizer,\n", " default_data_collator,\n", ")\n", "from transformers.utils.logging import disable_progress_bar, enable_progress_bar\n", "\n", "from ray import train\n", "from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback\n", "\n", "\n", "def train_func(config):\n", " # Use the actual number of CPUs assigned by Ray\n", " os.environ[\"OMP_NUM_THREADS\"] = str(\n", " train.get_context().get_trial_resources().bundles[-1].get(\"CPU\", 1)\n", " )\n", " # Enable tf32 for better performance\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", " batch_size = config.get(\"batch_size\", 4)\n", " epochs = config.get(\"epochs\", 2)\n", " warmup_steps = config.get(\"warmup_steps\", 0)\n", " learning_rate = config.get(\"learning_rate\", 0.00002)\n", " weight_decay = config.get(\"weight_decay\", 0.01)\n", " steps_per_epoch = config.get(\"steps_per_epoch\")\n", "\n", " deepspeed = {\n", " \"fp16\": {\n", " \"enabled\": \"auto\",\n", " \"initial_scale_power\": 8,\n", " \"hysteresis\": 4,\n", " \"consecutive_hysteresis\": True,\n", " },\n", " \"bf16\": {\"enabled\": \"auto\"},\n", " \"optimizer\": {\n", " \"type\": \"AdamW\",\n", " \"params\": {\n", " \"lr\": \"auto\",\n", " \"betas\": \"auto\",\n", " \"eps\": \"auto\",\n", " },\n", " },\n", " \"zero_optimization\": {\n", " \"stage\": 3,\n", " \"offload_optimizer\": {\n", " \"device\": \"cpu\",\n", " \"pin_memory\": True,\n", " },\n", " \"overlap_comm\": True,\n", " \"contiguous_gradients\": True,\n", " \"reduce_bucket_size\": \"auto\",\n", " \"stage3_prefetch_bucket_size\": \"auto\",\n", " \"stage3_param_persistence_threshold\": \"auto\",\n", " \"gather_16bit_weights_on_model_save\": True,\n", " \"round_robin_gradients\": True,\n", " },\n", " \"gradient_accumulation_steps\": \"auto\",\n", " \"gradient_clipping\": \"auto\",\n", " \"steps_per_print\": 10,\n", " \"train_batch_size\": \"auto\",\n", " \"train_micro_batch_size_per_gpu\": \"auto\",\n", " \"wall_clock_breakdown\": False,\n", " }\n", "\n", " print(\"Preparing training arguments\")\n", " training_args = TrainingArguments(\n", " \"output\",\n", " logging_steps=1,\n", " save_strategy=\"steps\",\n", " save_steps=steps_per_epoch,\n", " max_steps=steps_per_epoch * epochs,\n", " per_device_train_batch_size=batch_size,\n", " gradient_accumulation_steps=1,\n", " learning_rate=learning_rate,\n", " weight_decay=weight_decay,\n", " warmup_steps=warmup_steps,\n", " label_names=[\"input_ids\", \"attention_mask\"],\n", " push_to_hub=False,\n", " report_to=\"none\",\n", " disable_tqdm=True, # declutter the output a little\n", " fp16=True,\n", " gradient_checkpointing=True,\n", " deepspeed=deepspeed,\n", " )\n", " disable_progress_bar()\n", "\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", " tokenizer.pad_token = tokenizer.eos_token\n", "\n", " print(\"Loading model\")\n", "\n", " model = GPTJForCausalLM.from_pretrained(model_name, use_cache=False)\n", " model.resize_token_embeddings(len(tokenizer))\n", "\n", " print(\"Model loaded\")\n", "\n", " enable_progress_bar()\n", "\n", " metric = evaluate.load(\"accuracy\")\n", "\n", " train_ds = train.get_dataset_shard(\"train\")\n", " eval_ds = train.get_dataset_shard(\"validation\")\n", "\n", " train_ds_iterable = train_ds.iter_torch_batches(\n", " batch_size=batch_size,\n", " local_shuffle_buffer_size=train.get_context().get_world_size() * batch_size,\n", " )\n", " eval_ds_iterable = eval_ds.iter_torch_batches(batch_size=batch_size)\n", "\n", " def compute_metrics(eval_pred):\n", " logits, labels = eval_pred\n", " predictions = np.argmax(logits, axis=-1)\n", " return metric.compute(predictions=predictions, references=labels)\n", "\n", " trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_ds_iterable,\n", " eval_dataset=eval_ds_iterable,\n", " compute_metrics=compute_metrics,\n", " tokenizer=tokenizer,\n", " data_collator=default_data_collator,\n", " )\n", "\n", " # Add callback to report checkpoints to Ray Train\n", " trainer.add_callback(RayTrainReportCallback())\n", " trainer = prepare_trainer(trainer)\n", " trainer.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After defining the training function, instantiate the {class}`~ray.train.torch.TorchTrainer`. Aside from the function, set the `scaling_config` to control the number of workers and amount of resources to use, and `datasets`(the preprocessed Ray Datasets) to use for training and evaluation.\n", "\n", "```{note}\n", "Running with multiple nodes necessitates the persistence of checkpoints\n", "and other outputs to some external storage for access after training has completed.\n", "**You should set up cloud storage or NFS, then replace `storage_path` with your own cloud bucket URI or NFS path.**\n", "\n", "See {ref}`Configuration and Persistent Storage` for more details.\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "storage_path = \"s3://your-bucket-here\" # TODO: Set up cloud storage\n", "# storage_path=\"/mnt/path/to/nfs\" # TODO: Alternatively, set up NFS" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "import os, re\n", "\n", "artifact_storage = os.environ.get(\"ANYSCALE_ARTIFACT_STORAGE\", \"artifact_storage\")\n", "user_name = re.sub(r\"\\s+\", \"__\", os.environ.get(\"ANYSCALE_USERNAME\", \"user\"))\n", "storage_path = f\"{artifact_storage}/{user_name}/gptj-deepspeed-finetune\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_size = 16\n", "train_ds_size = processed_datasets[\"train\"].count()\n", "steps_per_epoch = train_ds_size // (batch_size * num_workers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "# SMOKE TEST SETTINGS FOR CI\n", "steps_per_epoch = 10\n", "num_workers = 8\n", "batch_size = 1" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [] }, "outputs": [], "source": [ "from ray.train.torch import TorchTrainer\n", "from ray.train import RunConfig, ScalingConfig\n", "\n", "trainer = TorchTrainer(\n", " train_loop_per_worker=train_func,\n", " train_loop_config={\n", " \"epochs\": 1,\n", " \"batch_size\": batch_size, # per device\n", " \"steps_per_epoch\": steps_per_epoch,\n", " },\n", " scaling_config=ScalingConfig(\n", " num_workers=num_workers,\n", " use_gpu=use_gpu,\n", " resources_per_worker={\"GPU\": 1, \"CPU\": cpus_per_worker},\n", " ),\n", " datasets=processed_datasets,\n", " run_config=RunConfig(storage_path=storage_path),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, call the {meth}`~ray.train.torch.TorchTrainer.fit` method to start training with Ray Train. Save the {class}`~ray.train.Result` object to a variable to access metrics and checkpoints." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "tags": [ "hide-output" ] }, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2023-08-18 18:54:02
Running for: 00:44:50.37
Memory: 10.2/62.0 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Logical resource usage: 129.0/256 CPUs, 16.0/16 GPUs\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) loss learning_rate epoch
TorchTrainer_01ea5_00000TERMINATED10.0.60.59:8839 1 2663.78 0.069 2.38095e-07 1
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:16.315108: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:16.462944: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:17.336229: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:17.336299: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:17.336306: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m --------------------------------------------------------------------------\n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m Aim collects anonymous usage analytics. \n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m Read how to opt-out here: \n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m https://aimstack.readthedocs.io/en/latest/community/telemetry.html \n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m --------------------------------------------------------------------------\n", "\u001b[2m\u001b[36m(TrainTrainable pid=8839, ip=10.0.60.59)\u001b[0m comet_ml is installed but `COMET_API_KEY` is not set.\n", "\u001b[2m\u001b[36m(TorchTrainer pid=8839, ip=10.0.60.59)\u001b[0m Starting distributed worker processes: ['8911 (10.0.60.59)', '36675 (10.0.13.222)', '8880 (10.0.63.99)', '8867 (10.0.49.236)', '49329 (10.0.40.253)', '8845 (10.0.18.195)', '36249 (10.0.11.26)', '8858 (10.0.0.119)', '8857 (10.0.44.114)', '8885 (10.0.47.209)', '36311 (10.0.27.53)', '8830 (10.0.30.35)', '8875 (10.0.0.80)', '8851 (10.0.43.240)', '9631 (10.0.57.153)', '36262 (10.0.52.191)']\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Setting up process group for: env:// [rank=0, world_size=16]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:25.209122: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:25.358493: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:26.095161: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:26.095229: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:26.095236: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", "\u001b[2m\u001b[36m(SplitCoordinator pid=8980, ip=10.0.60.59)\u001b[0m Auto configuring locality_with_output=['6002ded0aaa53ce9a0351d22a72b344ef411a422919132f41d9f937a', 'd3bbd390b6fe73f26202f96d75998946cf3e8b457528d426db0c6e07', 'fe6aaf54317ee630a02d23e0d49581b57b5cd51316eaf769e28bb045', 'f7de4694a4f764c05a9c51a6a4bd40ac33f3fced3b25127b25cd4ac3', '42866a2fba4ce2ab4b6645c4d731d486b762e2b23ac24cafccba7096', '8a7272830662c7e756a656de0a9b433a3a1f9b990768f692b6fe11a7', 'bba62e8b57552509c62a6b6b7fd67c1a2280b9d81b3d9c41eb4d1b9b', 'b40764f303538c24bc439106f2e7b2144d382bfed6c9fdec15ab828e', 'd1de4d4b6d44eff93857026df4ef0f70e24e3dc91e15d87015f2ed32', '4d6a9dc1aa7bfc80cb73d9f66f4e28041807f12769391f5643bce143', '8bcc7235f459b61be21fe158d0bae4fef2ec6de013ec60e7aaf7897a', '73c50b995811afa0ece70fd3d4466b7fd0dc85a97d6807128b2c47da', '03bf3d374a9f857b1cd1aebdbe028208f7904b077fb151790e03e9fe', '9f7fc101a7d6b3e17b72e57ca1c92f91d13aa385a6740f99d58ec016', '867844d104a8e9351a1dcc8bbd61d99906a8dc5b53e220c2ae2efbe1', '7677b344c59d6b30c3db451f48e346d61bb60cc798e5567aa4e0a1ea']\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m comet_ml is installed but `COMET_API_KEY` is not set.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m --------------------------------------------------------------------------\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m Aim collects anonymous usage analytics. \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m Read how to opt-out here: \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m https://aimstack.readthedocs.io/en/latest/community/telemetry.html \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m --------------------------------------------------------------------------\n", "\u001b[2m\u001b[36m(SplitCoordinator pid=8980, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:26.534936: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\u001b[32m [repeated 16x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n", "\u001b[2m\u001b[36m(SplitCoordinator pid=8980, ip=10.0.60.59)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(SplitCoordinator pid=8980, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:26.667181: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\u001b[32m [repeated 16x across cluster]\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m Preparing training arguments\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Loading model\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +3m53s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.52 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:12:01,852] [INFO] [partition_parameters.py:454:__exit__] finished initializing model with 6.05B parameters\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Preparing training arguments\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8880, ip=10.0.63.99)\u001b[0m Loading model\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m Model loaded\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 100%|██████████| 4.20k/4.20k [00:00<00:00, 22.1MB/s]\n", "\u001b[2m\u001b[36m(SplitCoordinator pid=8980, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:27.424862: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\u001b[32m [repeated 32x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(SplitCoordinator pid=8980, ip=10.0.60.59)\u001b[0m 2023-08-18 18:09:27.424869: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m comet_ml is installed but `COMET_API_KEY` is not set.\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m --------------------------------------------------------------------------\u001b[32m [repeated 26x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m Aim collects anonymous usage analytics. \u001b[32m [repeated 13x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m Read how to opt-out here: \u001b[32m [repeated 13x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m https://aimstack.readthedocs.io/en/latest/community/telemetry.html \u001b[32m [repeated 13x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m max_steps is given, it will override any value given in num_train_epochs\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Using cuda_amp half precision backend\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:12:36,256] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.9.2, git-hash=unknown, git-branch=unknown\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:12:36,373] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8858, ip=10.0.0.119)\u001b[0m Using /home/ray/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8858, ip=10.0.0.119)\u001b[0m Creating extension directory /home/ray/.cache/torch_extensions/py39_cu118/cpu_adam...\n", "Downloading builder script: 100%|██████████| 4.20k/4.20k [00:00<00:00, 19.8MB/s]\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8857, ip=10.0.44.114)\u001b[0m max_steps is given, it will override any value given in num_train_epochs\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8857, ip=10.0.44.114)\u001b[0m Using cuda_amp half precision backend\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Detected CUDA files, patching ldflags\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Emitting ninja build file /home/ray/.cache/torch_extensions/py39_cu118/cpu_adam/build.ninja...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Building extension module cpu_adam...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8858, ip=10.0.0.119)\u001b[0m [1/3] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1011\\\" -I/home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/TH -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/ray/anaconda3/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_75,code=compute_75 -c /home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/common/custom_cuda_kernel.cu -o custom_cuda_kernel.cuda.o \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8830, ip=10.0.30.35)\u001b[0m Model loaded\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m [2/3] c++ -MMD -MF cpu_adam.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1011\\\" -I/home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/TH -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/ray/anaconda3/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++14 -g -Wno-reorder -L/usr/local/cuda/lib64 -lcudart -lcublas -g -march=native -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -c /home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/adam/cpu_adam.cpp -o cpu_adam.o \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m [1/3] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1011\\\" -I/home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/TH -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/ray/anaconda3/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 --compiler-options '-fPIC' -O3 --use_fast_math -std=c++14 -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_75,code=compute_75 -c /home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/common/custom_cuda_kernel.cu -o custom_cuda_kernel.cuda.o \u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m [3/3] c++ cpu_adam.o custom_cuda_kernel.cuda.o -shared -lcurand -L/home/ray/anaconda3/lib/python3.9/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o cpu_adam.so\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Time to load cpu_adam op: 31.202290058135986 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Loading extension module cpu_adam...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Using /home/ray/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Creating extension directory /home/ray/.cache/torch_extensions/py39_cu118/cpu_adam...\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Detected CUDA files, patching ldflags\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Emitting ninja build file /home/ray/.cache/torch_extensions/py39_cu118/cpu_adam/build.ninja...\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Building extension module cpu_adam...\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\u001b[32m [repeated 15x across cluster]\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Adam Optimizer #0 is created with AVX512 arithmetic capability.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Config: alpha=0.000020, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Building extension module utils...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,196] [INFO] [logging.py:96:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adamw as basic optimizer\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,212] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,212] [INFO] [utils.py:54:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,212] [INFO] [logging.py:96:log_dist] [Rank 0] Creating fp16 ZeRO stage 3 optimizer, MiCS is enabled False, Hierarchical params gather False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,212] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 3 optimizer\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,520] [INFO] [utils.py:785:see_memory_usage] Stage 3 initialize beginning\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,521] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 1.26 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,521] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 8.96 GB, percent = 14.4%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,523] [INFO] [stage3.py:113:__init__] Reduce bucket size 16777216\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:13,523] [INFO] [stage3.py:114:__init__] Prefetch bucket size 15099494\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m [1/2] c++ -MMD -MF flatten_unflatten.o.d -DTORCH_EXTENSION_NAME=utils -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1011\\\" -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/TH -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/THC -isystem /home/ray/anaconda3/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/utils/flatten_unflatten.cpp -o flatten_unflatten.o \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m [2/3] c++ -MMD -MF cpu_adam.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1011\\\" -I/home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/includes -I/usr/local/cuda/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/TH -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /home/ray/anaconda3/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++14 -g -Wno-reorder -L/usr/local/cuda/lib64 -lcudart -lcublas -g -march=native -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -c /home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/adam/cpu_adam.cpp -o cpu_adam.o \u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m [3/3] c++ cpu_adam.o custom_cuda_kernel.cuda.o -shared -lcurand -L/home/ray/anaconda3/lib/python3.9/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/usr/local/cuda/lib64 -lcudart -o cpu_adam.so\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Time to load cpu_adam op: 34.29589319229126 seconds\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Adam Optimizer #0 is created with AVX512 arithmetic capability.\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Config: alpha=0.000020, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m [2/2] c++ flatten_unflatten.o -shared -L/home/ray/anaconda3/lib/python3.9/site-packages/torch/lib -lc10 -ltorch_cpu -ltorch -ltorch_python -o utils.so\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Time to load utils op: 15.381849527359009 seconds\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m Loading extension module utils...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Loading extension module cpu_adam...\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Using /home/ray/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Creating extension directory /home/ray/.cache/torch_extensions/py39_cu118/utils...\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Emitting ninja build file /home/ray/.cache/torch_extensions/py39_cu118/utils/build.ninja...\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Building extension module utils...\u001b[32m [repeated 15x across cluster]\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:29,490] [INFO] [utils.py:785:see_memory_usage] DeepSpeedZeRoOffload initialize [begin]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:29,491] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 0.11 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:29,491] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 8.96 GB, percent = 14.5%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Parameter Offload: Total persistent parameters: 811008 in 114 params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:29,763] [INFO] [utils.py:785:see_memory_usage] DeepSpeedZeRoOffload initialize [end]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:29,764] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 0.11 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:29,764] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 8.96 GB, percent = 14.5%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:30,012] [INFO] [utils.py:785:see_memory_usage] Before creating fp16 partitions\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:30,013] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 0.11 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:30,013] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 8.96 GB, percent = 14.5%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m [1/2] c++ -MMD -MF flatten_unflatten.o.d -DTORCH_EXTENSION_NAME=utils -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\\\"_gcc\\\" -DPYBIND11_STDLIB=\\\"_libstdcpp\\\" -DPYBIND11_BUILD_ABI=\\\"_cxxabi1011\\\" -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/TH -isystem /home/ray/anaconda3/lib/python3.9/site-packages/torch/include/THC -isystem /home/ray/anaconda3/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/ray/anaconda3/lib/python3.9/site-packages/deepspeed/ops/csrc/utils/flatten_unflatten.cpp -o flatten_unflatten.o \u001b[32m [repeated 15x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Loading extension module utils...\u001b[32m [repeated 15x across cluster]\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m [2/2] c++ flatten_unflatten.o -shared -L/home/ray/anaconda3/lib/python3.9/site-packages/torch/lib -lc10 -ltorch_cpu -ltorch -ltorch_python -o utils.so\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m Time to load utils op: 16.94431161880493 seconds\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:31,872] [INFO] [utils.py:785:see_memory_usage] After creating fp16 partitions: 1\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:31,873] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 0.11 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:31,873] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 9.98 GB, percent = 16.1%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,120] [INFO] [utils.py:785:see_memory_usage] Before creating fp32 partitions\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,121] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 0.11 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,121] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 9.98 GB, percent = 16.1%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,624] [INFO] [utils.py:785:see_memory_usage] After creating fp32 partitions\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,624] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 0.11 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,625] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 11.39 GB, percent = 18.4%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,870] [INFO] [utils.py:785:see_memory_usage] Before initializing optimizer states\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,870] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 0.11 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:32,871] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 11.39 GB, percent = 18.4%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:34,834] [INFO] [utils.py:785:see_memory_usage] After initializing optimizer states\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:34,835] [INFO] [utils.py:786:see_memory_usage] MA 0.11 GB Max_MA 0.11 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:34,835] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 16.25 GB, percent = 26.2%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:34,835] [INFO] [stage3.py:392:_setup_for_real_optimizer] optimizer state initialized\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8830, ip=10.0.30.35)\u001b[0m Using /home/ray/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8830, ip=10.0.30.35)\u001b[0m No modifications detected for re-loaded extension module utils, skipping build step...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8830, ip=10.0.30.35)\u001b[0m Loading extension module utils...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Loading extension module utils...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m ***** Running training *****\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Num examples = 10752\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Num Epochs = 9223372036854775807\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Instantaneous batch size per device = 8\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Total train batch size (w. parallel, distributed & accumulation) = 128\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Gradient Accumulation steps = 1\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Total optimization steps = 84\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Number of trainable parameters = 0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8830, ip=10.0.30.35)\u001b[0m Time to load utils op: 0.0005006790161132812 seconds\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m Time to load utils op: 0.0005137920379638672 seconds\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,692] [INFO] [utils.py:785:see_memory_usage] After initializing ZeRO optimizer\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,693] [INFO] [utils.py:786:see_memory_usage] MA 0.14 GB Max_MA 0.91 GB CA 1.54 GB Max_CA 2 GB \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,693] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 17.3 GB, percent = 27.9%\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,694] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Final Optimizer = adamw\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,694] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed using client callable to create LR scheduler\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,694] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed LR Scheduler = \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,694] [INFO] [logging.py:96:log_dist] [Rank 0] step=0, skipped=0, lr=[2e-05], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,695] [INFO] [config.py:955:print] DeepSpeedEngine configuration:\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] activation_checkpointing_config {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"partition_activations\": false, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"contiguous_memory_optimization\": false, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"cpu_checkpointing\": false, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"number_checkpoints\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"synchronize_checkpoint_boundary\": false, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"profile\": false\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] amp_enabled .................. False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] amp_params ................... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] autotuning_config ............ {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"enabled\": false, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"start_step\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"end_step\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"metric_path\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"arg_mappings\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"metric\": \"throughput\", \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"model_info\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"results_dir\": \"autotuning_results\", \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"exps_dir\": \"autotuning_exps\", \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"overwrite\": true, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"fast\": true, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"start_profile_step\": 3, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"end_profile_step\": 5, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"tuner_type\": \"gridsearch\", \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"tuner_early_stopping\": 5, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"tuner_num_trials\": 50, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"model_info_path\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"mp_size\": 1, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"max_train_batch_size\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"min_train_batch_size\": 1, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"max_train_micro_batch_size_per_gpu\": 1.024000e+03, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"min_train_micro_batch_size_per_gpu\": 1, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"num_tuning_micro_batch_sizes\": 3\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] bfloat16_enabled ............. False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] checkpoint_parallel_write_pipeline False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] checkpoint_tag_validation_enabled True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] checkpoint_tag_validation_fail False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] comms_config ................. \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] communication_data_type ...... None\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] curriculum_enabled_legacy .... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] curriculum_params_legacy ..... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] data_efficiency_enabled ...... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,696] [INFO] [config.py:959:print] dataloader_drop_last ......... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] disable_allgather ............ False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] dump_state ................... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] dynamic_loss_scale_args ...... {'init_scale': 256, 'scale_window': 1000, 'delayed_shift': 2, 'min_scale': 1}\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] eigenvalue_enabled ........... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] eigenvalue_gas_boundary_resolution 1\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] eigenvalue_layer_name ........ bert.encoder.layer\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] eigenvalue_layer_num ......... 0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] eigenvalue_max_iter .......... 100\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] eigenvalue_stability ......... 1e-06\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] eigenvalue_tol ............... 0.01\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] eigenvalue_verbose ........... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] elasticity_enabled ........... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] flops_profiler_config ........ {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"enabled\": false, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"profile_step\": 1, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"module_depth\": -1, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"top_modules\": 1, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"detailed\": true, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"output_file\": null\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] fp16_auto_cast ............... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] fp16_enabled ................. True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] fp16_master_weights_and_gradients False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] global_rank .................. 0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] grad_accum_dtype ............. None\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] gradient_accumulation_steps .. 1\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] gradient_clipping ............ 1.0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] gradient_predivide_factor .... 1.0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] hybrid_engine ................ enabled=False max_out_tokens=512 inference_tp_size=1 release_inference_cache=False pin_parameters=True tp_gather_partition_size=8\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] initial_dynamic_scale ........ 256\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] load_universal_checkpoint .... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] loss_scale ................... 0\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] memory_breakdown ............. False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] mics_hierarchial_params_gather False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] mics_shard_size .............. -1\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] monitor_config ............... tensorboard=TensorBoardConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') wandb=WandbConfig(enabled=False, group=None, team=None, project='deepspeed') csv_monitor=CSVConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') enabled=False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] nebula_config ................ {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"enabled\": false, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"persistent_storage_path\": null, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"persistent_time_interval\": 100, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"num_of_version_in_retention\": 2, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"enable_nebula_load\": true, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"load_path\": null\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] optimizer_legacy_fusion ...... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] optimizer_name ............... adamw\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] optimizer_params ............. {'lr': 2e-05, 'betas': [0.9, 0.999], 'eps': 1e-08}\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,697] [INFO] [config.py:959:print] pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] pld_enabled .................. False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] pld_params ................... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] prescale_gradients ........... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] scheduler_name ............... None\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] scheduler_params ............. None\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] sparse_attention ............. None\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] sparse_gradients_enabled ..... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] steps_per_print .............. 10\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] train_batch_size ............. 128\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] train_micro_batch_size_per_gpu 8\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] use_node_local_storage ....... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] wall_clock_breakdown ......... False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] world_size ................... 16\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] zero_allow_untested_optimizer False\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] zero_config .................. stage=3 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=16777216 allgather_partitions=True allgather_bucket_size=500,000,000 overlap_comm=True load_from_fp32_weights=True elastic_checkpoint=False offload_param=DeepSpeedZeroOffloadParamConfig(device='cpu', nvme_path=None, buffer_count=5, buffer_size=100,000,000, max_in_cpu=1,000,000,000, pin_memory=True) offload_optimizer=DeepSpeedZeroOffloadOptimizerConfig(device='cpu', nvme_path=None, buffer_count=4, pin_memory=True, pipeline=False, pipeline_read=False, pipeline_write=False, fast_init=False) sub_group_size=1,000,000,000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=15099494 param_persistence_threshold=40960 model_persistence_threshold=sys.maxsize max_live_parameters=1,000,000,000 max_reuse_distance=1,000,000,000 gather_16bit_weights_on_model_save=True stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=True mics_shard_size=-1 mics_hierarchical_params_gather=False memory_efficient_linear=True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] zero_enabled ................. True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] zero_force_ds_cpu_optimizer .. True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:959:print] zero_optimization_stage ...... 3\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:13:40,698] [INFO] [config.py:945:print_user_config] json = {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"fp16\": {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"enabled\": true, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"initial_scale_power\": 8\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"bf16\": {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"enabled\": false\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"optimizer\": {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"type\": \"AdamW\", \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"params\": {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"lr\": 2e-05, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"betas\": [0.9, 0.999], \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"eps\": 1e-08\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"zero_optimization\": {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"stage\": 3, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"offload_optimizer\": {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"device\": \"cpu\", \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"pin_memory\": true\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"offload_param\": {\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"device\": \"cpu\", \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"pin_memory\": true\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"overlap_comm\": true, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"contiguous_gradients\": true, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"reduce_bucket_size\": 1.677722e+07, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"stage3_prefetch_bucket_size\": 1.509949e+07, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"stage3_param_persistence_threshold\": 4.096000e+04, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"gather_16bit_weights_on_model_save\": true, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"round_robin_gradients\": true\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"gradient_accumulation_steps\": 1, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"gradient_clipping\": 1.0, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"steps_per_print\": 10, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"train_batch_size\": 128, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"train_micro_batch_size_per_gpu\": 8, \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \"wall_clock_breakdown\": false\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m }\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=8980, ip=10.0.60.59) Running 0: 0%| | 0/1 [00:00 TaskPoolMapOperator[MapBatches(split_text)->MapBatches(tokenize)] -> OutputSplitter[split(16, equal=True)]\n", "\u001b[2m\u001b[36m(SplitCoordinator pid=8980, ip=10.0.60.59)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=2000000000.0), locality_with_output=['6002ded0aaa53ce9a0351d22a72b344ef411a422919132f41d9f937a', 'd3bbd390b6fe73f26202f96d75998946cf3e8b457528d426db0c6e07', 'fe6aaf54317ee630a02d23e0d49581b57b5cd51316eaf769e28bb045', 'f7de4694a4f764c05a9c51a6a4bd40ac33f3fced3b25127b25cd4ac3', '42866a2fba4ce2ab4b6645c4d731d486b762e2b23ac24cafccba7096', '8a7272830662c7e756a656de0a9b433a3a1f9b990768f692b6fe11a7', 'bba62e8b57552509c62a6b6b7fd67c1a2280b9d81b3d9c41eb4d1b9b', 'b40764f303538c24bc439106f2e7b2144d382bfed6c9fdec15ab828e', 'd1de4d4b6d44eff93857026df4ef0f70e24e3dc91e15d87015f2ed32', '4d6a9dc1aa7bfc80cb73d9f66f4e28041807f12769391f5643bce143', '8bcc7235f459b61be21fe158d0bae4fef2ec6de013ec60e7aaf7897a', '73c50b995811afa0ece70fd3d4466b7fd0dc85a97d6807128b2c47da', '03bf3d374a9f857b1cd1aebdbe028208f7904b077fb151790e03e9fe', '9f7fc101a7d6b3e17b72e57ca1c92f91d13aa385a6740f99d58ec016', '867844d104a8e9351a1dcc8bbd61d99906a8dc5b53e220c2ae2efbe1', '7677b344c59d6b30c3db451f48e346d61bb60cc798e5567aa4e0a1ea'], preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", "\u001b[2m\u001b[36m(SplitCoordinator pid=8980, ip=10.0.60.59)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n", "\u001b[2m\u001b[36m(MapBatches(split_text)->MapBatches(tokenize) pid=10097, ip=10.0.60.59)\u001b[0m 2023-08-18 18:13:42.547741: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "\u001b[2m\u001b[36m(MapBatches(split_text)->MapBatches(tokenize) pid=10097, ip=10.0.60.59)\u001b[0m To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "\u001b[2m\u001b[36m(MapBatches(split_text)->MapBatches(tokenize) pid=10097, ip=10.0.60.59)\u001b[0m 2023-08-18 18:13:42.685843: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "\u001b[2m\u001b[36m(MapBatches(split_text)->MapBatches(tokenize) pid=10097, ip=10.0.60.59)\u001b[0m 2023-08-18 18:13:43.506819: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(MapBatches(split_text)->MapBatches(tokenize) pid=10097, ip=10.0.60.59)\u001b[0m 2023-08-18 18:13:43.506880: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "\u001b[2m\u001b[36m(MapBatches(split_text)->MapBatches(tokenize) pid=10097, ip=10.0.60.59)\u001b[0m 2023-08-18 18:13:43.506887: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Time to load utils op: 0.0003864765167236328 seconds\u001b[32m [repeated 14x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 12.1235, 'learning_rate': 1.9761904761904763e-05, 'epoch': 0.01}\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 6.7834, 'learning_rate': 1.9523809523809524e-05, 'epoch': 0.02}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8857, ip=10.0.44.114)\u001b[0m {'loss': 2.2151, 'learning_rate': 1.928571428571429e-05, 'epoch': 0.04}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m {'loss': 0.1739, 'learning_rate': 1.904761904761905e-05, 'epoch': 0.05}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +8m53s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.58 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8858, ip=10.0.0.119)\u001b[0m {'loss': 0.121, 'learning_rate': 1.880952380952381e-05, 'epoch': 0.06}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.1422, 'learning_rate': 1.8571428571428575e-05, 'epoch': 0.07}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36249, ip=10.0.11.26)\u001b[0m {'loss': 0.1007, 'learning_rate': 1.8333333333333333e-05, 'epoch': 0.08}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m {'loss': 0.1082, 'learning_rate': 1.8095238095238097e-05, 'epoch': 0.1}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8858, ip=10.0.0.119)\u001b[0m {'loss': 0.094, 'learning_rate': 1.785714285714286e-05, 'epoch': 0.11}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8880, ip=10.0.63.99)\u001b[0m {'loss': 0.0936, 'learning_rate': 1.761904761904762e-05, 'epoch': 0.12}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:18:36,553] [INFO] [logging.py:96:log_dist] [Rank 0] step=10, skipped=0, lr=[1.761904761904762e-05], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:18:36,554] [INFO] [timer.py:199:stop] epoch=0/micro_step=10/global_step=10, RunningAvgSamplesPerSec=4.768458258762969, CurrSamplesPerSec=4.833942877725304, MemAllocated=0.16GB, MaxMemAllocated=8.93GB\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8857, ip=10.0.44.114)\u001b[0m {'loss': 0.0921, 'learning_rate': 1.7380952380952384e-05, 'epoch': 0.13}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 0.0915, 'learning_rate': 1.7142857142857142e-05, 'epoch': 0.14}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m {'loss': 0.0883, 'learning_rate': 1.6904761904761906e-05, 'epoch': 0.15}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m {'loss': 0.0868, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.17}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m {'loss': 0.0815, 'learning_rate': 1.642857142857143e-05, 'epoch': 0.18}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +13m58s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.58 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8875, ip=10.0.0.80)\u001b[0m {'loss': 0.0825, 'learning_rate': 1.6190476190476193e-05, 'epoch': 0.19}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m {'loss': 0.0813, 'learning_rate': 1.5952380952380954e-05, 'epoch': 0.2}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8880, ip=10.0.63.99)\u001b[0m {'loss': 0.0816, 'learning_rate': 1.5714285714285715e-05, 'epoch': 0.21}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 0.0813, 'learning_rate': 1.5476190476190476e-05, 'epoch': 0.23}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m {'loss': 0.0765, 'learning_rate': 1.523809523809524e-05, 'epoch': 0.24}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:23:03,756] [INFO] [logging.py:96:log_dist] [Rank 0] step=20, skipped=0, lr=[1.523809523809524e-05], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:23:03,756] [INFO] [timer.py:199:stop] epoch=0/micro_step=20/global_step=20, RunningAvgSamplesPerSec=4.781402482813706, CurrSamplesPerSec=4.7832870646183325, MemAllocated=0.16GB, MaxMemAllocated=8.93GB\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8858, ip=10.0.0.119)\u001b[0m {'loss': 0.0833, 'learning_rate': 1.5000000000000002e-05, 'epoch': 0.25}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8857, ip=10.0.44.114)\u001b[0m {'loss': 0.084, 'learning_rate': 1.4761904761904763e-05, 'epoch': 0.26}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m {'loss': 0.0839, 'learning_rate': 1.4523809523809524e-05, 'epoch': 0.27}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m {'loss': 0.0825, 'learning_rate': 1.4285714285714287e-05, 'epoch': 0.29}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m {'loss': 0.0838, 'learning_rate': 1.4047619047619048e-05, 'epoch': 0.3}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m {'loss': 0.0847, 'learning_rate': 1.3809523809523811e-05, 'epoch': 0.31}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +18m58s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.58 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 0.0788, 'learning_rate': 1.3571428571428574e-05, 'epoch': 0.32}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8875, ip=10.0.0.80)\u001b[0m {'loss': 0.0832, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.33}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8875, ip=10.0.0.80)\u001b[0m {'loss': 0.0811, 'learning_rate': 1.3095238095238096e-05, 'epoch': 0.35}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.0759, 'learning_rate': 1.2857142857142859e-05, 'epoch': 0.36}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:27:35,516] [INFO] [logging.py:96:log_dist] [Rank 0] step=30, skipped=0, lr=[1.2857142857142859e-05], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:27:35,517] [INFO] [timer.py:199:stop] epoch=0/micro_step=30/global_step=30, RunningAvgSamplesPerSec=4.756191577689035, CurrSamplesPerSec=4.775146730091594, MemAllocated=0.16GB, MaxMemAllocated=8.93GB\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8875, ip=10.0.0.80)\u001b[0m {'loss': 0.0774, 'learning_rate': 1.261904761904762e-05, 'epoch': 0.37}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8858, ip=10.0.0.119)\u001b[0m {'loss': 0.0751, 'learning_rate': 1.2380952380952383e-05, 'epoch': 0.38}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m {'loss': 0.0744, 'learning_rate': 1.2142857142857142e-05, 'epoch': 0.39}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8845, ip=10.0.18.195)\u001b[0m {'loss': 0.0722, 'learning_rate': 1.1904761904761905e-05, 'epoch': 0.4}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8880, ip=10.0.63.99)\u001b[0m {'loss': 0.0742, 'learning_rate': 1.1666666666666668e-05, 'epoch': 0.42}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8857, ip=10.0.44.114)\u001b[0m {'loss': 0.0764, 'learning_rate': 1.1428571428571429e-05, 'epoch': 0.43}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 0.0786, 'learning_rate': 1.1190476190476192e-05, 'epoch': 0.44}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +24m4s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.58 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m {'loss': 0.0738, 'learning_rate': 1.0952380952380955e-05, 'epoch': 0.45}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.0784, 'learning_rate': 1.0714285714285714e-05, 'epoch': 0.46}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m {'loss': 0.0786, 'learning_rate': 1.0476190476190477e-05, 'epoch': 0.48}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:32:06,009] [INFO] [logging.py:96:log_dist] [Rank 0] step=40, skipped=0, lr=[1.0476190476190477e-05], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:32:06,009] [INFO] [timer.py:199:stop] epoch=0/micro_step=40/global_step=40, RunningAvgSamplesPerSec=4.750214082000028, CurrSamplesPerSec=4.781755388354574, MemAllocated=0.16GB, MaxMemAllocated=8.93GB\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8845, ip=10.0.18.195)\u001b[0m {'loss': 0.0714, 'learning_rate': 1.0238095238095238e-05, 'epoch': 0.49}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m {'loss': 0.0739, 'learning_rate': 1e-05, 'epoch': 0.5}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 0.0767, 'learning_rate': 9.761904761904762e-06, 'epoch': 0.51}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.0827, 'learning_rate': 9.523809523809525e-06, 'epoch': 0.52}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m {'loss': 0.0751, 'learning_rate': 9.285714285714288e-06, 'epoch': 0.54}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.0737, 'learning_rate': 9.047619047619049e-06, 'epoch': 0.55}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m {'loss': 0.0755, 'learning_rate': 8.80952380952381e-06, 'epoch': 0.56}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m {'loss': 0.0745, 'learning_rate': 8.571428571428571e-06, 'epoch': 0.57}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m {'loss': 0.0753, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.58}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +29m9s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.59 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8845, ip=10.0.18.195)\u001b[0m {'loss': 0.0739, 'learning_rate': 8.095238095238097e-06, 'epoch': 0.6}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:36:34,033] [INFO] [logging.py:96:log_dist] [Rank 0] step=50, skipped=0, lr=[8.095238095238097e-06], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:36:34,033] [INFO] [timer.py:199:stop] epoch=0/micro_step=50/global_step=50, RunningAvgSamplesPerSec=4.75579745222066, CurrSamplesPerSec=4.705258125568294, MemAllocated=0.16GB, MaxMemAllocated=8.93GB\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m {'loss': 0.073, 'learning_rate': 7.857142857142858e-06, 'epoch': 0.61}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8830, ip=10.0.30.35)\u001b[0m {'loss': 0.0721, 'learning_rate': 7.61904761904762e-06, 'epoch': 0.62}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m {'loss': 0.0729, 'learning_rate': 7.380952380952382e-06, 'epoch': 0.63}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8880, ip=10.0.63.99)\u001b[0m {'loss': 0.0714, 'learning_rate': 7.1428571428571436e-06, 'epoch': 0.64}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.0745, 'learning_rate': 6.9047619047619055e-06, 'epoch': 0.65}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m {'loss': 0.0726, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.67}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m {'loss': 0.0699, 'learning_rate': 6.4285714285714295e-06, 'epoch': 0.68}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m {'loss': 0.0732, 'learning_rate': 6.1904761904761914e-06, 'epoch': 0.69}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m {'loss': 0.0714, 'learning_rate': 5.9523809523809525e-06, 'epoch': 0.7}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 0.0709, 'learning_rate': 5.7142857142857145e-06, 'epoch': 0.71}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:41:07,338] [INFO] [logging.py:96:log_dist] [Rank 0] step=60, skipped=0, lr=[5.7142857142857145e-06], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:41:07,338] [INFO] [timer.py:199:stop] epoch=0/micro_step=60/global_step=60, RunningAvgSamplesPerSec=4.74341422313603, CurrSamplesPerSec=4.640637786972311, MemAllocated=0.16GB, MaxMemAllocated=8.93GB\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +34m9s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.59 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8875, ip=10.0.0.80)\u001b[0m {'loss': 0.071, 'learning_rate': 5.476190476190477e-06, 'epoch': 0.73}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m {'loss': 0.0714, 'learning_rate': 5.2380952380952384e-06, 'epoch': 0.74}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8875, ip=10.0.0.80)\u001b[0m {'loss': 0.0703, 'learning_rate': 5e-06, 'epoch': 0.75}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.0733, 'learning_rate': 4.761904761904762e-06, 'epoch': 0.76}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8845, ip=10.0.18.195)\u001b[0m {'loss': 0.0686, 'learning_rate': 4.523809523809524e-06, 'epoch': 0.77}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8851, ip=10.0.43.240)\u001b[0m {'loss': 0.068, 'learning_rate': 4.2857142857142855e-06, 'epoch': 0.79}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m {'loss': 0.071, 'learning_rate': 4.047619047619048e-06, 'epoch': 0.8}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m {'loss': 0.0708, 'learning_rate': 3.80952380952381e-06, 'epoch': 0.81}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m {'loss': 0.0766, 'learning_rate': 3.5714285714285718e-06, 'epoch': 0.82}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8858, ip=10.0.0.119)\u001b[0m {'loss': 0.0743, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.83}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:45:31,965] [INFO] [logging.py:96:log_dist] [Rank 0] step=70, skipped=0, lr=[3.3333333333333333e-06], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:45:31,965] [INFO] [timer.py:199:stop] epoch=0/micro_step=70/global_step=70, RunningAvgSamplesPerSec=4.757168325507401, CurrSamplesPerSec=4.8146031804109555, MemAllocated=0.16GB, MaxMemAllocated=8.93GB\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8830, ip=10.0.30.35)\u001b[0m {'loss': 0.0752, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.85}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:45:58,184] [INFO] [loss_scaler.py:188:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 256, but hysteresis is 2. Reducing hysteresis to 1\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +39m14s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.59 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8845, ip=10.0.18.195)\u001b[0m {'loss': 0.0717, 'learning_rate': 3.0952380952380957e-06, 'epoch': 0.86}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:46:26,433] [WARNING] [stage3.py:1826:step] 2 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding get_accelerator().empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.0695, 'learning_rate': 2.8571428571428573e-06, 'epoch': 0.87}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36249, ip=10.0.11.26)\u001b[0m {'loss': 0.0709, 'learning_rate': 2.6190476190476192e-06, 'epoch': 0.88}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m {'loss': 0.0729, 'learning_rate': 2.380952380952381e-06, 'epoch': 0.89}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8880, ip=10.0.63.99)\u001b[0m {'loss': 0.0752, 'learning_rate': 2.1428571428571427e-06, 'epoch': 0.9}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8845, ip=10.0.18.195)\u001b[0m {'loss': 0.0712, 'learning_rate': 1.904761904761905e-06, 'epoch': 0.92}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m {'loss': 0.0708, 'learning_rate': 1.6666666666666667e-06, 'epoch': 0.93}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36249, ip=10.0.11.26)\u001b[0m {'loss': 0.0723, 'learning_rate': 1.4285714285714286e-06, 'epoch': 0.94}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8845, ip=10.0.18.195)\u001b[0m {'loss': 0.0689, 'learning_rate': 1.1904761904761906e-06, 'epoch': 0.95}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:50:01,494] [INFO] [logging.py:96:log_dist] [Rank 0] step=80, skipped=1, lr=[1.1904761904761906e-06], mom=[[0.9, 0.999]]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:50:01,494] [INFO] [timer.py:199:stop] epoch=0/micro_step=80/global_step=80, RunningAvgSamplesPerSec=4.756310378443122, CurrSamplesPerSec=4.758170892979721, MemAllocated=0.16GB, MaxMemAllocated=8.93GB\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m {'loss': 0.0715, 'learning_rate': 9.523809523809525e-07, 'epoch': 0.96}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8880, ip=10.0.63.99)\u001b[0m {'loss': 0.07, 'learning_rate': 7.142857142857143e-07, 'epoch': 0.98}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m {'loss': 0.0716, 'learning_rate': 4.7619047619047623e-07, 'epoch': 0.99}\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[1m\u001b[36m(autoscaler +44m19s)\u001b[0m [workspace snapshot] New snapshot created successfully (size: 172.60 MB).\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8880, ip=10.0.63.99)\u001b[0m {'loss': 0.069, 'learning_rate': 2.3809523809523811e-07, 'epoch': 1.0}\u001b[32m [repeated 16x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Saving model checkpoint to output/checkpoint-84\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Configuration saved in output/checkpoint-84/config.json\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Configuration saved in output/checkpoint-84/generation_config.json\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Using /home/ray/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m No modifications detected for re-loaded extension module utils, skipping build step...\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Loading extension module utils...\u001b[32m [repeated 14x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m ***** Running training *****\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Num examples = 10752\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Num Epochs = 9223372036854775807\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Instantaneous batch size per device = 8\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Total train batch size (w. parallel, distributed & accumulation) = 128\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Gradient Accumulation steps = 1\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Total optimization steps = 84\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Number of trainable parameters = 0\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Model weights saved in output/checkpoint-84/pytorch_model.bin\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m tokenizer config file saved in output/checkpoint-84/tokenizer_config.json\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Special tokens file saved in output/checkpoint-84/special_tokens_map.json\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m [2023-08-18 18:52:12,213] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint global_step84 is ready now!\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36249, ip=10.0.11.26)\u001b[0m {'loss': 0.069, 'learning_rate': 2.3809523809523811e-07, 'epoch': 1.0}\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:12,213] [INFO] [logging.py:96:log_dist] [Rank 0] [Torch] Checkpoint global_step84 is about to be saved!\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:12,213] [INFO] [engine.py:3337:save_16bit_model] Saving model weights to output/checkpoint-84/pytorch_model.bin, tag: global_step84\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:12,213] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving output/checkpoint-84/pytorch_model.bin...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m /home/ray/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1802: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=49329)\u001b[0m warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:27,660] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved output/checkpoint-84/pytorch_model.bin.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:27,673] [INFO] [logging.py:96:log_dist] [Rank 0] [Torch] Checkpoint global_step84 is about to be saved!\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:27,684] [INFO] [logging.py:96:log_dist] [Rank 0] Saving model checkpoint: output/checkpoint-84/global_step84/zero_pp_rank_0_mp_rank_00_model_states.pt\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:27,685] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving output/checkpoint-84/global_step84/zero_pp_rank_0_mp_rank_00_model_states.pt...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:27,660] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint global_step84 is ready now!\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m [2023-08-18 18:52:27,685] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving output/checkpoint-84/global_step84/zero_pp_rank_15_mp_rank_00_model_states.pt...\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=9631, ip=10.0.57.153)\u001b[0m [2023-08-18 18:52:32,337] [INFO] [engine.py:3228:_save_zero_checkpoint] zero checkpoint saved output/checkpoint-84/global_step84/zero_pp_rank_14_mp_rank_00_optim_states.pt\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:36,011] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved output/checkpoint-84/global_step84/zero_pp_rank_0_mp_rank_00_optim_states.pt.\u001b[32m [repeated 32x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36675, ip=10.0.13.222)\u001b[0m [2023-08-18 18:52:27,684] [INFO] [logging.py:96:log_dist] [Rank 1] Saving model checkpoint: output/checkpoint-84/global_step84/zero_pp_rank_1_mp_rank_00_model_states.pt\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m [2023-08-18 18:52:27,873] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving output/checkpoint-84/global_step84/zero_pp_rank_3_mp_rank_00_optim_states.pt...\u001b[32m [repeated 30x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=36311, ip=10.0.27.53)\u001b[0m [2023-08-18 18:52:36,193] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint global_step84 is ready now!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8885, ip=10.0.47.209)\u001b[0m \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m /home/ray/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1802: UserWarning: Positional args are being deprecated, use kwargs instead. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8867, ip=10.0.49.236)\u001b[0m warnings.warn(\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "2023-08-18 18:53:44,782\tWARNING syncer.py:853 -- Ray AIR no longer supports the synchronization of checkpoints and other artifacts from worker nodes to the head node. This means that the checkpoints and artifacts saved by trials scheduled on worker nodes will not be accessible during the run (e.g., resuming from a checkpoint after a failure) or after the run (e.g., loading the checkpoint of a trial that ran on an already terminated worker node).\n", "\n", "To fix this issue, configure AIR to use either:\n", "(1) Cloud storage: `RunConfig(storage_path='s3://your/bucket')`\n", "(2) A network filesystem mounted on all nodes: `RunConfig(storage_path='/mnt/path/to/nfs_storage')`\n", "See this Github issue for more details on transitioning to cloud storage/NFS as well as an explanation on why this functionality is being removed: https://github.com/ray-project/ray/issues/37177\n", "If you are already using NFS, you can ignore this warning message.\n", "\n", "Other temporary workarounds:\n", "- If you want to avoid errors/warnings and continue running with syncing explicitly turned off, set `RunConfig(SyncConfig(syncer=None))`\n", "- Or, to re-enable the head node syncing behavior, set the environment variable RAY_AIR_REENABLE_DEPRECATED_SYNC_TO_HEAD_NODE=1\n", " - **Note that this functionality will tentatively be hard-deprecated in Ray 2.7.** See the linked issue for the latest information.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=36262, ip=10.0.52.191)\u001b[0m {'train_runtime': 2355.3551, 'train_samples_per_second': 4.565, 'train_steps_per_second': 0.036, 'train_loss': 0.32820896875290645, 'epoch': 1.0}\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m [2023-08-18 18:52:36,012] [INFO] [engine.py:3228:_save_zero_checkpoint] zero checkpoint saved output/checkpoint-84/global_step84/zero_pp_rank_0_mp_rank_00_optim_states.pt\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8875, ip=10.0.0.80)\u001b[0m [2023-08-18 18:52:36,193] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint global_step84 is ready now!\u001b[32m [repeated 15x across cluster]\u001b[0m\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m \u001b[32m [repeated 60x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=8911, ip=10.0.60.59)\u001b[0m Training completed. Do not forget to share your model on huggingface.co/models =)\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "2023-08-18 18:54:02,594\tINFO tune.py:1146 -- Total run time: 2691.03 seconds (2676.82 seconds for the tuning loop).\n" ] } ], "source": [ "results = trainer.fit()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use the returned {class}`~ray.train.Result` object to access metrics and the Ray Train {class}`~ray.train.Checkpoint` associated with the last iteration." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Checkpoint(filesystem=, path=anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/org_7c1Kalm9WcX2bNIjW53GUT/cld_kvedZWag2qA8i5BjxUevf5i7/artifact_storage/yunxuan__xiao/gptj-deepspeed-finetune/TorchTrainer_2023-08-18_18-09-11/TorchTrainer_01ea5_00000_0_2023-08-18_18-09-12/checkpoint_000000)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint = results.checkpoint\n", "checkpoint" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(gptj-predict)=\n", "### Generate text from prompt\n", "\n", "First, download the persistent Ray Train checkpoint locally and load the fine-tuned model weights and tokenizer from the checkpoint. Then use 🤗 Transformers [`pipeline`](https://huggingface.co/docs/transformers/en/main_classes/pipelines) to generate predictions from the fine-tuned model.\n", "\n", "```{tip}\n", "For large scale batch inference, see {ref}`End-to-end: Offline Batch Inference `.\n", "```" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "!aws configure set s3.max_concurrent_requests 32\n", "!aws configure set default.s3.preferred_transfer_client crt\n", "!aws configure set default.s3.target_bandwidth 100Gb/s\n", "!aws configure set default.s3.multipart_chunksize 8MB" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "\n", "os.system(f\"aws s3 sync s3://{checkpoint.path} /mnt/local_storage/\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set the `task` to `\"text-generation\"`, and also set `device_map=\"auto\"` for Ray Train to automatically place the model on the right device. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [] }, "outputs": [], "source": [ "from transformers import pipeline, AutoTokenizer, GPTJForCausalLM\n", "\n", "model = GPTJForCausalLM.from_pretrained(\"/mnt/local_storage/checkpoint\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"/mnt/local_storage/checkpoint\")\n", "\n", "pipe = pipeline(\n", " model=model,\n", " tokenizer=tokenizer,\n", " task=\"text-generation\",\n", " torch_dtype=torch.float16,\n", " device_map=\"auto\",\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'generated_text': 'Romeo and Juliet. This very night shall they come. A word with you, sir.'}]\n", "[{'generated_text': 'Romeo! I know thee not. Lord Mercutio, is it you! Signior Montague.'}]\n", "[{'generated_text': 'Juliet, look up in the vault, and there shalt find a grave; within the monument there is a table:'}]\n" ] } ], "source": [ "# Generate from prompts!\n", "for sentence in pipe(\n", " [\"Romeo and Juliet\", \"Romeo\", \"Juliet\"], do_sample=True, min_length=20\n", "):\n", " print(sentence)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" }, "orphan": true, "vscode": { "interpreter": { "hash": "3c0d54d489a08ae47a06eae2fd00ff032d6cddb527c382959b7b2575f6a8167f" } } }, "nbformat": 4, "nbformat_minor": 4 }