magma / magma /train_loop.py
stellaathena's picture
This should work
bb5cd12
import torch
from tqdm import tqdm
from .utils import reduce_losses, to_cuda_half
from torchvision.utils import make_grid
def train_step(config, train_loader, model_engine):
losses = []
for _ in range(config.gradient_accumulation_steps):
images, captions = next(train_loader)
images, captions = images.half().cuda(), captions.cuda()
if config.run_blind:
images = torch.zeros_like(images)
outputs = model_engine(images, captions)
loss = outputs.loss
losses.append(loss)
model_engine.backward(loss)
model_engine.step()
return reduce_losses(torch.mean(torch.stack(losses))).item()
def train_step_classification(config, train_loader, model_engine, return_accuracy=True):
losses = []
if return_accuracy:
accuracies = []
for _ in range(config.gradient_accumulation_steps):
images, captions, class_labels = next(train_loader)
images, captions, class_labels = to_cuda_half(images, captions, class_labels)
if config.run_blind:
images = torch.zeros_like(images)
loss, logits = model_engine(images, captions, class_labels)
losses.append(loss)
if return_accuracy:
argmax_pred = logits.argmax(dim=-1)
accuracies.append((argmax_pred == class_labels).float().mean())
model_engine.backward(loss)
model_engine.step()
loss_reduced = reduce_losses(torch.mean(torch.stack(losses))).item()
if return_accuracy:
accuracy_reduced = reduce_losses(torch.mean(torch.stack(accuracies))).item()
return loss_reduced, accuracy_reduced
return loss_reduced
def eval_step(config, eval_loader, model_engine):
losses = []
for i in tqdm(range(config.eval_steps), "evaluating..."):
images, captions = next(eval_loader)
images, captions = images.half().cuda(), captions.cuda()
if config.run_blind:
images = torch.zeros_like(images)
outputs = model_engine(images, captions)
loss = outputs.loss
losses.append(loss)
return reduce_losses(torch.mean(torch.stack(losses))).item()
def eval_step_classification(config, train_loader, model_engine, return_accuracy=True):
losses = []
if return_accuracy:
accuracies = []
for _ in range(config.gradient_accumulation_steps):
images, captions, class_labels = next(train_loader)
images, captions, class_labels = to_cuda_half(images, captions, class_labels)
if config.run_blind:
images = torch.zeros_like(images)
loss, logits = model_engine(images, captions, class_labels)
losses.append(loss)
if return_accuracy:
argmax_pred = logits.argmax(dim=-1)
accuracies.append((argmax_pred == class_labels).float().mean())
loss_reduced = reduce_losses(torch.mean(torch.stack(losses))).item()
if return_accuracy:
accuracy_reduced = reduce_losses(torch.mean(torch.stack(accuracies))).item()
return loss_reduced, accuracy_reduced
return loss_reduced
def inference_step(config, eval_loader, model_engine):
images, _ = next(eval_loader)
images = images.half().cuda()
if config.run_blind:
images = torch.zeros_like(images)
captions = model_engine(
images, captions=None, inference=True
) # [caption1, caption2, ... b]
width = min(2, images.shape[0])
image_grid = make_grid(images[:width])
caption = ""
for i in range(width):
caption += f"Caption {i}: \n{captions[i]}\n"
return image_grid, caption