HPSv2 / src /training /train.py
tgxs002's picture
init
54199b6
raw
history blame
No virus
21 kB
import hashlib
import itertools
import json
import logging
import math
import random
import os
import tempfile
import time
import einops
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel
from .data import ImageRewardDataset, RankingDataset
from open_clip import get_cast_dtype, CLIP, CustomTextCLIP
from .distributed import is_master, barrier
from .zero_shot import zero_shot_eval
from .precision import get_autocast
from ..open_clip.loss import PreferenceLoss, RankingLoss, HPSLoss
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def postprocess_clip_output(model_out):
return {
"image_features": model_out[0],
"text_features": model_out[1],
"logit_scale": model_out[2]
}
def unwrap_model(model):
if hasattr(model, 'module'):
return model.module
else:
return model
def backward(total_loss, scaler):
if scaler is not None:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
def random_sampling_iterator(iterators, sampling_ratios, data_types, num_iters):
iterators = [iter(iterator) for iterator in iterators]
num_iterators = len(iterators)
loop_counter = 0
while loop_counter < num_iters:
current_state = random.getstate()
random.seed(loop_counter)
iterator_idx = random.choices(range(num_iterators), sampling_ratios)[0]
random.setstate(current_state)
yield next(iterators[iterator_idx]), data_types[iterator_idx]
loop_counter += 1
def train_iters(model, data, iterations, optimizer, scaler, scheduler, dist_model, args, tb_writer=None):
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
model.train()
ce_loss = PreferenceLoss()
mse_loss = torch.nn.MSELoss()
rk_loss = RankingLoss()
hps_loss = HPSLoss()
if args.distill:
dist_model.eval()
for train_set in data['train']:
train_set.set_epoch(0) # set epoch in process safe manner via sampler or shared_epoch
data_types = [d.data_type for d in data['train']]
train_data_sample_ratios = [sample_ratio for sample_ratio, ignore in zip(args.train_data_sample_ratio, args.ignore_in_train) if not ignore]
dataloader = random_sampling_iterator([dataset.dataloader for dataset in data['train']], train_data_sample_ratios, data_types, iterations)
sample_digits = math.ceil(math.log(sum([dataset.dataloader.num_samples for dataset in data['train']]) + 1, 10))
losses_m = {}
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for step, (batch, data_type) in enumerate(dataloader):
# TODO: currently only test on accum_freq==1
if not args.skip_scheduler:
scheduler(step)
if data_type == 'preference':
images, num_images, labels, texts = batch
texts = texts.to(device=device, non_blocking=True)
elif data_type == 'rating':
images, labels = batch
elif data_type == 'regional':
images, labels = batch
elif data_type == 'ranking':
images, num_images, labels, texts = batch
texts = texts.to(device=device, non_blocking=True)
elif data_type == 'HPD':
images, labels, texts = batch
# num_per_prompts = num_per_prompts.to(device=device, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
labels = labels.to(device=device, non_blocking=True)
data_time_m.update(time.time() - end)
optimizer.zero_grad()
if args.accum_freq == 1:
with autocast():
if data_type == 'rating' or args.no_text_condition:
image_features = unwrap_model(model).visual(images)
scores = unwrap_model(model).score_predictor(image_features)
if args.no_text_condition:
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))]
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
total_loss = F.cross_entropy(paired_logits, labels)
else:
total_loss = mse_loss(scores.squeeze(), labels.to(scores.dtype))
elif data_type == 'preference' :
output = model(images, texts)
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"]
# total_loss = loss(image_features, text_features, logit_scale)
logits_per_image = logit_scale * image_features @ text_features.T
total_loss = ce_loss(logits_per_image, num_images, labels)
elif data_type == 'HPD':
output = model(images, texts)
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"]
logits_per_text = logit_scale * text_features @ image_features.T
total_loss = hps_loss(logits_per_text, labels)
elif data_type == 'ranking':
output = model(images, texts)
image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"]
# logits_per_image = logit_scale * image_features @ text_features.T
score = logit_scale * image_features @ text_features.T
total_loss = rk_loss(score, num_images, labels, args.margin)
elif data_type == 'regional':
# logit_scale = model.logit_scale
feature_map = unwrap_model(model).visual(images, skip_pool=True)[:, 1:]
logits = unwrap_model(model).region_predictor(feature_map)
wh = int(math.sqrt(feature_map.size(1)))
ps = images.size(2) // wh
logits = logits.unflatten(1, (wh, wh))[:,:,:,0]
# downsample the labels to match the feature map size
patches = einops.reduce(labels, 'b (h s1) (w s2) -> b h w', 'mean', s1=ps, s2=ps)
patches = (patches > 0).float()
total_loss = mse_loss(logits.sigmoid(), patches.to(patches.dtype))
backward(total_loss, scaler)
losses = dict(total_loss=total_loss)
if scaler is not None:
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
if args.grad_clip_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
scaler.step(optimizer)
scaler.update()
else:
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
optimizer.step()
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).logit_scale.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = step + 1
if is_master(args) and (step % args.log_every_n_steps == 0 or batch_count == iterations):
batch_size = len(images)
num_samples = batch_count * args.accum_freq
percent_complete = 100.0 * batch_count / iterations
# NOTE loss is coarsely sampled, just master node and per log update
for key, val in losses.items():
if key not in losses_m:
losses_m[key] = AverageMeter()
losses_m[key].update(val.item(), batch_size)
logit_scale_scalar = unwrap_model(model).logit_scale.item()
loss_log = " ".join(
[
f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})"
for loss_name, loss_m in losses_m.items()
]
)
samples_per_second = args.accum_freq * args.world_size / batch_time_m.val
samples_per_second_per_gpu = args.accum_freq / batch_time_m.val
logging.info(
f"Train iterations: [{num_samples:>{sample_digits}}/{iterations} ({percent_complete:.0f}%)] "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"scale": logit_scale_scalar,
"lr": optimizer.param_groups[0]["lr"]
}
log_data.update({name:val.val for name,val in losses_m.items()})
for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
def evaluate_preference(model, data, args):
model = unwrap_model(model)
model.eval()
dataloader = data.dataloader
samples_per_val = dataloader.num_samples
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
total = 0
correct = 0
with torch.no_grad():
for i, batch in enumerate(dataloader):
if i % args.world_size != args.rank:
continue
images, num_images, labels, texts = batch
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
with autocast():
if args.no_text_condition:
image_features = model.visual(images)
logit_scale = model.logit_scale
scores = model.score_predictor(image_features)
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))]
else:
outputs = model(images, texts)
image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"]
logits_per_image = logit_scale * image_features @ text_features.T
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
predicted = torch.tensor([k.argmax().item() for k in paired_logits_list])
correct += (predicted == labels).int().sum().item()
total += predicted.numel()
# write to a temp file
file_name = hashlib.md5(str(args.name).encode()).hexdigest()
with open(f"{file_name}_{args.rank}.json", "w") as f:
json.dump(dict(
correct=correct,
total=total,
), f)
time.sleep(0.1)
barrier(args)
correct = 0
total = 0
if is_master(args):
for i in range(args.world_size):
with open(f"{file_name}_{i}.json", "r") as f:
data = json.load(f)
correct += data["correct"]
total += data["total"]
os.remove(f"{file_name}_{i}.json")
logging.info(
f"Final Acc: {correct / total:.4f}\t")
return correct / (total + 1e-6)
def evaluate_regional(model, data, args):
dataloader = data.dataloader
samples_per_val = dataloader.num_samples
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
num_samples = len(dataloader)
threshold = 0.5
with torch.no_grad():
score = 0
total = 0
for i, batch in enumerate(dataloader):
images, labels = batch
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
labels = labels.to(device=device, non_blocking=True)
with autocast():
feature_map = model.visual(images, skip_pool=True)[:, 1:]
logits = model.region_predictor(feature_map)
wh = int(math.sqrt(feature_map.size(1)))
ps = images.size(2) // wh
logits = logits.unflatten(1, (wh, wh))[:,:,:,0]
# downsample the labels to match the feature map size
patches = einops.reduce(labels, 'b (h s1) (w s2) -> b h w', 'mean', s1=ps, s2=ps)
patches = (patches > 0).float()
pred_mask = (logits.sigmoid() > threshold).float()
#calc IOU
intersection = (pred_mask * patches).sum()
union = pred_mask.sum() + patches.sum() - intersection
iou_score = intersection / union
score += iou_score
total += 1
if is_master(args) and (i % 100) == 0:
logging.info(
# f"[{i} / {samples_per_val}]\t"
f"[{i} / {len(dataloader)}]\t"
f"Current IoU: {score / (total + 0.001):.4f}\t")
if is_master(args):
logging.info(
f"Final IoU: {score / (total + 0.001):.4f}\t")
return score / (total + 0.001)
def inversion_score(p1, p2):
assert len(p1) == len(p2), f'{len(p1)}, {len(p2)}'
n = len(p1)
cnt = 0
for i in range(n-1):
for j in range(i+1, n):
if p1[i] > p1[j] and p2[i] < p2[j]:
cnt += 1
elif p1[i] < p1[j] and p2[i] > p2[j]:
cnt += 1
return 1 - cnt / (n * (n - 1) / 2)
def model_pair_score(score:dict, p1, p2, num_image):
model_pairs = set()
for i in range(num_image):
if i not in score.keys():
score[i] = {}
for j in range(num_image):
if j not in score[i].keys():
score[i][j] = 0
if j == i or (i, j) in model_pairs or (j, i) in model_pairs:
continue
model_pairs.add((i,j))
if (p1[i] - p1[j]) * (p2[i] - p2[j]) > 0:
score[i][j] += 1
return score
def all_gather(tensor):
world_size = torch.distributed.get_world_size()
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, tensor, async_op=False)
return torch.cat(tensor_list, dim=0)
def evaluate_ranking(model, data, args):
model = unwrap_model(model)
model.eval()
dataloader = data.dataloader
samples_per_val = dataloader.num_samples
device = torch.device(args.device)
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
score = 0
# pair_score = {}
with torch.no_grad():
for i, batch in enumerate(dataloader):
if i % args.world_size != args.local_rank:
continue
images, num_images, labels, texts = batch
images = images.to(device=device, dtype=cast_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
num_images = num_images.to(device=device, non_blocking=True)
labels = labels.to(device=device, non_blocking=True)
with autocast():
if args.no_text_condition:
image_features = model.visual(images)
logit_scale = model.logit_scale
scores = model.score_predictor(image_features)
paired_logits_list = [logit[:,0] for i, logit in enumerate(scores.split(num_images.tolist()))]
else:
outputs = model(images, texts)
image_features, text_features, logit_scale = outputs["image_features"], outputs["text_features"], outputs["logit_scale"]
logits_per_image = logit_scale * image_features @ text_features.T
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
predicted = [torch.argsort(-k) for k in paired_logits_list]
hps_ranking = [[predicted[i].tolist().index(j) for j in range(n)] for i,n in enumerate(num_images)]
labels = [label for label in labels.split(num_images.tolist())]
if isinstance(dataloader.dataset, RankingDataset):
score += sum([inversion_score(hps_ranking[i], labels[i]) for i in range(len(hps_ranking))])
elif isinstance(dataloader.dataset, ImageRewardDataset):
score +=sum([calc_ImageReward(paired_logits_list[i].tolist(), labels[i]) for i in range(len(hps_ranking))])
# write score to a tempfile, file name is a hash string
file_name = hashlib.md5(str(args.name).encode()).hexdigest()
with open(f"{file_name}_{args.rank}.tmp", "w") as f:
f.write(str(score))
time.sleep(0.1)
barrier(args)
score = 0
if is_master(args):
for i in range(args.world_size):
with open(f"{file_name}_{i}.tmp", "r") as f:
score += float(f.read())
os.remove(f"{file_name}_{i}.tmp")
score = score / samples_per_val
logging.info(
f"Final Acc: {score:.4f}\t")
# return score, pair_score
return score
def calc_ImageReward( pred, gt):
# using inversion score calculate method in ImageReward
# There's some little difference because ImageReward benchmark has tie rankings
tol_cnt = 0.
true_cnt = 0.
for idx in range(len(gt)):
item_base = gt
item = pred
for i in range(len(item_base)):
for j in range(i+1, len(item_base)):
if item_base[i] > item_base[j]:
if item[i] >= item[j]:
tol_cnt += 1
elif item[i] < item[j]:
tol_cnt += 1
true_cnt += 1
elif item_base[i] < item_base[j]:
if item[i] > item[j]:
tol_cnt += 1
true_cnt += 1
elif item[i] <= item[j]:
tol_cnt += 1
return true_cnt / tol_cnt
def get_clip_metrics(image_features, text_features, logit_scale):
metrics = {}
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
return metrics
def maybe_compute_generative_loss(model_out):
if "logits" in model_out and "labels" in model_out:
token_logits = model_out["logits"]
token_labels = model_out["labels"]
return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels)