File size: 7,266 Bytes
7d52396 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
"""
Instruction-tuning with LoRA on the Alpaca dataset.
Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line
`torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101).
"""
import sys
from pathlib import Path
import os
import time
import lightning as L
import numpy as np
import torch
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from generate import generate
from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
from lit_llama.model import LLaMA, LLaMAConfig
from lit_llama.tokenizer import Tokenizer
from scripts.prepare_alpaca import generate_prompt
instruction_tuning = True
eval_interval = 100
save_interval = 100
eval_iters = 100
log_interval = 1
# Hyperparameters
learning_rate = 3e-4
batch_size = 128
micro_batch_size = 4
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
max_iters = 50000 * 3 // micro_batch_size
weight_decay = 0.0
max_seq_length = 256 # see scripts/prepare_alpaca.py
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
warmup_iters = 100
def main(
data_dir: str = "data/alpaca",
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
tokenizer_path: str = "checkpoints/lit-llama/tokenizer.model",
out_dir: str = "out/lora/alpaca",
):
fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true")
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
if fabric.global_rank == 0:
os.makedirs(out_dir, exist_ok=True)
train_data, val_data = load_datasets(data_dir=data_dir)
config = LLaMAConfig.from_name("7B")
config.block_size = max_seq_length
checkpoint = torch.load(pretrained_path)
with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):
model = LLaMA(config)
# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)
mark_only_lora_as_trainable(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model, optimizer = fabric.setup(model, optimizer)
train(fabric, model, optimizer, train_data, val_data, tokenizer_path, out_dir)
# Save the final LoRA checkpoint at the end of training
checkpoint = lora_state_dict(model)
fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint)
def train(
fabric: L.Fabric,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
train_data: np.ndarray,
val_data: np.ndarray,
tokenizer_path: str,
out_dir: str,
) -> None:
"""The training loop.
Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0
for iter_num in range(max_iters):
if step_count <= warmup_iters:
# linear warmup
lr = learning_rate * step_count / warmup_iters
for param_group in optimizer.param_groups:
param_group['lr'] = lr
t0 = time.time()
input_ids, targets = get_batch(fabric, train_data)
with fabric.no_backward_sync(model, enabled=((iter_num + 1) % gradient_accumulation_iters != 0)):
logits = model(input_ids)
loss = loss_fn(logits, targets)
fabric.backward(loss / gradient_accumulation_iters)
if (iter_num + 1) % gradient_accumulation_iters == 0:
optimizer.step()
optimizer.zero_grad()
step_count += 1
if step_count % eval_interval == 0:
val_loss = validate(fabric, model, val_data, tokenizer_path)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
fabric.barrier()
if step_count % save_interval == 0:
print(f"Saving LoRA weights to {out_dir}")
# We are only saving the LoRA weights
# TODO: Provide a function/script to merge the LoRA weights with pretrained weights
checkpoint = lora_state_dict(model)
fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint)
dt = time.time() - t0
if iter_num % log_interval == 0:
fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
def generate_response(model, instruction, tokenizer_path):
tokenizer = Tokenizer(tokenizer_path)
sample = {"instruction": instruction, "input": ""}
prompt = instruction
if instruction_tuning:
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
output = generate(
model,
idx=encoded,
max_seq_length=max_seq_length,
max_new_tokens=100,
)
output = tokenizer.decode(output)
return output # output.split("### Response:")[1].strip()
@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray, tokenizer_path: str) -> torch.Tensor:
fabric.print("Validating ...")
model.eval()
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
input_ids, targets = get_batch(fabric, val_data)
logits = model(input_ids)
loss = loss_fn(logits, targets)
losses[k] = loss.item()
out = losses.mean()
# produce an example:
instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
output = generate_response(model, instruction, tokenizer_path)
fabric.print(instruction)
fabric.print(output)
model.train()
return out.item()
def loss_fn(logits, targets):
# shift the targets such that output n predicts token n+1
logits = logits[..., :-1, :].contiguous()
targets = targets[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return loss
def get_batch(fabric: L.Fabric, data: list):
ix = torch.randint(len(data), (micro_batch_size,))
input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
labels = [data[i]["labels"].type(torch.int64) for i in ix]
max_len = max(len(s) for s in input_ids)
def pad_right(x, pad_id):
# pad right based on the longest sequence
n = max_len - len(x)
return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
return x, y
def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
return train_data, val_data
if __name__ == "__main__":
# Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
from jsonargparse.cli import CLI
CLI(main)
|