|
import os |
|
import sys |
|
|
|
import torch |
|
import torch.nn as nn |
|
import bitsandbytes as bnb |
|
from datasets import load_dataset |
|
import transformers |
|
from peft import PeftModel |
|
import wandb |
|
|
|
|
|
assert ( |
|
"LlamaTokenizer" in transformers._import_structure["models.llama"] |
|
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" |
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
from peft import ( |
|
prepare_model_for_int8_training, |
|
LoraConfig, |
|
get_peft_model, |
|
get_peft_model_state_dict, |
|
) |
|
|
|
|
|
|
|
MICRO_BATCH_SIZE = 64 |
|
BATCH_SIZE = 128 |
|
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE |
|
EPOCHS = 2 |
|
LEARNING_RATE = 3e-4 |
|
CUTOFF_LEN = 256 |
|
LORA_R = 8 |
|
LORA_ALPHA = 16 |
|
LORA_DROPOUT = 0.05 |
|
VAL_SET_SIZE = 2000 |
|
TARGET_MODULES = [ |
|
"q_proj", |
|
"v_proj", |
|
] |
|
DATA_PATH = "alpaca_data.json" |
|
DATA_PATH = "belle_open_source_1M.train.json" |
|
OUTPUT_DIR = "lora-alpaca" |
|
|
|
device_map = "auto" |
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
ddp = world_size != 1 |
|
if ddp: |
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
|
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size |
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
"/ndk/ai-repos/train-llama/models/7b_hf", |
|
load_in_8bit=True, |
|
device_map=device_map, |
|
) |
|
tokenizer = LlamaTokenizer.from_pretrained( |
|
"/ndk/ai-repos/train-llama/models/7b_hf", add_eos_token=True |
|
) |
|
|
|
model = prepare_model_for_int8_training(model) |
|
|
|
config = LoraConfig( |
|
r=LORA_R, |
|
lora_alpha=LORA_ALPHA, |
|
target_modules=TARGET_MODULES, |
|
lora_dropout=LORA_DROPOUT, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained ( |
|
model, |
|
"./lora-alpaca-cn-remote", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
|
|
tokenizer.pad_token_id = 0 |
|
data = load_dataset("json", data_files=DATA_PATH) |
|
|
|
|
|
def generate_prompt(data_point): |
|
|
|
return f"""以下是描述任务的说明。 编写适当地完成请求的响应。 |
|
### 输入: |
|
{data_point["input"]} |
|
|
|
### 输出: |
|
{data_point["target"]}""" |
|
|
|
|
|
def tokenize(prompt): |
|
|
|
|
|
result = tokenizer( |
|
prompt, |
|
truncation=True, |
|
max_length=CUTOFF_LEN + 1, |
|
padding="max_length", |
|
) |
|
return { |
|
"input_ids": result["input_ids"][:-1], |
|
"attention_mask": result["attention_mask"][:-1], |
|
} |
|
|
|
|
|
def generate_and_tokenize_prompt(data_point): |
|
prompt = generate_prompt(data_point) |
|
return tokenize(prompt) |
|
|
|
|
|
if VAL_SET_SIZE > 0: |
|
train_val = data["train"].train_test_split( |
|
test_size=VAL_SET_SIZE, shuffle=True, seed=42 |
|
) |
|
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt) |
|
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt) |
|
else: |
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) |
|
val_data = None |
|
|
|
|
|
wandb.init(project="llama-lora") |
|
|
|
trainer = transformers.Trainer( |
|
model=model, |
|
train_dataset=train_data, |
|
eval_dataset=val_data, |
|
args=transformers.TrainingArguments( |
|
per_device_train_batch_size=MICRO_BATCH_SIZE, |
|
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, |
|
warmup_steps=100, |
|
num_train_epochs=EPOCHS, |
|
learning_rate=LEARNING_RATE, |
|
fp16=True, |
|
logging_steps=20, |
|
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no", |
|
save_strategy="steps", |
|
eval_steps=200 if VAL_SET_SIZE > 0 else None, |
|
save_steps=200, |
|
output_dir=OUTPUT_DIR, |
|
save_total_limit=3, |
|
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False, |
|
ddp_find_unused_parameters=False if ddp else None, |
|
), |
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), |
|
) |
|
model.config.use_cache = False |
|
|
|
old_state_dict = model.state_dict |
|
model.state_dict = ( |
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) |
|
).__get__(model, type(model)) |
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32": |
|
model = torch.compile(model) |
|
|
|
trainer.train() |
|
|
|
model.save_pretrained(OUTPUT_DIR) |
|
|
|
print("\n If there's a warning about missing keys above, please disregard :)") |
|
|
|
|