test_model / finetune_lora.py
khoicrtp's picture
Upload 13 files
d417650
"""
Instruction-tuning with LoRA on the Alpaca dataset.
Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
"""
import os
import time
import lightning as L
import numpy as np
import torch
from generate import generate
from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
from lit_llama.model import LLaMA, LLaMAConfig
from lit_llama.tokenizer import Tokenizer
from scripts.prepare_alpaca import generate_prompt
eval_interval = 100
save_interval = 100
eval_iters = 100
log_interval = 1
# Hyperparameters
learning_rate = 3e-4
batch_size = 128
micro_batch_size = 4
gradient_accumulation_steps = batch_size // micro_batch_size
max_iters = 10000 #50000 * 3 // micro_batch_size
weight_decay = 0.0
max_seq_length = 256 # see scripts/prepare_alpaca.py
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
warmup_steps = 100
def main(
data_dir: str = "data/alpaca",
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
out_dir: str = "out/lora/alpaca",
):
#fabric = L.Fabric(accelerator="cuda", precision="bf16-true")
fabric = L.Fabric(accelerator="cpu", devices=1, precision="bf16-true")
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
print("loading dataset ", data_dir)
train_data, val_data = load_datasets(data_dir=data_dir)
print("train data: ", len(train_data))
print("val data: ", len(val_data))
config = LLaMAConfig.from_name("7B")
config.block_size = max_seq_length
print("loading pretrained model ", pretrained_path)
checkpoint = torch.load(pretrained_path)
with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
model = LLaMA(config)
# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)
mark_only_lora_as_trainable(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model, optimizer = fabric.setup(model, optimizer)
print("start training")
train(fabric, model, optimizer, train_data, val_data, out_dir)
# Save the final LoRA checkpoint at the end of training
print(f"Saving LoRA weights to {out_dir}")
checkpoint = lora_state_dict(model)
fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint)
def train(
fabric: L.Fabric,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
) -> None:
"""The training loop.
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0
print("max iters:", max_iters )
for iter_num in range(max_iters):
print("iter_num", iter_num)
if step_count <= warmup_steps:
# linear warmup
lr = learning_rate * step_count / warmup_steps
for param_group in optimizer.param_groups:
param_group['lr'] = lr
t0 = time.time()
input_ids, targets = get_batch(fabric, train_data)
logits = model(input_ids)
print("calculate loss")
loss = loss_fn(logits, targets)
print("backward")
fabric.backward(loss)
if (iter_num + 1) % gradient_accumulation_steps == 0:
print("step optimizer")
optimizer.step()
optimizer.zero_grad()
step_count += 1
if step_count % eval_interval == 0:
val_loss = validate(fabric, model, val_data)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
fabric.barrier()
if step_count % save_interval == 0:
print(f"Saving LoRA weights to {out_dir}")
# We are only saving the LoRA weights
# TODO: Provide a function/script to merge the LoRA weights with pretrained weights
checkpoint = lora_state_dict(model)
fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)
dt = time.time() - t0
if iter_num % log_interval == 0:
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
def generate_response(model, instruction):
tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
sample = {"instruction": instruction, "input": ""}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
output = generate(
model,
idx=encoded,
max_seq_length=max_seq_length,
max_new_tokens=100,
)
output = tokenizer.decode(output)
return output # output.split("### Response:")[1].strip()
@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
input_ids, targets = get_batch(fabric, val_data)
logits = model(input_ids)
loss = loss_fn(logits, targets)
losses[k] = loss.item()
out = losses.mean()
# produce an example:
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
output = generate_response(model, instruction)
fabric.print(instruction)
fabric.print(output)
model.train()
return out.item()
def loss_fn(logits, targets):
# shift the targets such that output n predicts token n+1
logits = logits[..., :-1, :].contiguous()
targets = targets[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return loss
def get_batch(fabric: L.Fabric, data: list):
ix = torch.randint(len(data), (micro_batch_size,))
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]
max_len = max(len(s) for s in input_ids)
def pad_right(x, pad_id):
# pad right based on the longest sequence
n = max_len - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
return x, y
def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
return train_data, val_data
if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
from jsonargparse.cli import CLI
CLI(main)