LLM Training: Difference between revisions
Jump to navigation
Jump to search
(Added a test) |
(Removed a test) Tags: Manual revert Visual edit |
||
Line 18: | Line 18: | ||
===== TRL ===== | ===== TRL ===== | ||
(software module coming soon, can be installed in venv meanwhile) | (software module coming soon, can be installed in venv meanwhile)<syntaxhighlight lang="bash"> | ||
python3 -m venv ~/trl_venv | python3 -m venv ~/trl_venv | ||
source ~/trl_venv/bin/activate | source ~/trl_venv/bin/activate |
Latest revision as of 14:52, 18 September 2024
Introduction
HuggingFace
Hub
Libraries
Transformers
module load Transformers
Datasets
module load datasets
TRL
(software module coming soon, can be installed in venv meanwhile)
python3 -m venv ~/trl_venv
source ~/trl_venv/bin/activate
pip install --require-virtualenv trl
Compute Resources
52,002 instruct articles for 3 epochs against Meta-Llama-3-8B loaded in 4bit with the PEFT library
Vendor | Product | Backend | VRAM (GB) |
---|---|---|---|
Nvidia | L4 | CUDA | 24 |
A100 | 80 | ||
H100 | |||
AMD | MI210 | ROCm | 64 |
# of Acc. | Acc. Hardware | Training Duration | Notes | VRAM Usage (GB / %) |
---|---|---|---|---|
1x | Nvidia L4 | Pending | ||
3x | Nvidia L4 | |||
4x | Nvidia L4 | |||
1x | AMD MI210 | |||
3x | AMD MI210 | |||
1x | Nvidia A100 | |||
1x | Pending | PDBS: 1 | ||
2h50m25s | PDBS: 5 | ~70.704 / 88.38% | ||
2h32m32s | PDBS: 7 | ~76.472 / 95.59% | ||
3x | 1h14s | PDBS: 7 | ~224.384 / 93.49% | |
4x | 47m15s | PDBS: 7 | ~306.824 / 95.88% | |
1x | Nvidia H100 | Planned | ||
3x | ||||
4x |
Training Script (w/HuggingFace)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_from_disk
from trl import SFTTrainer, AutoModelForCausalLMWithValueHead, ModelConfig, get_peft_config, get_quantization_config, get_kbit_device_map, DataCollatorForCompletionOnlyLM
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
base_model = "/scratch/$$YOUR_MYID/llm/models/hf/Meta-Llama-3-8B"
output_dir = "/lscratch/$$YOUR_MYID/guac0/"
report_to = "wandb"
attn_implementation = "flash_attention_2"
def prompt_formatting_func(self, article):
output_texts = []
for i in range(len(article['hash'])):
text = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n{% for message in messages %}\n{% if message['role'] == 'system' %}### Instruction:\n{% elif message['role'] == 'user' %}### Input:\n{% elif message['role'] == 'assistant' %}### Response:\n{% endif %}{{message['content']}}\n{% endfor %}\n### Response:\n".render(messages = article['messages'][i])
output_texts.append(text)
return output_texts
if __name__ == "__main__":
# Basic model config
model_config = ModelConfig(
model_name_or_path = base_model,
attn_implementation = attn_implementation,
)
quant_config = get_quantization_config(model_config)
model_kwargs = dict(
torch_dtype = "auto",
load_in_4bit = True,
trust_remote_code = False, # Don't
attn_implementation = attn_implementation,
use_cache = False, # false if grad chkpnting
quantization_config = get_quantization_config(model_config),
device_map = get_kbit_device_map(),
)
# Load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
tokenizer.pad_token = tokenizer.eos_token
lora_config = LoraConfig(
r = 64,
lora_alpha = 16,
lora_dropout = 0.05,
bias = "none",
task_type = "CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
train_dataset = load_from_disk("/scratch/ks98810/llm/datasets/guac-merge0")
training_args = TrainingArguments(
logging_strategy = "steps",
logging_steps = 500,
logging_first_step = True,
report_to = report_to,
num_train_epochs = 3,
output_dir = kwargs.get("output_dir", defaults["output_path"]),
per_device_train_batch_size = 1,
learning_rate = 2e-4,
)
response_template = "### Response:\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer = tokenizer)
trainer = SFTTrainer(
model,
args = training_args,
train_dataset = train_dataset,
dataset_text_field = "text",
max_seq_length = 4096,
peft_config = lora_config,
formatting_func = prompt_formatting_func,
data_collator = collator,
)
trainer.train()
trainer.save_model(kwargs.get("output_dir", output_dir))
Job Submission Script
#!/usr/bin/env bash
#SBATCH --job-name=train_guac0_1xA100
#SBATCH --cpus-per-task=16
#SBATCH --partition=gpu_p
#SBATCH --gres=gpu:A100:1
#SBATCH --ntasks=1
#SBATCH --mem=64gb
#SBATCH --time=03:00:00
#SBATCH --output=logs/%x.%j.out
#SBATCH --error=logs/%x.%j.err
#SBATCH --mail-type=ALL
#SBATCH --mail-user=$$YOUR_MYID@uga.edu
export JOB_CUSTODIAN="$$YOUR_MYID"
export JOB_GROUP="$$YOUR_LAB"
export PROJECT_DIR="/work/$JOB_GROUP/$JOB_CUSTODIAN/"
export SCRATCH_DIR="/scratch/$JOB_CUSTODIAN/"
export LSCRATCH_DIR="/lscratch/$JOB_CUSTODIAN/"
export PROJECT_NAME="guac0"
export PROJECT_VARIANT="flash-attn0"
export PROJECT_TITLE="$PROJECT_NAME.$PROJECT_VARIANT.$SLURM_JOBID"
export RESULT_DEPOT="$SCRATCH_DIR/llm/models/hf/$PROJECT_TITLE"
export TRAINING_OUTPUT="$LSCRATCH_DIR/$PROJECT_TITLE"
export TRAINING_BASE_MODEL="Meta-Llama-3-8B"
export OMP_NUM_THREADS=16
export PER_DEVICE_BATCH_SIZE=1
export GPUS_PER_NODE=1
export TRAINING_EPOCHS=3
export TRAINING_VENV="/scratch/$$YOUR_MYID/llm/projects/workbench/venv/"
export TRAINING_SCRIPT="/scratch/$$YOUR_MYID/llm/projects/guac/scripts/training/train_guac0.py"
export TRAINING_ARGS="-b $PER_DEVICE_BATCH_SIZE -m $TRAINING_BASE_MODEL -o $TRAINING_OUTPUT -e $TRAINING_EPOCHS -s $MAX_SEQ_LENGTH"
export WANDB_PROJECT="$PROJECT_NAME"
export WANDB_LOG_MODEL="checkpoint"
export WANDB_JOB_TYPE="training"
export WANDB_NAME="$PROJECT_TITLE"
export CUDA_VERSION="12.1.1"
export RDZV_BACKEND="c10d"
export RDZV_ID=2299
export RDZV_PORT=29500
cd $SLURM_SUBMIT_DIR
module load CUDA/$CUDA_VERSION diffusers ccache wandb flash-attn
head_node_ip=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export LAUNCHER="python -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $SLURM_NNODES \
--rdzv_id $RDZV_ID \
--rdzv_backend $RDZV_BACKEND \
--rdzv_endpoint $head_node_ip:$RDZV_PORT \
"
source $TRAINING_VENV/bin/activate
export CMD="$LAUNCHER $TRAINING_SCRIPT $TRAINING_ARGS"
srun --jobid $SLURM_JOB_ID bash -c "$CMD"
deactivate
rsync -r $TRAINING_OUTPUT $RESULT_DEPOT