LLM Training: Difference between revisions

From Research Computing Center Wiki
Jump to navigation Jump to search
(Added a test)
(Removed a test)
Tags: Manual revert Visual edit
 
Line 18: Line 18:


===== TRL =====
===== TRL =====
(software module coming soon, can be installed in venv meanwhile) Hello World!<syntaxhighlight lang="bash">
(software module coming soon, can be installed in venv meanwhile)<syntaxhighlight lang="bash">
python3 -m venv ~/trl_venv
python3 -m venv ~/trl_venv
source ~/trl_venv/bin/activate
source ~/trl_venv/bin/activate

Latest revision as of 14:52, 18 September 2024

Introduction

HuggingFace

Hub

Libraries

Transformers
module load Transformers
Datasets
module load datasets
TRL

(software module coming soon, can be installed in venv meanwhile)

python3 -m venv ~/trl_venv
source ~/trl_venv/bin/activate
pip install --require-virtualenv trl

Compute Resources

52,002 instruct articles for 3 epochs against Meta-Llama-3-8B loaded in 4bit with the PEFT library

Tested Accelerators
Vendor Product Backend VRAM (GB)
Nvidia L4 CUDA 24
A100 80
H100
AMD MI210 ROCm 64
LLM Training Compute Resource Consumption
# of Acc. Acc. Hardware Training Duration Notes VRAM Usage (GB / %)
1x Nvidia L4 Pending
3x Nvidia L4
4x Nvidia L4
1x AMD MI210
3x AMD MI210
1x Nvidia A100
1x Pending PDBS: 1
2h50m25s PDBS: 5 ~70.704 / 88.38%
2h32m32s PDBS: 7 ~76.472 / 95.59%
3x 1h14s PDBS: 7 ~224.384 / 93.49%
4x 47m15s PDBS: 7 ~306.824 / 95.88%
1x Nvidia H100 Planned
3x
4x

Training Script (w/HuggingFace)

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_from_disk
from trl import SFTTrainer, AutoModelForCausalLMWithValueHead, ModelConfig, get_peft_config, get_quantization_config, get_kbit_device_map, DataCollatorForCompletionOnlyLM
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
 
base_model = "/scratch/$$YOUR_MYID/llm/models/hf/Meta-Llama-3-8B"
output_dir = "/lscratch/$$YOUR_MYID/guac0/"

report_to = "wandb"

attn_implementation = "flash_attention_2"

def prompt_formatting_func(self, article):
    output_texts = []

    for i in range(len(article['hash'])):
        text = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n{% for message in messages %}\n{% if message['role'] == 'system' %}### Instruction:\n{% elif message['role'] == 'user' %}### Input:\n{% elif message['role'] == 'assistant' %}### Response:\n{% endif %}{{message['content']}}\n{% endfor %}\n### Response:\n".render(messages = article['messages'][i])
        output_texts.append(text)
    return output_texts

