LLM Training

From Research Computing Center Wiki
Jump to navigation Jump to search

Introduction

HuggingFace

Hub

Libraries

Transformers
module load Transformers
Datasets
module load datasets
TRL

(software module coming soon, can be installed in venv meanwhile)

python3 -m venv ~/trl_venv
source ~/trl_venv/bin/activate
pip install --require-virtualenv trl

Compute Resources

LLM Training Compute Resource Consumption
Identifier Accelerator Resources Methods Training Duration Notes
alpaca.1xMIG.A100 LoRA (4-bit), FlashAttention2 Planned
alpaca.1xL4 1 * Nvidia L4 (24GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.3xL4 3 * Nvidia L4 (24GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.4xL4 4 * Nvidia L4 (24GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.1xMI210 1 * AMD MI210 (64GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.3xMI210 3 * AMD MI210 (64GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.1xA100 1 * Nvidia A100 (80GB VRAM each) LoRA (4-bit) Pending
alpaca.1xA100 1 * Nvidia A100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 PDBS: 1
alpaca.3xA100 1 * Nvidia A100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 PDBS: 1
alpaca.4xA100 1 * Nvidia A100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 PDBS: 1
alpaca.1xH100 1 * Nvidia H100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 Planned
alpaca.3xH100 1 * Nvidia H100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 Planned
alpaca.4xH100 1 * Nvidia H100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 Planned

Training Script (w/HuggingFace)

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_from_disk
from trl import SFTTrainer, ORPOTrainer, AutoModelForCausalLMWithValueHead, ModelConfig, get_peft_config, get_quantization_config, get_kbit_device_map, DataCollatorForCompletionOnlyLM
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
 
base_model = "/scratch/$$YOUR_MYID/llm/models/hf/Meta-Llama-3-8B"
output_dir = "/lscratch/$$YOUR_MYID/guac0/"

report_to = "wandb"

attn_implementation = "flash_attention_2"

def prompt_formatting_func(self, article):
    output_texts = []

    for i in range(len(article['hash'])):
        text = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n{% for message in messages %}\n{% if message['role'] == 'system' %}### Instruction:\n{% elif message['role'] == 'user' %}### Input:\n{% elif message['role'] == 'assistant' %}### Response:\n{% endif %}{{message['content']}}\n{% endfor %}\n### Response:\n".render(messages = article['messages'][i])
        output_texts.append(text)
    return output_texts

if __name__ == "__main__":
    # Basic model config
    model_config = ModelConfig(
        model_name_or_path      = base_model,
        attn_implementation     = attn_implementation,
    )
    quant_config = get_quantization_config(model_config)

    model_kwargs = dict(
        torch_dtype         = "auto",
        load_in_4bit        = True,
        trust_remote_code   = False, # Don't
        attn_implementation = attn_implementation,
        use_cache           = False, # false if grad chkpnting
        quantization_config = get_quantization_config(model_config),
        device_map          = get_kbit_device_map(),
    )

    # Load model & tokenizer
    tokenizer  = AutoTokenizer.from_pretrained(model_path)
    model      = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)

    tokenizer.pad_token = tokenizer.eos_token

    lora_config = LoraConfig(
        r               = 64,
        lora_alpha      = 16,
        lora_dropout    = 0.05,
        bias            = "none",
        task_type       = "CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)

    train_dataset = load_from_disk("/scratch/ks98810/llm/datasets/guac-merge0")
    training_args = TrainingArguments(
        logging_strategy            = "steps",
        logging_steps               = 500,
        logging_first_step          = True,
        report_to                   = report_to,
        num_train_epochs            = 3,
        output_dir                  = kwargs.get("output_dir", defaults["output_path"]),
        per_device_train_batch_size = 1,
        learning_rate               = 2e-4,
    )
    response_template = "### Response:\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)
    trainer = SFTTrainer(
        model,
        args                = training_args,
        train_dataset       = train_dataset,
        dataset_text_field  = "text",
        max_seq_length      = 4096,
        peft_config         = lora_config,
        formatting_func     = prompt_formatting_func,
        data_collator       = collator,
    )

    trainer.train()
    trainer.save_model(kwargs.get("output_dir", output_dir))

Job Submission Script

#!/usr/bin/env bash
#SBATCH --job-name=train_guac0_1xA100
#SBATCH --cpus-per-task=16
#SBATCH --partition=gpu_p
#SBATCH --gres=gpu:A100:1
#SBATCH --ntasks=1
#SBATCH --mem=64gb
#SBATCH --time=03:00:00
#SBATCH --output=logs/%x.%j.out
#SBATCH --error=logs/%x.%j.err

#SBATCH --mail-type=ALL
#SBATCH --mail-user=$$YOUR_MYID@uga.edu

export JOB_CUSTODIAN="$$YOUR_MYID"
export JOB_GROUP="$$YOUR_LAB"

export PROJECT_DIR="/work/$JOB_GROUP/$JOB_CUSTODIAN/"
export SCRATCH_DIR="/scratch/$JOB_CUSTODIAN/"
export LSCRATCH_DIR="/lscratch/$JOB_CUSTODIAN/"

export PROJECT_NAME="guac0"
export PROJECT_VARIANT="flash-attn0"

export PROJECT_TITLE="$PROJECT_NAME.$PROJECT_VARIANT.$SLURM_JOBID"

export RESULT_DEPOT="$SCRATCH_DIR/llm/models/hf/$PROJECT_TITLE"
export TRAINING_OUTPUT="$LSCRATCH_DIR/$PROJECT_TITLE"

export TRAINING_BASE_MODEL="Meta-Llama-3-8B"

export OMP_NUM_THREADS=16
export PER_DEVICE_BATCH_SIZE=1
export GPUS_PER_NODE=1
export TRAINING_EPOCHS=3

export TRAINING_VENV="/scratch/$$YOUR_MYID/llm/projects/workbench/venv/"
export TRAINING_SCRIPT="/scratch/$$YOUR_MYID/llm/projects/guac/scripts/training/train_guac0.py"
export TRAINING_ARGS="-b $PER_DEVICE_BATCH_SIZE -m $TRAINING_BASE_MODEL -o $TRAINING_OUTPUT -e $TRAINING_EPOCHS -s $MAX_SEQ_LENGTH"

export WANDB_PROJECT="$PROJECT_NAME"
export WANDB_LOG_MODEL="checkpoint"
export WANDB_JOB_TYPE="training"
export WANDB_NAME="$PROJECT_TITLE"

export CUDA_VERSION="12.1.1"
export RDZV_BACKEND="c10d"
export RDZV_ID=2299
export RDZV_PORT=29500

cd $SLURM_SUBMIT_DIR

module load CUDA/$CUDA_VERSION diffusers ccache wandb flash-attn

head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)

export LAUNCHER="python -m torch.distributed.run \
        --nproc_per_node $GPUS_PER_NODE \
        --nnodes $SLURM_NNODES \
        --rdzv_id $RDZV_ID \
        --rdzv_backend $RDZV_BACKEND \
        --rdzv_endpoint $head_node_ip:$RDZV_PORT \
"

source $TRAINING_VENV/bin/activate

export CMD="$LAUNCHER $TRAINING_SCRIPT $TRAINING_ARGS"
srun --jobid $SLURM_JOB_ID bash -c "$CMD"

deactivate
rsync -r $TRAINING_OUTPUT $RESULT_DEPOT