Llama model pre-training on Intel Gaudi#

In this Jupyter notebook, we will pre-train a huggyllama/llama-7b model by using Intel Gaudi accelerators.

We will use PyTorch for model training and Ray for distributed training.

Intel Gaudi AI Processors (HPUs) are AI hardware accelerators designed by Habana Labs. For more information, see Gaudi Architecture and Gaudi Developer Docs.

Basic features for this pre-training example are:

Prepare environment#

This example run on single node with 4 HPUs.

We recommend using a prebuilt container to run these examples. To run a container, you need Docker. See Install Docker Engine for installation instructions.

Next, follow Run Using Containers to install the Habana drivers and container runtime.

Get docker image#

# more available docker image can be found here: https://vault.habana.ai/ui/native/gaudi-docker
docker pull vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest

Run docker image#

docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest
# maybe should mapping your workspace volumns

Install dependency#

# "optimum-habana>1.11.1" if exection mode "eager" or "eager.compile" 
# "ray>=2.20.0"
pip install ray[train] notebook transformers datasets evaluate peft accelerate scikit-learn optimum-habana

# install deepspeed
pip install git+https://github.com/HabanaAI/[email protected]

# this notebook verfied with packages' version:
# transformers==4.38.2
# datasets==2.19.1
# evaluate==0.4.2
# peft==0.4.0
# accelerate==0.27.2
# scikit-learn==1.4.2
# optimum-habana==1.11.1

# deepspeed==0.12.4+hpu.synapse.v1.15.0

Import necessary libraries#

#!/usr/bin/env python

import os
from typing import Any, Dict
from torch.utils.data import DataLoader

import transformers
from itertools import chain
from datasets import load_dataset
from transformers import default_data_collator
from transformers.testing_utils import CaptureLogger
from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments
from optimum.habana.utils import set_seed

Build datasets#

Download and load dataset from huggingface.co

def load_datasets(config):
    dataset_name = config["name"] 
    dataset_config_name = config["config_name"]

    # Downloading and loading a dataset from the hub.
    raw_datasets = load_dataset(
        dataset_name,
        dataset_config_name,
        cache_dir=None,
        token=None,
        streaming=False,
    )
    if "validation" not in raw_datasets.keys():
        raw_datasets["validation"] = load_dataset(
            dataset_name,
            dataset_config_name,
            split=f"train[:{data_args.validation_split_percentage}%]",
            cache_dir=None,
            token=None,
            streaming=False,
        )
        raw_datasets["train"] = load_dataset(
            dataset_name,
            dataset_config_name,
            split=f"train[{data_args.validation_split_percentage}%:]",
            cache_dir=None,
            token=None,
            streaming=False,
        )

    return raw_datasets

Load tokenizer#

Download vocabulary from huggingface.co.

def load_tokenizer(config):
    name = config["name"]
    tokenizer_kwargs = {
        "cache_dir": None,
        "use_fast": True,
        "revision": "main",
        "token": None,
        "trust_remote_code": False,
    }
    return transformers.AutoTokenizer.from_pretrained(name, **tokenizer_kwargs)

Tokenize dataset#

tokenize word to token ids.

def tokenize_dataset(datasets, tokenizer):
    column_names = list(datasets["train"].features)
    text_column_name = "text" if "text" in column_names else column_names[0]

    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
                " before being passed to the model."
            )
        return output

    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        num_proc=None,
        remove_columns=column_names,
        load_from_cache_file=True,
        desc="Running tokenizer on dataset",
    )

    return tokenized_datasets

Group dataset#

This preprocssing will concatenate all texts from our dataset and generate chunks of block_size, and will pre-train model much faster.

