seokju cho
initial commit
f8f62f3
raw
history blame contribute delete
No virus
12.1 kB
import json
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
try:
import wandb
except ImportError:
wandb = None
from open_clip import ClipLoss, get_cast_dtype
from .distributed import is_master
from .zero_shot import zero_shot_eval
from .precision import get_autocast
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def unwrap_model(model):
if hasattr(model, 'module'):
return model.module
else:
return model
def backward(total_loss, scaler):
if scaler is not None:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
def train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
model.train()
loss = ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod)
data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
if args.accum_freq > 1:
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
loss_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for i, batch in enumerate(dataloader):
i_accum = i // args.accum_freq
step = num_batches_per_epoch * epoch + i_accum
if not args.skip_scheduler:
scheduler(step)
images, texts = batch
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
data_time_m.update(time.time() - end)
optimizer.zero_grad()
if args.accum_freq == 1:
with autocast():
image_features, text_features, logit_scale = model(images, texts)
total_loss = loss(image_features, text_features, logit_scale)
backward(total_loss, scaler)
else:
# First, cache the features without any gradient tracking.
with torch.no_grad():
with autocast():
chunk_image_features, chunk_text_features, _ = model(images, texts)
accum_image_features.append(chunk_image_features)
accum_text_features.append(chunk_text_features)
accum_images.append(images)
accum_texts.append(texts)
# If (i + 1) % accum_freq is not zero, move on to the next batch.
if ((i + 1) % args.accum_freq) > 0:
# FIXME this makes data time logging unreliable when accumulating
continue
# Now, ready to take gradients for the last accum_freq batches.
# Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
# Call backwards each time, but only step optimizer at the end.
optimizer.zero_grad()
for j in range(args.accum_freq):
images = accum_images[j]
texts = accum_texts[j]
with autocast():
chunk_image_features, chunk_text_features, logit_scale = model(images, texts)
image_features = torch.cat(
accum_image_features[:j] + [chunk_image_features] + accum_image_features[j + 1:])
text_features = torch.cat(
accum_text_features[:j] + [chunk_text_features] + accum_text_features[j + 1:])
total_loss = loss(image_features, text_features, logit_scale)
backward(total_loss, scaler)
if scaler is not None:
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
if args.grad_clip_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
scaler.step(optimizer)
scaler.update()
else:
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
optimizer.step()
# reset gradient accum, if enabled
if args.accum_freq > 1:
accum_images, accum_texts, accum_image_features, accum_text_features = [], [], [], []
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).logit_scale.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i_accum + 1
if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
batch_size = len(images)
num_samples = batch_count * batch_size * args.accum_freq * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
loss_m.update(total_loss.item(), batch_size)
logit_scale_scalar = logit_scale.item()
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f}, {args.accum_freq * args.batch_size * args.world_size / batch_time_m.val:#g}/s "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale: {logit_scale_scalar:.3f}"
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"loss": loss_m.val,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"samples_per_second": args.accum_freq * args.batch_size * args.world_size / batch_time_m.val,
"scale": logit_scale_scalar,
"lr": optimizer.param_groups[0]["lr"]
}
for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, 'Please install wandb.'
wandb.log({name: val, 'step': step})
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# end for
def evaluate(model, data, epoch, args, tb_writer=None):
metrics = {}
if not is_master(args):
return metrics
device = torch.device(args.device)
model.eval()
zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
metrics.update(zero_shot_metrics)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):
dataloader = data['val'].dataloader
num_samples = 0
samples_per_val = dataloader.num_samples
# FIXME this does not scale past small eval datasets
# all_image_features @ all_text_features will blow up memory and compute very quickly
cumulative_loss = 0.0
all_image_features, all_text_features = [], []
with torch.no_grad():
for i, batch in enumerate(dataloader):
images, texts = batch
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
with autocast():
image_features, text_features, logit_scale = model(images, texts)
# features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
# however, system RAM is easily exceeded and compute time becomes problematic
all_image_features.append(image_features.cpu())
all_text_features.append(text_features.cpu())
logit_scale = logit_scale.mean()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
batch_size = images.shape[0]
labels = torch.arange(batch_size, device=device).long()
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
cumulative_loss += total_loss * batch_size
num_samples += batch_size
if is_master(args) and (i % 100) == 0:
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"
f"Loss: {cumulative_loss / num_samples:.6f}\t")
val_metrics = get_metrics(
image_features=torch.cat(all_image_features),
text_features=torch.cat(all_text_features),
logit_scale=logit_scale.cpu(),
)
loss = cumulative_loss / num_samples
metrics.update(
{**val_metrics, "val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples}
)
if not metrics:
return metrics
logging.info(
f"Eval Epoch: {epoch} "
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)
if args.save_logs:
for name, val in metrics.items():
if tb_writer is not None:
tb_writer.add_scalar(f"val/{name}", val, epoch)
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")
if args.wandb:
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
wandb.log({f"val/{name}": val, 'epoch': epoch})
return metrics
def get_metrics(image_features, text_features, logit_scale):
metrics = {}
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
return metrics