LightGPT / instruction-tune.py
Andrew DalPino
Use SmolTalk dataset for instruction-tuning
111c2a3
import random
from argparse import ArgumentParser
import torch
from torch.utils.data import DataLoader
from torch.optim import Adafactor
from torch.amp import autocast
from torch.cuda import is_available as cuda_is_available, is_bf16_supported
from torch.utils.data import random_split
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.text import Perplexity
from model import LightGPT, LightGPTInstruct
from data import SmolTalk
import tiktoken
from tqdm import tqdm
def main():
parser = ArgumentParser(description="Instruction-tune the GPT.")
parser.add_argument(
"--base_model_path", default="./checkpoints/checkpoint.pt", type=str
)
parser.add_argument("--max_tokens_per_sample", default=1048, type=int)
parser.add_argument("--mask_input", action="store_true")
parser.add_argument("--batch_size", default=1, type=int)
parser.add_argument("--gradient_accumulation_steps", default=64, type=int)
parser.add_argument("--learning_rate", default=5e-4, type=float)
parser.add_argument("--rms_decay", default=-0.8, type=float)
parser.add_argument("--optimizer_low_memory", action="store_true")
parser.add_argument("--num_epochs", default=4, type=int)
parser.add_argument("--rank", default=8, type=int)
parser.add_argument("--alpha", default=1.0, type=float)
parser.add_argument("--dropout", default=0.05, type=float)
parser.add_argument("--activation_checkpointing", action="store_true")
parser.add_argument("--eval_interval", default=1, type=int)
parser.add_argument("--checkpoint_interval", default=1, type=int)
parser.add_argument(
"--checkpoint_path", default="./checkpoints/lora_instruction.pt", type=str
)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--run_dir_path", default="./runs/instruction-tune", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--seed", default=None, type=int)
args = parser.parse_args()
if "cuda" in args.device and not cuda_is_available():
raise RuntimeError("Cuda is not available.")
torch.set_float32_matmul_precision("high")
dtype = (
torch.bfloat16
if "cuda" in args.device and is_bf16_supported()
else torch.float32
)
amp_context = autocast(device_type=args.device, dtype=dtype)
if args.seed:
torch.manual_seed(args.seed)
random.seed(args.seed)
logger = SummaryWriter(args.run_dir_path)
checkpoint = torch.load(
args.base_model_path, map_location=args.device, weights_only=True
)
model_args = checkpoint["model_args"]
tokenizer = tiktoken.get_encoding(checkpoint["token_encoding"])
dataset = SmolTalk(
tokenizer,
subset="all",
max_tokens_per_sample=args.max_tokens_per_sample,
)
training, testing = random_split(dataset, (0.9, 0.1))
train_loader = DataLoader(
training,
collate_fn=dataset.collate,
batch_size=args.batch_size,
pin_memory="cpu" not in args.device,
shuffle=True,
)
test_loader = DataLoader(
testing,
collate_fn=dataset.collate,
batch_size=args.batch_size,
pin_memory="cpu" not in args.device,
shuffle=False,
)
model = LightGPT(**model_args)
if args.activation_checkpointing:
model.enable_activation_checkpointing()
model = torch.compile(model)
model.load_state_dict(checkpoint["model"])
print("Model checkpoint loaded")
lora_args = {
"rank": args.rank,
"alpha": args.alpha,
"dropout": args.dropout,
}
model = LightGPTInstruct(model, **lora_args).to(args.device)
print("Compiling model")
model.compile()
optimizer = Adafactor(
model.parameters(),
lr=args.learning_rate,
beta2_decay=args.rms_decay,
foreach=not args.optimizer_low_memory,
)
starting_epoch = 1
if args.resume:
checkpoint = torch.load(
args.checkpoint_path, map_location=args.device, weights_only=True
)
model.load_state_dict(checkpoint["lora"], strict=False)
optimizer.load_state_dict(checkpoint["optimizer"])
starting_epoch += checkpoint["epoch"]
print("Previous checkpoint resumed successfully")
model.train()
print(f"Model has {model.num_trainable_params:,} trainable parameters")
perplexity_metric = Perplexity(ignore_index=dataset.PADDING_INDEX).to(args.device)
print("Instruction-tuning ...")
for epoch in range(starting_epoch, args.num_epochs + 1):
total_cross_entropy, total_batches = 0.0, 0
for step, (x, y) in enumerate(
tqdm(train_loader, desc=f"Epoch {epoch}", leave=False), start=1
):
x = x.to(args.device, non_blocking=True)
y = y.to(args.device, non_blocking=True)
with amp_context:
_, loss = model(x, y)
scaled_loss = loss / args.gradient_accumulation_steps
scaled_loss.backward()
total_cross_entropy += loss.item()
if step % args.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
total_batches += 1
average_cross_entropy = total_cross_entropy / total_batches
logger.add_scalar("cross entropy", average_cross_entropy, epoch)
print(
f"Epoch {epoch}: Cross Entropy: {average_cross_entropy:.5f}",
)
if epoch % args.eval_interval == 0:
model.eval()
for x, y in tqdm(test_loader, desc="Testing", leave=False):
x = x.to(args.device, non_blocking=True)
y = y.to(args.device, non_blocking=True)
with torch.no_grad():
y_pred, _ = model(x)
perplexity_metric.update(y_pred, y)
perplexity = perplexity_metric.compute()
logger.add_scalar("perplexity", perplexity, epoch)
print(f"Perplexity: {perplexity:.3f}")
perplexity_metric.reset()
model.train()
if epoch % args.checkpoint_interval == 0:
checkpoint = {
"epoch": epoch,
"lora_args": lora_args,
"lora": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, args.checkpoint_path)
print("Checkpoint saved")
print("Done!")
if __name__ == "__main__":
main()