|
import random |
|
|
|
from argparse import ArgumentParser |
|
|
|
import torch |
|
|
|
from torch.utils.data import DataLoader |
|
from torch.optim import Adafactor |
|
from torch.amp import autocast |
|
from torch.cuda import is_available as cuda_is_available, is_bf16_supported |
|
from torch.utils.data import random_split |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from torchmetrics.text import Perplexity |
|
|
|
from model import LightGPT, LightGPTInstruct |
|
from data import SmolTalk |
|
|
|
import tiktoken |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser(description="Instruction-tune the GPT.") |
|
|
|
parser.add_argument( |
|
"--base_model_path", default="./checkpoints/checkpoint.pt", type=str |
|
) |
|
parser.add_argument("--max_tokens_per_sample", default=1048, type=int) |
|
parser.add_argument("--mask_input", action="store_true") |
|
parser.add_argument("--batch_size", default=1, type=int) |
|
parser.add_argument("--gradient_accumulation_steps", default=64, type=int) |
|
parser.add_argument("--learning_rate", default=5e-4, type=float) |
|
parser.add_argument("--rms_decay", default=-0.8, type=float) |
|
parser.add_argument("--optimizer_low_memory", action="store_true") |
|
parser.add_argument("--num_epochs", default=4, type=int) |
|
parser.add_argument("--rank", default=8, type=int) |
|
parser.add_argument("--alpha", default=1.0, type=float) |
|
parser.add_argument("--dropout", default=0.05, type=float) |
|
parser.add_argument("--activation_checkpointing", action="store_true") |
|
parser.add_argument("--eval_interval", default=1, type=int) |
|
parser.add_argument("--checkpoint_interval", default=1, type=int) |
|
parser.add_argument( |
|
"--checkpoint_path", default="./checkpoints/lora_instruction.pt", type=str |
|
) |
|
parser.add_argument("--resume", action="store_true") |
|
parser.add_argument("--run_dir_path", default="./runs/instruction-tune", type=str) |
|
parser.add_argument("--device", default="cuda", type=str) |
|
parser.add_argument("--seed", default=None, type=int) |
|
|
|
args = parser.parse_args() |
|
|
|
if "cuda" in args.device and not cuda_is_available(): |
|
raise RuntimeError("Cuda is not available.") |
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
dtype = ( |
|
torch.bfloat16 |
|
if "cuda" in args.device and is_bf16_supported() |
|
else torch.float32 |
|
) |
|
|
|
amp_context = autocast(device_type=args.device, dtype=dtype) |
|
|
|
if args.seed: |
|
torch.manual_seed(args.seed) |
|
random.seed(args.seed) |
|
|
|
logger = SummaryWriter(args.run_dir_path) |
|
|
|
checkpoint = torch.load( |
|
args.base_model_path, map_location=args.device, weights_only=True |
|
) |
|
|
|
model_args = checkpoint["model_args"] |
|
|
|
tokenizer = tiktoken.get_encoding(checkpoint["token_encoding"]) |
|
|
|
dataset = SmolTalk( |
|
tokenizer, |
|
subset="all", |
|
max_tokens_per_sample=args.max_tokens_per_sample, |
|
) |
|
|
|
training, testing = random_split(dataset, (0.9, 0.1)) |
|
|
|
train_loader = DataLoader( |
|
training, |
|
collate_fn=dataset.collate, |
|
batch_size=args.batch_size, |
|
pin_memory="cpu" not in args.device, |
|
shuffle=True, |
|
) |
|
test_loader = DataLoader( |
|
testing, |
|
collate_fn=dataset.collate, |
|
batch_size=args.batch_size, |
|
pin_memory="cpu" not in args.device, |
|
shuffle=False, |
|
) |
|
|
|
model = LightGPT(**model_args) |
|
|
|
if args.activation_checkpointing: |
|
model.enable_activation_checkpointing() |
|
|
|
model = torch.compile(model) |
|
|
|
model.load_state_dict(checkpoint["model"]) |
|
|
|
print("Model checkpoint loaded") |
|
|
|
lora_args = { |
|
"rank": args.rank, |
|
"alpha": args.alpha, |
|
"dropout": args.dropout, |
|
} |
|
|
|
model = LightGPTInstruct(model, **lora_args).to(args.device) |
|
|
|
print("Compiling model") |
|
model.compile() |
|
|
|
optimizer = Adafactor( |
|
model.parameters(), |
|
lr=args.learning_rate, |
|
beta2_decay=args.rms_decay, |
|
foreach=not args.optimizer_low_memory, |
|
) |
|
|
|
starting_epoch = 1 |
|
|
|
if args.resume: |
|
checkpoint = torch.load( |
|
args.checkpoint_path, map_location=args.device, weights_only=True |
|
) |
|
|
|
model.load_state_dict(checkpoint["lora"], strict=False) |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
starting_epoch += checkpoint["epoch"] |
|
|
|
print("Previous checkpoint resumed successfully") |
|
|
|
model.train() |
|
|
|
print(f"Model has {model.num_trainable_params:,} trainable parameters") |
|
|
|
perplexity_metric = Perplexity(ignore_index=dataset.PADDING_INDEX).to(args.device) |
|
|
|
print("Instruction-tuning ...") |
|
|
|
for epoch in range(starting_epoch, args.num_epochs + 1): |
|
total_cross_entropy, total_batches = 0.0, 0 |
|
|
|
for step, (x, y) in enumerate( |
|
tqdm(train_loader, desc=f"Epoch {epoch}", leave=False), start=1 |
|
): |
|
x = x.to(args.device, non_blocking=True) |
|
y = y.to(args.device, non_blocking=True) |
|
|
|
with amp_context: |
|
_, loss = model(x, y) |
|
|
|
scaled_loss = loss / args.gradient_accumulation_steps |
|
|
|
scaled_loss.backward() |
|
|
|
total_cross_entropy += loss.item() |
|
|
|
if step % args.gradient_accumulation_steps == 0: |
|
optimizer.step() |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
total_batches += 1 |
|
|
|
average_cross_entropy = total_cross_entropy / total_batches |
|
|
|
logger.add_scalar("cross entropy", average_cross_entropy, epoch) |
|
|
|
print( |
|
f"Epoch {epoch}: Cross Entropy: {average_cross_entropy:.5f}", |
|
) |
|
|
|
if epoch % args.eval_interval == 0: |
|
model.eval() |
|
|
|
for x, y in tqdm(test_loader, desc="Testing", leave=False): |
|
x = x.to(args.device, non_blocking=True) |
|
y = y.to(args.device, non_blocking=True) |
|
|
|
with torch.no_grad(): |
|
y_pred, _ = model(x) |
|
|
|
perplexity_metric.update(y_pred, y) |
|
|
|
perplexity = perplexity_metric.compute() |
|
|
|
logger.add_scalar("perplexity", perplexity, epoch) |
|
|
|
print(f"Perplexity: {perplexity:.3f}") |
|
|
|
perplexity_metric.reset() |
|
|
|
model.train() |
|
|
|
if epoch % args.checkpoint_interval == 0: |
|
checkpoint = { |
|
"epoch": epoch, |
|
"lora_args": lora_args, |
|
"lora": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
} |
|
|
|
torch.save(checkpoint, args.checkpoint_path) |
|
|
|
print("Checkpoint saved") |
|
|
|
print("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|