def group_dataset(config, datasets, tokenizer):
    config_name = config["name"]
    auto_config = transformers.AutoConfig.from_pretrained(config_name)
    max_pos_embeddings = auto_config.max_position_embeddings
    block_size = tokenizer.model_max_length
    if block_size > max_pos_embeddings:
        print(
            f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
            f"Using block_size={min(1024, max_pos_embeddings)} instead. You can change that default value by passing --block_size xxx."
        )
        if max_pos_embeddings > 0:
            block_size = min(1024, max_pos_embeddings)
        else:
            block_size = 1024

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
        total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    lm_datasets = datasets.map(
        group_texts,
        batched=True,
        num_proc=None,
        load_from_cache_file=True,
        desc=f"Grouping texts in chunks of {block_size}",
    )
    return lm_datasets

Load model#

Download and load pre-configed model from huggingface.co, the detail model configurations in config.json

def load_model(config):
    name = config["name"]
    model_config = config.get("config", {})
    auto_config = transformers.AutoConfig.from_pretrained(
        pretrained_model_name_or_path=name, **model_config
    )
    model = transformers.AutoModelForCausalLM.from_config(auto_config, trust_remote_code=False)

    return model

Prepare trainer#

Instance Trainer with model, gaudi_config, training_args, tokenizer

No evaluation dataset passed, just training.

def get_trainer(training_args, datasets, tokenizer, model):
    gaudi_config = GaudiConfig.from_pretrained(
        training_args.gaudi_config_name, revision="main",
    )

    trainer = GaudiTrainer(
        model=model,
        gaudi_config=gaudi_config,
        args=training_args,
        train_dataset=datasets["train"],
        eval_dataset=None,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
    )
    return trainer

Training Function#

This function will be executed by each worker during training, with following steps:

  • prepare GaudiTrainingArguments object.

  • load datasets from huggingface.co.

  • load pre-configed tokenizer from huggingface.co.

  • tokenize dataset with loaded model tokenizer.

  • concatenate all texts from our dataset and generate chunks of block_size.

  • instance object of GaudiTrainer with training_args, datasets, tokenizer, and model.

  • call train of trainer.

  • save model.

def pretrain_llama(config: Dict[str, Any]):

    training_args = GaudiTrainingArguments(**config["training_args"])
    set_seed(training_args.seed)

    raw_datasets = load_datasets(config["datasets"])

    tokenizer = load_tokenizer(config["tokenizer"])

    tokenized_datasets = tokenize_dataset(raw_datasets, tokenizer)

    tokenized_datasets = group_dataset(config["model"], tokenized_datasets, tokenizer)

    model = load_model(config["model"])

    trainer = get_trainer(training_args, tokenized_datasets, tokenizer, model)

    result = trainer.train()
    trainer.save_model()
    print(result)

Main Training Function#

The main function sets up the distributed training environment using Ray and starts the training process. To enable training using HPU, we only need to make the following changes:

  • Set the exectuion mode for training, supported execution mode are:

    • “lazy”: Deferred execution of graphs, comprising of ops delivered from script op by op similar to Eager mode. It gives the Eager mode experience with performance on Gaudi. Unlike Eager Mode with torch.compile, graph is analyzed in each iteration leading to a higher CPU usage.

    • “eager”: Op-by-op execution as defined in standard PyTorch Eager mode scripts.

    • “eager.compile”: Eager mode extended with torch.compile - Similar to Eager mode but extended with wrapping complete or part of model (such as a function) into a graph. Parts that are not wrapped are executed eagerly.

    More detail theory can be found here, and detail performance results can be found here

  • Require an HPU for each worker in ScalingConfig

  • Set backend to hccl in TorchConfig

