LLM Training

From Research Computing Center Wiki
Jump to navigation Jump to search

Introduction

HuggingFace

Hub

Libraries

Transformers
module load Transformers
Datasets
module load datasets
TRL

(software module coming soon, can be installed in venv meanwhile)

python3 -m venv ~/trl_venv
source ~/trl_venv/bin/activate
pip install --require-virtualenv trl

Compute Resources

52,002 instruct articles for 3 epochs against Meta-Llama-3-8B loaded in 4bit with the PEFT library

Tested Accelerators
Vendor Product Backend VRAM (GB)
Nvidia L4 CUDA 24
A100 80
H100
AMD MI210 ROCm 64
LLM Training Compute Resource Consumption
# of Acc. Acc. Hardware Training Duration Notes VRAM Usage (GB / %)
1x Nvidia L4 Pending
3x Nvidia L4
4x Nvidia L4
1x AMD MI210
3x AMD MI210
1x Nvidia A100
1x Pending PDBS: 1
2h50m25s PDBS: 5 ~70.704 / 88.38%
2h32m32s PDBS: 7 ~76.472 / 95.59%
3x 1h14s PDBS: 7 ~224.384 / 93.49%
4x 47m15s PDBS: 7 ~306.824 / 95.88%
1x Nvidia H100 Planned
3x
4x

Training Script (w/HuggingFace)

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_from_disk
from trl import SFTTrainer, AutoModelForCausalLMWithValueHead, ModelConfig, get_peft_config, get_quantization_config, get_kbit_device_map, DataCollatorForCompletionOnlyLM
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
 
base_model = "/scratch/$$YOUR_MYID/llm/models/hf/Meta-Llama-3-8B"
output_dir = "/lscratch/$$YOUR_MYID/guac0/"

report_to = "wandb"

attn_implementation = "flash_attention_2"

def prompt_formatting_func(self, article):
    output_texts = []

    for i in range(len(article['hash'])):
        text = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n{% for message in messages %}\n{% if message['role'] == 'system' %}### Instruction:\n{% elif message['role'] == 'user' %}### Input:\n{% elif message['role'] == 'assistant' %}### Response:\n{% endif %}{{message['content']}}\n{% endfor %}\n### Response:\n".render(messages = article['messages'][i])
        output_texts.append(text)
    return output_texts

if __name__ == "__main__":
    # Basic model config
    model_config = ModelConfig(
        model_name_or_path      = base_model,
        attn_implementation     = attn_implementation,
    )
    quant_config = get_quantization_config(model_config)

    model_kwargs = dict(
        torch_dtype         = "auto",
        load_in_4bit        = True,
        trust_remote_code   = False, # Don't
        attn_implementation = attn_implementation,
        use_cache           = False, # false if grad chkpnting
        quantization_config = get_quantization_config(model_config),
        device_map          = get_kbit_device_map(),
    )

    # Load model & tokenizer
    tokenizer  = AutoTokenizer.from_pretrained(model_path)
    model      = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)

    tokenizer.pad_token = tokenizer.eos_token

    lora_config = LoraConfig(
        r               = 64,
        lora_alpha      = 16,
        lora_dropout    = 0.05,
        bias            = "none",
        task_type       = "CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)

    train_dataset = load_from_disk("/scratch/ks98810/llm/datasets/guac-merge0")
    training_args = TrainingArguments(
        logging_strategy            = "steps",
        logging_steps               = 500,
        logging_first_step          = True,
        report_to                   = report_to,
        num_train_epochs            = 3,
        output_dir                  = kwargs.get("output_dir", defaults["output_path"]),
        per_device_train_batch_size = 1,
        learning_rate               = 2e-4,
    )
    response_template = "### Response:\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)
    trainer = SFTTrainer(
        model,
        args                = training_args,
        train_dataset       = train_dataset,
        dataset_text_field  = "text",
        max_seq_length      = 4096,
        peft_config         = lora_config,
        formatting_func     = prompt_formatting_func,
        data_collator       = collator,
    )

    trainer.train()
    trainer.save_model(kwargs.get("output_dir", output_dir))

Job Submission Script

#!/usr/bin/env bash
#SBATCH --job-name=train_guac0_1xA100
#SBATCH --cpus-per-task=16
#SBATCH --partition=gpu_p
#SBATCH --gres=gpu:A100:1
#SBATCH --ntasks=1
#SBATCH --mem=64gb
#SBATCH --time=03:00:00
#SBATCH --output=logs/%x.%j.out
#SBATCH --error=logs/%x.%j.err

#SBATCH --mail-type=ALL
#SBATCH --mail-user=$$YOUR_MYID@uga.edu

export JOB_CUSTODIAN="$$YOUR_MYID"
export JOB_GROUP="$$YOUR_LAB"

export PROJECT_DIR="/work/$JOB_GROUP/$JOB_CUSTODIAN/"
export SCRATCH_DIR="/scratch/$JOB_CUSTODIAN/"
export LSCRATCH_DIR="/lscratch/$JOB_CUSTODIAN/"

export PROJECT_NAME="guac0"
export PROJECT_VARIANT="flash-attn0"

export PROJECT_TITLE="$PROJECT_NAME.$PROJECT_VARIANT.$SLURM_JOBID"

export RESULT_DEPOT="$SCRATCH_DIR/llm/models/hf/$PROJECT_TITLE"
export TRAINING_OUTPUT="$LSCRATCH_DIR/$PROJECT_TITLE"

export TRAINING_BASE_MODEL="Meta-Llama-3-8B"

export OMP_NUM_THREADS=16
export PER_DEVICE_BATCH_SIZE=1
export GPUS_PER_NODE=1
export TRAINING_EPOCHS=3

export TRAINING_VENV="/scratch/$$YOUR_MYID/llm/projects/workbench/venv/"
export TRAINING_SCRIPT="/scratch/$$YOUR_MYID/llm/projects/guac/scripts/training/train_guac0.py"
export TRAINING_ARGS="-b $PER_DEVICE_BATCH_SIZE -m $TRAINING_BASE_MODEL -o $TRAINING_OUTPUT -e $TRAINING_EPOCHS -s $MAX_SEQ_LENGTH"

export WANDB_PROJECT="$PROJECT_NAME"
export WANDB_LOG_MODEL="checkpoint"
export WANDB_JOB_TYPE="training"
export WANDB_NAME="$PROJECT_TITLE"

export CUDA_VERSION="12.1.1"
export RDZV_BACKEND="c10d"
export RDZV_ID=2299
export RDZV_PORT=29500

cd $SLURM_SUBMIT_DIR

module load CUDA/$CUDA_VERSION diffusers ccache wandb flash-attn

head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)

export LAUNCHER="python -m torch.distributed.run \
        --nproc_per_node $GPUS_PER_NODE \
        --nnodes $SLURM_NNODES \
        --rdzv_id $RDZV_ID \
        --rdzv_backend $RDZV_BACKEND \
        --rdzv_endpoint $head_node_ip:$RDZV_PORT \
"

source $TRAINING_VENV/bin/activate

export CMD="$LAUNCHER $TRAINING_SCRIPT $TRAINING_ARGS"
srun --jobid $SLURM_JOB_ID bash -c "$CMD"

deactivate
rsync -r $TRAINING_OUTPUT $RESULT_DEPOT