paella / Paella /src /train.py
pcuenq's picture
pcuenq HF staff
Add copy of github repo
cab8a49
raw
history blame
2.98 kB
import os
import torch
import numpy as np
from tqdm import tqdm
from modules import Paella
from torch import nn, optim
from warmup_scheduler import GradualWarmupScheduler
from utils import get_dataloader, load_conditional_models
steps = 100_000
warmup_updates = 10000
batch_size = 16
checkpoint_frequency = 2000
lr = 1e-4
train_device = "cuda"
dataset_path = ""
byt5_model_name = "google/byt5-xl"
vqmodel_path = ""
run_name = "Paella-ByT5-XL-v1"
output_path = "output"
checkpoint_path = f"{run_name}.pt"
def train():
os.makedirs(output_path, exist_ok=True)
device = torch.device(train_device)
dataloader = get_dataloader(dataset_path, batch_size=batch_size)
checkpoint = torch.load(checkpoint_path, map_location=device) if os.path.exists(checkpoint_path) else None
model = Paella(byt5_embd=2560).to(device)
vqgan, (byt5_tokenizer, byt5) = load_conditional_models(byt5_model_name, vqmodel_path, device)
optimizer = optim.AdamW(model.parameters(), lr=lr)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_updates)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, reduction='none')
start_iter = 1
if checkpoint is not None:
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.last_epoch = checkpoint['scheduler_last_step']
start_iter = checkpoint['scheduler_last_step'] + 1
del checkpoint
pbar = tqdm(range(start_iter, steps+1))
model.train()
for i, (images, captions) in enumerate(dataloader):
images = images.to(device)
with torch.no_grad():
if np.random.rand() < 0.05:
byt5_captions = [''] * len(captions)
else:
byt5_captions = captions
byt5_tokens = byt5_tokenizer(byt5_captions, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
byt_embeddings = byt5(input_ids=byt5_tokens).last_hidden_state
t = (1-torch.rand(images.size(0), device=device))
latents = vqgan.encode(images)[2]
noised_latents, _ = model.add_noise(latents, t)
pred = model(noised_latents, t, byt_embeddings)
loss = criterion(pred, latents)
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scheduler.step()
optimizer.zero_grad()
acc = (pred.argmax(1) == latents).float().mean()
pbar.set_postfix({'bs': images.size(0), 'loss': loss.item(), 'acc': acc.item(), 'grad_norm': grad_norm.item(), 'lr': optimizer.param_groups[0]['lr'], 'total_steps': scheduler.last_epoch})
if i % checkpoint_frequency == 0:
torch.save({'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_last_step': scheduler.last_epoch, 'iter' : i}, checkpoint_path)
if __name__ == '__main__':
train()