def main(num_workers, execution_mode):
    import ray
    from ray.train import ScalingConfig
    from ray.train.torch import TorchTrainer, TorchConfig

    pretrain_config = {
        "datasets": {
            "name": "wikitext",
            "config_name": "wikitext-2-raw-v1",
        },
        "tokenizer": {
            "name": "huggyllama/llama-7b",
            "config": {}
        },
        "model": {
            "name": "huggyllama/llama-7b",
            "config": {
                "torch_dtype": "bfloat16",
            },
        },
        "training_args": {
            "per_device_train_batch_size": 1,
            "do_train": True,
            "save_strategy": "no",
            "output_dir": "/tmp/ray/pretrain-llama-2",
            "logging_steps": 1,
            "gaudi_config_name": "Habana/llama",
            "use_habana": True,
            "throughput_warmup_steps": 3,
            "use_lazy_mode": True,
            "overwrite_output_dir": True,
            "seed": 42,
            "bf16": True,
            "report_to":'tensorboard',
            "deepspeed": {
                "steps_per_print": 64,
                "train_batch_size": "auto",
                "train_micro_batch_size_per_gpu": "auto",
                "gradient_accumulation_steps": "auto",
                "bf16": {
                    "enabled": True
                },
                "gradient_clipping": 1.0,
                "zero_optimization": {
                    "stage": 3,
                    "overlap_comm": False,
                    "reduce_scatter": False,
                    "contiguous_gradients": False,
                    "stage3_gather_16bit_weights_on_model_save": True
                }
            },
        },
    }

    # if execution mode is eager with compile, must spcified with a compile backend
    if execution_mode == "eager.compile":
        pretrain_config["training_args"].update({"torch_compile_backend": "hpu_backend"})

    scaling_config = ScalingConfig(num_workers=num_workers,
                                   use_gpu=False,
                                   resources_per_worker={"CPU": 1, "HPU": 1})

    # Set backend to hccl in TorchConfig
    torch_config = TorchConfig(backend="hccl")

    ray.init()

    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=pretrain_llama,
        train_loop_config=pretrain_config,
        torch_config=torch_config,
        scaling_config=scaling_config
    )

    result = trainer.fit()
    print(result)

Start Training#

Finally, we call the main function to start the pre-training process.

Before calling main function, you must set some environment variables.

  1. The visiable devices. Environment variable HABANA_VISIBLE_DEVICES and HABANA_VISIBLE_MODULES are used to control the HPU device visiable to application, you must set this two environment variable properly. For more detail usage of HABANA_VISIBLE_DEVICES, HABANA_VISIBLE_MODULES, please visit here

  2. The execution mode. Different execution mode has different runtime performance. The default execution mode is lazy mode.

# set some environment variables
os.environ["RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES"] = "0"
# if using RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES env var
# you must set HABANA_VISIBLE_MODULES, such as
# os.environ["HABANA_VISIBLE_MODULES"] = "0,1,2,3"

# execution_mode are ["lazy", "eager", "eager.compile"]
execution_mode = "lazy"
os.environ["PT_HPU_LAZY_MODE"] = "1" if execution_mode == "lazy" else "0"

main(num_workers=8, execution_mode=execution_mode)

Possible outputs#

...

(RayTrainWorker pid=289322) Setting up process group for: env:// [rank=0, world_size=8]
(TorchTrainer pid=288676) Started distributed worker processes: 
(TorchTrainer pid=288676) - (ip=100.83.111.228, pid=289322) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=288676) - (ip=100.83.111.228, pid=289323) world_rank=1, local_rank=1, node_rank=0
(TorchTrainer pid=288676) - (ip=100.83.111.228, pid=289324) world_rank=2, local_rank=2, node_rank=0
(TorchTrainer pid=288676) - (ip=100.83.111.228, pid=289325) world_rank=3, local_rank=3, node_rank=0
(TorchTrainer pid=288676) - (ip=100.83.111.228, pid=289327) world_rank=4, local_rank=4, node_rank=0
(TorchTrainer pid=288676) - (ip=100.83.111.228, pid=289326) world_rank=5, local_rank=5, node_rank=0
(TorchTrainer pid=288676) - (ip=100.83.111.228, pid=289328) world_rank=6, local_rank=6, node_rank=0
(TorchTrainer pid=288676) - (ip=100.83.111.228, pid=289329) world_rank=7, local_rank=7, node_rank=0

...

