Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from modules import Paella | |
from torch import nn, optim | |
from warmup_scheduler import GradualWarmupScheduler | |
from utils import get_dataloader, load_conditional_models | |
steps = 100_000 | |
warmup_updates = 10000 | |
batch_size = 16 | |
checkpoint_frequency = 2000 | |
lr = 1e-4 | |
train_device = "cuda" | |
dataset_path = "" | |
byt5_model_name = "google/byt5-xl" | |
vqmodel_path = "" | |
run_name = "Paella-ByT5-XL-v1" | |
output_path = "output" | |
checkpoint_path = f"{run_name}.pt" | |
def train(): | |
os.makedirs(output_path, exist_ok=True) | |
device = torch.device(train_device) | |
dataloader = get_dataloader(dataset_path, batch_size=batch_size) | |
checkpoint = torch.load(checkpoint_path, map_location=device) if os.path.exists(checkpoint_path) else None | |
model = Paella(byt5_embd=2560).to(device) | |
vqgan, (byt5_tokenizer, byt5) = load_conditional_models(byt5_model_name, vqmodel_path, device) | |
optimizer = optim.AdamW(model.parameters(), lr=lr) | |
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_updates) | |
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, reduction='none') | |
start_iter = 1 | |
if checkpoint is not None: | |
model.load_state_dict(checkpoint['state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
scheduler.last_epoch = checkpoint['scheduler_last_step'] | |
start_iter = checkpoint['scheduler_last_step'] + 1 | |
del checkpoint | |
pbar = tqdm(range(start_iter, steps+1)) | |
model.train() | |
for i, (images, captions) in enumerate(dataloader): | |
images = images.to(device) | |
with torch.no_grad(): | |
if np.random.rand() < 0.05: | |
byt5_captions = [''] * len(captions) | |
else: | |
byt5_captions = captions | |
byt5_tokens = byt5_tokenizer(byt5_captions, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device) | |
byt_embeddings = byt5(input_ids=byt5_tokens).last_hidden_state | |
t = (1-torch.rand(images.size(0), device=device)) | |
latents = vqgan.encode(images)[2] | |
noised_latents, _ = model.add_noise(latents, t) | |
pred = model(noised_latents, t, byt_embeddings) | |
loss = criterion(pred, latents) | |
loss.backward() | |
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
scheduler.step() | |
optimizer.zero_grad() | |
acc = (pred.argmax(1) == latents).float().mean() | |
pbar.set_postfix({'bs': images.size(0), 'loss': loss.item(), 'acc': acc.item(), 'grad_norm': grad_norm.item(), 'lr': optimizer.param_groups[0]['lr'], 'total_steps': scheduler.last_epoch}) | |
if i % checkpoint_frequency == 0: | |
torch.save({'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_last_step': scheduler.last_epoch, 'iter' : i}, checkpoint_path) | |
if __name__ == '__main__': | |
train() |