LLM Training

From Research Computing Center Wiki
Revision as of 09:35, 3 July 2024 by Kstanier (talk | contribs) (Added initial training workflow details)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigation Jump to search

Introduction

HuggingFace

Hub

Libraries

Transformers
module load Transformers
Datasets
module load datasets
TRL

(software module coming soon, can be installed in venv meanwhile)

python3 -m venv ~/trl_venv
source ~/trl_venv/bin/activate
pip install --require-virtualenv trl

Compute Resources

LLM Training Compute Resource Consumption
Identifier Accelerator Resources Methods Training Duration Notes
alpaca.1xMIG.A100 LoRA (4-bit), FlashAttention2 Planned
alpaca.1xL4 1 * Nvidia L4 (24GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.3xL4 3 * Nvidia L4 (24GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.4xL4 4 * Nvidia L4 (24GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.1xMI210 1 * AMD MI210 (64GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.3xMI210 3 * AMD MI210 (64GB VRAM each) LoRA (4-bit), FlashAttention2 Pending
alpaca.1xA100 1 * Nvidia A100 (80GB VRAM each) LoRA (4-bit) Pending
alpaca.1xA100 1 * Nvidia A100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 PDBS: 1
alpaca.3xA100 1 * Nvidia A100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 PDBS: 1
alpaca.4xA100 1 * Nvidia A100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 PDBS: 1
alpaca.1xH100 1 * Nvidia H100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 Planned
alpaca.3xH100 1 * Nvidia H100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 Planned
alpaca.4xH100 1 * Nvidia H100 (80GB VRAM each) LoRA (4-bit), FlashAttention2 Planned

Training Script (w/HuggingFace)

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_from_disk
from trl import SFTTrainer, ORPOTrainer, AutoModelForCausalLMWithValueHead, ModelConfig, get_peft_config, get_quantization_config, get_kbit_device_map, DataCollatorForCompletionOnlyLM
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
 
base_model = "/scratch/ks98810/llm/models/hf/Meta-Llama-3-8B"
output_dir = "/lscratch/ks98810/guac0/"

report_to = "wandb"

attn_implementation = "flash_attention_2"

def prompt_formatting_func(self, article):
    output_texts = []

    for i in range(len(article['hash'])):
        text = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n{% for message in messages %}\n{% if message['role'] == 'system' %}### Instruction:\n{% elif message['role'] == 'user' %}### Input:\n{% elif message['role'] == 'assistant' %}### Response:\n{% endif %}{{message['content']}}\n{% endfor %}\n### Response:\n".render(messages = article['messages'][i])
        output_texts.append(text)
    return output_texts

if __name__ == "__main__":
    # Basic model config
    model_config = ModelConfig(
        model_name_or_path      = base_model,
        attn_implementation     = attn_implementation,
    )
    quant_config = get_quantization_config(model_config)

    model_kwargs = dict(
        torch_dtype         = "auto",
        load_in_4bit        = True,
        trust_remote_code   = False, # Don't
        attn_implementation = attn_implementation,
        use_cache           = False, # false if grad chkpnting
        quantization_config = get_quantization_config(model_config),
        device_map          = get_kbit_device_map(),
    )

    # Load model & tokenizer
    tokenizer  = AutoTokenizer.from_pretrained(model_path)
    model      = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)

    tokenizer.pad_token = tokenizer.eos_token

    lora_config = LoraConfig(
        r               = 64,
        lora_alpha      = 16,
        lora_dropout    = 0.05,
        bias            = "none",
        task_type       = "CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)

    train_dataset = load_from_disk("/scratch/ks98810/llm/datasets/guac-merge0")
    training_args = TrainingArguments(
        logging_strategy            = "steps",
        logging_steps               = 500,
        logging_first_step          = True,
        report_to                   = report_to,
        num_train_epochs            = 3,
        output_dir                  = kwargs.get("output_dir", defaults["output_path"]),
        per_device_train_batch_size = 1,
        learning_rate               = 2e-4,
    )
    response_template = "### Response:\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)
    trainer = SFTTrainer(
        model,
        args                = training_args,
        train_dataset       = train_dataset,
        dataset_text_field  = "text",
        max_seq_length      = 4096,
        peft_config         = lora_config,
        formatting_func     = prompt_formatting_func,
        data_collator       = collator,
    )

    trainer.train()
    trainer.save_model(kwargs.get("output_dir", output_dir))

Job Submission Script

#!/usr/bin/env bash
#SBATCH --job-name=train_guac0_4xA100
#SBATCH --cpus-per-task=32
#SBATCH --partition=gpu_p
#SBATCH --gres=gpu:A100:4
#SBATCH --ntasks=1
#SBATCH --mem=256gb
#SBATCH --time=06:00:00
#SBATCH --output=logs/%x.%j.out
#SBATCH --error=logs/%x.%j.err

#SBATCH --mail-type=ALL
#SBATCH --mail-user=ks98810@uga.edu

export JOB_CUSTODIAN="ks98810"
export JOB_GROUP="gclab"

export PROJECT_DIR="/work/$JOB_GROUP/$JOB_CUSTODIAN/"
export SCRATCH_DIR="/scratch/$JOB_CUSTODIAN/"
export LSCRATCH_DIR="/lscratch/$JOB_CUSTODIAN/"

export PROJECT_NAME="guac0"
export PROJECT_VARIANT="flash-attn0"

export PROJECT_TITLE="$PROJECT_NAME.$PROJECT_VARIANT.$SLURM_JOBID"

export RESULT_DEPOT="$SCRATCH_DIR/llm/models/hf/$PROJECT_TITLE"
export TRAINING_OUTPUT="$LSCRATCH_DIR/$PROJECT_TITLE"

export OMP_NUM_THREADS=16

export PER_DEVICE_BATCH_SIZE=1
export GPUS_PER_NODE=4

export TRAINING_EPOCHS=3

export TRAINING_VENV="/scratch/ks98810/llm/projects/workbench/venv/"
export TRAINING_SCRIPT="/scratch/ks98810/llm/projects/guac/scripts/training/train_guac0.py"
export TRAINING_ARGS="-b $PER_DEVICE_BATCH_SIZE -o $TRAINING_OUTPUT -e $TRAINING_EPOCHS"

export WANDB_PROJECT="$PROJECT_NAME"
export WANDB_LOG_MODEL="checkpoint"
export WANDB_JOB_TYPE="training"
export WANDB_NAME="$PROJECT_TITLE"

export CUDA_VERSION="12.1.1"
export RDZV_BACKEND="c10d"
export RDZV_ID=2299
export RDZV_PORT=29500

cd $SLURM_SUBMIT_DIR

module load CUDA/$CUDA_VERSION diffusers ccache wandb flash-attn

head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)

export LAUNCHER="python -m torch.distributed.run \
        --nproc_per_node $GPUS_PER_NODE \
        --nnodes $SLURM_NNODES \
        --rdzv_id $RDZV_ID \
        --rdzv_backend $RDZV_BACKEND \
        --rdzv_endpoint $head_node_ip:$RDZV_PORT \
"

source $TRAINING_VENV/bin/activate

export CMD="$LAUNCHER $TRAINING_SCRIPT $TRAINING_ARGS"
srun --jobid $SLURM_JOB_ID bash -c "$CMD"

deactivate
rsync -r $TRAINING_OUTPUT $RESULT_DEPOT