(RayTrainWorker pid=289322) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(RayTrainWorker pid=289322)  PT_HPU_LAZY_MODE = 1
(RayTrainWorker pid=289322)  PT_RECIPE_CACHE_PATH = 
(RayTrainWorker pid=289322)  PT_CACHE_FOLDER_DELETE = 0
(RayTrainWorker pid=289322)  PT_HPU_RECIPE_CACHE_CONFIG = 
(RayTrainWorker pid=289322)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(RayTrainWorker pid=289322)  PT_HPU_LAZY_ACC_PAR_MODE = 1
(RayTrainWorker pid=289322)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(RayTrainWorker pid=289322) ---------------------------: System Configuration :---------------------------
(RayTrainWorker pid=289322) Num CPU Cores : 152
(RayTrainWorker pid=289322) CPU RAM       : 1056440348 KB
(RayTrainWorker pid=289322) ------------------------------------------------------------------------------

...

(RayTrainWorker pid=289322) {'loss': 11.1784, 'grad_norm': 11.160387992858887, 'learning_rate': 4.9903660886319845e-05, 'epoch': 0.01, 'memory_allocated (GB)': 26.34, 'max_memory_allocated (GB)': 66.83, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 11.1116, 'grad_norm': 11.13752555847168, 'learning_rate': 4.9807321772639694e-05, 'epoch': 0.01, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 71.35, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 10.8931, 'grad_norm': 11.067651748657227, 'learning_rate': 4.971098265895954e-05, 'epoch': 0.02, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 75.01, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 10.3421, 'grad_norm': 10.925484657287598, 'learning_rate': 4.9614643545279386e-05, 'epoch': 0.02, 'memory_allocated (GB)': 27.33, 'max_memory_allocated (GB)': 75.08, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 10.007, 'grad_norm': 9.689080238342285, 'learning_rate': 4.9518304431599236e-05, 'epoch': 0.03, 'memory_allocated (GB)': 27.33, 'max_memory_allocated (GB)': 75.08, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 9.8195, 'grad_norm': 18.040328979492188, 'learning_rate': 4.942196531791908e-05, 'epoch': 0.03, 'memory_allocated (GB)': 27.33, 'max_memory_allocated (GB)': 75.14, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 9.6815, 'grad_norm': 29.881019592285156, 'learning_rate': 4.932562620423892e-05, 'epoch': 0.04, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 75.14, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 9.4898, 'grad_norm': 12.468446731567383, 'learning_rate': 4.922928709055877e-05, 'epoch': 0.05, 'memory_allocated (GB)': 27.31, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 9.5611, 'grad_norm': 8.117713928222656, 'learning_rate': 4.913294797687861e-05, 'epoch': 0.05, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 9.2297, 'grad_norm': 14.138890266418457, 'learning_rate': 4.903660886319846e-05, 'epoch': 0.06, 'memory_allocated (GB)': 27.35, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 9.0812, 'grad_norm': 7.828359127044678, 'learning_rate': 4.894026974951831e-05, 'epoch': 0.06, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 9.9278, 'grad_norm': 40.32044219970703, 'learning_rate': 4.8843930635838154e-05, 'epoch': 0.07, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 8.5225, 'grad_norm': 7.01698637008667, 'learning_rate': 4.8747591522157996e-05, 'epoch': 0.08, 'memory_allocated (GB)': 27.36, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 8.3957, 'grad_norm': 9.207005500793457, 'learning_rate': 4.8651252408477846e-05, 'epoch': 0.08, 'memory_allocated (GB)': 27.33, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 8.3269, 'grad_norm': 15.509377479553223, 'learning_rate': 4.855491329479769e-05, 'epoch': 0.09, 'memory_allocated (GB)': 27.34, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 8.392, 'grad_norm': 11.741216659545898, 'learning_rate': 4.845857418111754e-05, 'epoch': 0.09, 'memory_allocated (GB)': 27.36, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 8.341, 'grad_norm': 13.54684066772461, 'learning_rate': 4.836223506743739e-05, 'epoch': 0.1, 'memory_allocated (GB)': 27.33, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 8.132, 'grad_norm': 6.200448513031006, 'learning_rate': 4.826589595375723e-05, 'epoch': 0.1, 'memory_allocated (GB)': 27.31, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.8102, 'grad_norm': 5.493015766143799, 'learning_rate': 4.816955684007707e-05, 'epoch': 0.11, 'memory_allocated (GB)': 27.3, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.6805, 'grad_norm': 7.432443141937256, 'learning_rate': 4.807321772639692e-05, 'epoch': 0.12, 'memory_allocated (GB)': 27.33, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.6716, 'grad_norm': 18.697616577148438, 'learning_rate': 4.7976878612716764e-05, 'epoch': 0.12, 'memory_allocated (GB)': 27.34, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.531, 'grad_norm': 9.172748565673828, 'learning_rate': 4.7880539499036607e-05, 'epoch': 0.13, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.4479, 'grad_norm': 7.693913459777832, 'learning_rate': 4.7784200385356456e-05, 'epoch': 0.13, 'memory_allocated (GB)': 27.34, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.4504, 'grad_norm': 4.102222442626953, 'learning_rate': 4.7687861271676305e-05, 'epoch': 0.14, 'memory_allocated (GB)': 27.34, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.2147, 'grad_norm': 4.539271831512451, 'learning_rate': 4.759152215799615e-05, 'epoch': 0.14, 'memory_allocated (GB)': 27.37, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.2953, 'grad_norm': 4.624892711639404, 'learning_rate': 4.7495183044316e-05, 'epoch': 0.15, 'memory_allocated (GB)': 27.37, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.279, 'grad_norm': 3.8493056297302246, 'learning_rate': 4.739884393063584e-05, 'epoch': 0.16, 'memory_allocated (GB)': 27.37, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.2769, 'grad_norm': 3.396097183227539, 'learning_rate': 4.730250481695568e-05, 'epoch': 0.16, 'memory_allocated (GB)': 27.31, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.2125, 'grad_norm': 4.0201640129089355, 'learning_rate': 4.720616570327553e-05, 'epoch': 0.17, 'memory_allocated (GB)': 27.34, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.1199, 'grad_norm': 4.433038234710693, 'learning_rate': 4.710982658959538e-05, 'epoch': 0.17, 'memory_allocated (GB)': 27.35, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 7.0391, 'grad_norm': 2.8623831272125244, 'learning_rate': 4.7013487475915223e-05, 'epoch': 0.18, 'memory_allocated (GB)': 27.34, 'max_memory_allocated (GB)': 79.56, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 6.8758, 'grad_norm': 3.1782188415527344, 'learning_rate': 4.6917148362235066e-05, 'epoch': 0.18, 'memory_allocated (GB)': 27.29, 'max_memory_allocated (GB)': 93.29, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 6.6878, 'grad_norm': 2.3016743659973145, 'learning_rate': 4.6820809248554915e-05, 'epoch': 0.19, 'memory_allocated (GB)': 27.37, 'max_memory_allocated (GB)': 93.29, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 6.637, 'grad_norm': 4.136375904083252, 'learning_rate': 4.672447013487476e-05, 'epoch': 0.2, 'memory_allocated (GB)': 27.33, 'max_memory_allocated (GB)': 93.29, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 6.8968, 'grad_norm': 3.34140682220459, 'learning_rate': 4.662813102119461e-05, 'epoch': 0.2, 'memory_allocated (GB)': 27.35, 'max_memory_allocated (GB)': 93.29, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 6.9145, 'grad_norm': 2.7163383960723877, 'learning_rate': 4.653179190751446e-05, 'epoch': 0.21, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 93.29, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 6.7147, 'grad_norm': 2.5218122005462646, 'learning_rate': 4.64354527938343e-05, 'epoch': 0.21, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 93.29, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 6.7815, 'grad_norm': 3.993046522140503, 'learning_rate': 4.633911368015414e-05, 'epoch': 0.22, 'memory_allocated (GB)': 27.32, 'max_memory_allocated (GB)': 93.29, 'total_memory_available (GB)': 94.62}
(RayTrainWorker pid=289322) {'loss': 6.8765, 'grad_norm': 2.5143563747406006, 'learning_rate': 4.624277456647399e-05, 'epoch': 0.23, 'memory_allocated (GB)': 27.34, 'max_memory_allocated (GB)': 93.29, 'total_memory_available (GB)': 94.62}

...