if __name__ == "__main__":
    # Basic model config
    model_config = ModelConfig(
        model_name_or_path      = base_model,
        attn_implementation     = attn_implementation,
    )
    quant_config = get_quantization_config(model_config)

    model_kwargs = dict(
        torch_dtype         = "auto",
        load_in_4bit        = True,
        trust_remote_code   = False, # Don't
        attn_implementation = attn_implementation,
        use_cache           = False, # false if grad chkpnting
        quantization_config = get_quantization_config(model_config),
        device_map          = get_kbit_device_map(),
    )

    # Load model & tokenizer
    tokenizer  = AutoTokenizer.from_pretrained(model_path)
    model      = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)

    tokenizer.pad_token = tokenizer.eos_token

    lora_config = LoraConfig(
        r               = 64,
        lora_alpha      = 16,
        lora_dropout    = 0.05,
        bias            = "none",
        task_type       = "CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)

    train_dataset = load_from_disk("/scratch/ks98810/llm/datasets/guac-merge0")
    training_args = TrainingArguments(
        logging_strategy            = "steps",
        logging_steps               = 500,
        logging_first_step          = True,
        report_to                   = report_to,
        num_train_epochs            = 3,
        output_dir                  = kwargs.get("output_dir", defaults["output_path"]),
        per_device_train_batch_size = 1,
        learning_rate               = 2e-4,
    )
    response_template = "### Response:\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)
    trainer = SFTTrainer(
        model,
        args                = training_args,
        train_dataset       = train_dataset,
        dataset_text_field  = "text",
        max_seq_length      = 4096,
        peft_config         = lora_config,
        formatting_func     = prompt_formatting_func,
        data_collator       = collator,
    )

    trainer.train()
    trainer.save_model(kwargs.get("output_dir", output_dir))

Job Submission Script

#!/usr/bin/env bash
#SBATCH --job-name=train_guac0_1xA100
#SBATCH --cpus-per-task=16
#SBATCH --partition=gpu_p
#SBATCH --gres=gpu:A100:1
#SBATCH --ntasks=1
#SBATCH --mem=64gb
#SBATCH --time=03:00:00
#SBATCH --output=logs/%x.%j.out
#SBATCH --error=logs/%x.%j.err

#SBATCH --mail-type=ALL
#SBATCH --mail-user=$$YOUR_MYID@uga.edu

export JOB_CUSTODIAN="$$YOUR_MYID"
export JOB_GROUP="$$YOUR_LAB"

export PROJECT_DIR="/work/$JOB_GROUP/$JOB_CUSTODIAN/"
export SCRATCH_DIR="/scratch/$JOB_CUSTODIAN/"
export LSCRATCH_DIR="/lscratch/$JOB_CUSTODIAN/"

export PROJECT_NAME="guac0"
export PROJECT_VARIANT="flash-attn0"

export PROJECT_TITLE="$PROJECT_NAME.$PROJECT_VARIANT.$SLURM_JOBID"

export RESULT_DEPOT="$SCRATCH_DIR/llm/models/hf/$PROJECT_TITLE"
export TRAINING_OUTPUT="$LSCRATCH_DIR/$PROJECT_TITLE"

export TRAINING_BASE_MODEL="Meta-Llama-3-8B"

export OMP_NUM_THREADS=16
export PER_DEVICE_BATCH_SIZE=1
export GPUS_PER_NODE=1
export TRAINING_EPOCHS=3

export TRAINING_VENV="/scratch/$$YOUR_MYID/llm/projects/workbench/venv/"
export TRAINING_SCRIPT="/scratch/$$YOUR_MYID/llm/projects/guac/scripts/training/train_guac0.py"
export TRAINING_ARGS="-b $PER_DEVICE_BATCH_SIZE -m $TRAINING_BASE_MODEL -o $TRAINING_OUTPUT -e $TRAINING_EPOCHS -s $MAX_SEQ_LENGTH"

export WANDB_PROJECT="$PROJECT_NAME"
export WANDB_LOG_MODEL="checkpoint"
export WANDB_JOB_TYPE="training"
export WANDB_NAME="$PROJECT_TITLE"

export CUDA_VERSION="12.1.1"
export RDZV_BACKEND="c10d"
export RDZV_ID=2299
export RDZV_PORT=29500

cd $SLURM_SUBMIT_DIR

module load CUDA/$CUDA_VERSION diffusers ccache wandb flash-attn

head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)

export LAUNCHER="python -m torch.distributed.run \
        --nproc_per_node $GPUS_PER_NODE \
        --nnodes $SLURM_NNODES \
        --rdzv_id $RDZV_ID \
        --rdzv_backend $RDZV_BACKEND \
        --rdzv_endpoint $head_node_ip:$RDZV_PORT \
"

source $TRAINING_VENV/bin/activate

export CMD="$LAUNCHER $TRAINING_SCRIPT $TRAINING_ARGS"
srun --jobid $SLURM_JOB_ID bash -c "$CMD"

deactivate
rsync -r $TRAINING_OUTPUT $RESULT_DEPOT