LLM Training
Jump to navigation
Jump to search
Introduction
HuggingFace
Hub
Libraries
Transformers
module load Transformers
Datasets
module load datasets
TRL
(software module coming soon, can be installed in venv meanwhile)
python3 -m venv ~/trl_venv
source ~/trl_venv/bin/activate
pip install --require-virtualenv trl
Compute Resources
52,002 instruct articles for 3 epochs against Meta-Llama-3-8B loaded in 4bit with the PEFT library
Vendor | Product | Backend | VRAM (GB) |
---|---|---|---|
Nvidia | L4 | CUDA | 24 |
A100 | 80 | ||
H100 | |||
AMD | MI210 | ROCm | 24 |
# of Acc. | Acc. Hardware | Training Duration | Notes |
---|---|---|---|
1x | Nvidia L4 | Pending | |
3x | Nvidia L4 | ||
4x | Nvidia L4 | ||
1x | AMD MI210 | ||
3x | AMD MI210 | ||
1x | Nvidia A100 | ||
1x | PDBS: 1 | ||
2h50m25s | PDBS: 5 | ||
2h32m32s | PDBS: 7 | ||
3x | PDBS: 7 | ||
4x | 47m15s | PDBS: 7 | |
1x | Nvidia H100 | Planned | |
3x | |||
4x |
Training Script (w/HuggingFace)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_from_disk
from trl import SFTTrainer, AutoModelForCausalLMWithValueHead, ModelConfig, get_peft_config, get_quantization_config, get_kbit_device_map, DataCollatorForCompletionOnlyLM
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
base_model = "/scratch/$$YOUR_MYID/llm/models/hf/Meta-Llama-3-8B"
output_dir = "/lscratch/$$YOUR_MYID/guac0/"
report_to = "wandb"
attn_implementation = "flash_attention_2"
def prompt_formatting_func(self, article):
output_texts = []
for i in range(len(article['hash'])):
text = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n{% for message in messages %}\n{% if message['role'] == 'system' %}### Instruction:\n{% elif message['role'] == 'user' %}### Input:\n{% elif message['role'] == 'assistant' %}### Response:\n{% endif %}{{message['content']}}\n{% endfor %}\n### Response:\n".render(messages = article['messages'][i])
output_texts.append(text)
return output_texts
if __name__ == "__main__":
# Basic model config
model_config = ModelConfig(
model_name_or_path = base_model,
attn_implementation = attn_implementation,
)
quant_config = get_quantization_config(model_config)
model_kwargs = dict(
torch_dtype = "auto",
load_in_4bit = True,
trust_remote_code = False, # Don't
attn_implementation = attn_implementation,
use_cache = False, # false if grad chkpnting
quantization_config = get_quantization_config(model_config),
device_map = get_kbit_device_map(),
)
# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
tokenizer.pad_token = tokenizer.eos_token
lora_config = LoraConfig(
r = 64,
lora_alpha = 16,
lora_dropout = 0.05,
bias = "none",
task_type = "CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
train_dataset = load_from_disk("/scratch/ks98810/llm/datasets/guac-merge0")
training_args = TrainingArguments(
logging_strategy = "steps",
logging_steps = 500,
logging_first_step = True,
report_to = report_to,
num_train_epochs = 3,
output_dir = kwargs.get("output_dir", defaults["output_path"]),
per_device_train_batch_size = 1,
learning_rate = 2e-4,
)
response_template = "### Response:\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)
trainer = SFTTrainer(
model,
args = training_args,
train_dataset = train_dataset,
dataset_text_field = "text",
max_seq_length = 4096,
peft_config = lora_config,
formatting_func = prompt_formatting_func,
data_collator = collator,
)
trainer.train()
trainer.save_model(kwargs.get("output_dir", output_dir))
Job Submission Script
#!/usr/bin/env bash
#SBATCH --job-name=train_guac0_1xA100
#SBATCH --cpus-per-task=16
#SBATCH --partition=gpu_p
#SBATCH --gres=gpu:A100:1
#SBATCH --ntasks=1
#SBATCH --mem=64gb
#SBATCH --time=03:00:00
#SBATCH --output=logs/%x.%j.out
#SBATCH --error=logs/%x.%j.err
#SBATCH --mail-type=ALL
#SBATCH --mail-user=$$YOUR_MYID@uga.edu
export JOB_CUSTODIAN="$$YOUR_MYID"
export JOB_GROUP="$$YOUR_LAB"
export PROJECT_DIR="/work/$JOB_GROUP/$JOB_CUSTODIAN/"
export SCRATCH_DIR="/scratch/$JOB_CUSTODIAN/"
export LSCRATCH_DIR="/lscratch/$JOB_CUSTODIAN/"
export PROJECT_NAME="guac0"
export PROJECT_VARIANT="flash-attn0"
export PROJECT_TITLE="$PROJECT_NAME.$PROJECT_VARIANT.$SLURM_JOBID"
export RESULT_DEPOT="$SCRATCH_DIR/llm/models/hf/$PROJECT_TITLE"
export TRAINING_OUTPUT="$LSCRATCH_DIR/$PROJECT_TITLE"
export TRAINING_BASE_MODEL="Meta-Llama-3-8B"
export OMP_NUM_THREADS=16
export PER_DEVICE_BATCH_SIZE=1
export GPUS_PER_NODE=1
export TRAINING_EPOCHS=3
export TRAINING_VENV="/scratch/$$YOUR_MYID/llm/projects/workbench/venv/"
export TRAINING_SCRIPT="/scratch/$$YOUR_MYID/llm/projects/guac/scripts/training/train_guac0.py"
export TRAINING_ARGS="-b $PER_DEVICE_BATCH_SIZE -m $TRAINING_BASE_MODEL -o $TRAINING_OUTPUT -e $TRAINING_EPOCHS -s $MAX_SEQ_LENGTH"
export WANDB_PROJECT="$PROJECT_NAME"
export WANDB_LOG_MODEL="checkpoint"
export WANDB_JOB_TYPE="training"
export WANDB_NAME="$PROJECT_TITLE"
export CUDA_VERSION="12.1.1"
export RDZV_BACKEND="c10d"
export RDZV_ID=2299
export RDZV_PORT=29500
cd $SLURM_SUBMIT_DIR
module load CUDA/$CUDA_VERSION diffusers ccache wandb flash-attn
head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export LAUNCHER="python -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $SLURM_NNODES \
--rdzv_id $RDZV_ID \
--rdzv_backend $RDZV_BACKEND \
--rdzv_endpoint $head_node_ip:$RDZV_PORT \
"
source $TRAINING_VENV/bin/activate
export CMD="$LAUNCHER $TRAINING_SCRIPT $TRAINING_ARGS"
srun --jobid $SLURM_JOB_ID bash -c "$CMD"
deactivate
rsync -r $TRAINING_OUTPUT $RESULT_DEPOT