|
import os
|
|
import random
|
|
import numpy as np
|
|
from PIL import Image
|
|
from loguru import logger
|
|
import sys
|
|
import inspect
|
|
import math
|
|
import torch
|
|
import torch.distributed as dist
|
|
from collections import OrderedDict
|
|
from torch import nn
|
|
|
|
def init_random_seed(seed=None, device='cuda', rank=0, world_size=1):
|
|
"""Initialize random seed."""
|
|
if seed is not None:
|
|
return seed
|
|
|
|
|
|
|
|
|
|
seed = np.random.randint(2**31)
|
|
if world_size == 1:
|
|
return seed
|
|
|
|
if rank == 0:
|
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
|
else:
|
|
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
|
dist.broadcast(random_num, src=0)
|
|
return random_num.item()
|
|
|
|
def set_random_seed(seed, deterministic=False):
|
|
"""Set random seed."""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
if deterministic:
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
def worker_init_fn(worker_id, num_workers, rank, seed):
|
|
|
|
|
|
worker_seed = num_workers * rank + worker_id + seed
|
|
np.random.seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
class AverageMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self, name, fmt=":f"):
|
|
self.name = name
|
|
self.fmt = fmt
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
def __str__(self):
|
|
if self.name == "Lr":
|
|
fmtstr = "{name}={val" + self.fmt + "}"
|
|
else:
|
|
fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})"
|
|
return fmtstr.format(**self.__dict__)
|
|
|
|
class ProgressMeter(object):
|
|
def __init__(self, num_batches, meters, prefix=""):
|
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
|
self.meters = meters
|
|
self.prefix = prefix
|
|
|
|
def display(self, batch):
|
|
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
|
entries += [str(meter) for meter in self.meters]
|
|
logger.info(" ".join(entries))
|
|
|
|
def _get_batch_fmtstr(self, num_batches):
|
|
num_digits = len(str(num_batches // 1))
|
|
fmt = "{:" + str(num_digits) + "d}"
|
|
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
|
|
|
def get_caller_name(depth=0):
|
|
"""
|
|
Args:
|
|
depth (int): Depth of caller conext, use 0 for caller depth.
|
|
Default value: 0.
|
|
|
|
Returns:
|
|
str: module name of the caller
|
|
"""
|
|
|
|
frame = inspect.currentframe().f_back
|
|
for _ in range(depth):
|
|
frame = frame.f_back
|
|
|
|
return frame.f_globals["__name__"]
|
|
|
|
class StreamToLoguru:
|
|
"""
|
|
stream object that redirects writes to a logger instance.
|
|
"""
|
|
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
|
|
"""
|
|
Args:
|
|
level(str): log level string of loguru. Default value: "INFO".
|
|
caller_names(tuple): caller names of redirected module.
|
|
Default value: (apex, pycocotools).
|
|
"""
|
|
self.level = level
|
|
self.linebuf = ""
|
|
self.caller_names = caller_names
|
|
|
|
def write(self, buf):
|
|
full_name = get_caller_name(depth=1)
|
|
module_name = full_name.rsplit(".", maxsplit=-1)[0]
|
|
if module_name in self.caller_names:
|
|
for line in buf.rstrip().splitlines():
|
|
|
|
logger.opt(depth=2).log(self.level, line.rstrip())
|
|
else:
|
|
sys.__stdout__.write(buf)
|
|
|
|
def flush(self):
|
|
pass
|
|
|
|
def redirect_sys_output(log_level="INFO"):
|
|
redirect_logger = StreamToLoguru(log_level)
|
|
sys.stderr = redirect_logger
|
|
sys.stdout = redirect_logger
|
|
|
|
def setup_logger(save_dir, filename="log.txt", mode="a"):
|
|
"""setup logger for training and testing.
|
|
Args:
|
|
save_dir(str): location to save log file
|
|
distributed_rank(int): device rank when multi-gpu environment
|
|
filename (string): log save name.
|
|
mode(str): log file write mode, `append` or `override`. default is `a`.
|
|
|
|
Return:
|
|
logger instance.
|
|
"""
|
|
loguru_format = (
|
|
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
|
|
"<level>{level: <8}</level> | "
|
|
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
|
|
|
|
logger.remove()
|
|
save_file = os.path.join(save_dir, filename)
|
|
if mode == "o" and os.path.exists(save_file):
|
|
os.remove(save_file)
|
|
|
|
|
|
logger.add(
|
|
sys.stderr,
|
|
format=loguru_format,
|
|
level="INFO",
|
|
enqueue=True,
|
|
)
|
|
logger.add(save_file)
|
|
|
|
|
|
redirect_sys_output("INFO")
|
|
|
|
def trainMetric(pred, label):
|
|
pred = torch.argmax(pred,dim = 1)
|
|
prec = torch.sum(pred == label)
|
|
|
|
return prec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def token_wise_similarity(rep1, rep2, mask=None, chunk_size=1024):
|
|
batch_size1, n_token1, feat_dim = rep1.shape
|
|
batch_size2, n_token2, _ = rep2.shape
|
|
num_folds = math.ceil(batch_size2 / chunk_size)
|
|
output = []
|
|
for i in range(num_folds):
|
|
rep2_seg = rep2[i * chunk_size:(i + 1) * chunk_size]
|
|
out_i = rep1.reshape(-1, feat_dim) @ rep2_seg.reshape(-1, feat_dim).T
|
|
out_i = out_i.reshape(batch_size1, n_token1, -1, n_token2).max(3)[0]
|
|
if mask is None:
|
|
out_i = out_i.mean(1)
|
|
else:
|
|
out_i = out_i.sum(1)
|
|
output.append(out_i)
|
|
output = torch.cat(output, dim=1)
|
|
if mask is not None:
|
|
output = output / mask.sum(1, keepdim=True).clamp_(min=1)
|
|
return output
|
|
|
|
def compute_acc(logits, targets, topk=5):
|
|
targets = targets.squeeze(1)
|
|
p = logits.topk(topk, 1, True, True)[1]
|
|
pred = logits.topk(topk, 1, True, True)[1]
|
|
gt = targets[pred,:]
|
|
|
|
a = gt.view(1, -1)
|
|
|
|
|
|
c = gt.eq(targets)
|
|
correct = pred.eq(targets.view(1, -1).expand_as(pred)).contiguous()
|
|
acc_1 = correct[:1].sum(0)
|
|
acc_k = correct[:topk].sum(0)
|
|
return acc_1, acc_k
|
|
|
|
def compute_mAP(predicted_probs, true_labels):
|
|
aps = compute_AP(predicted_probs, true_labels)
|
|
aps = [ap for ap in aps if not torch.isnan(ap)]
|
|
mAP = torch.mean(torch.tensor(aps))
|
|
return mAP
|
|
|
|
def compute_F1(predictions, labels, k_val=5):
|
|
labels = labels.squeeze(1)
|
|
idx = predictions.topk(dim=1, k=k_val)[1]
|
|
predictions.fill_(0)
|
|
predictions.scatter_(dim=1, index=idx, src=torch.ones(predictions.size(0), k_val).to(predictions.device))
|
|
mask = predictions == 1
|
|
TP = (labels[mask] == 1).sum().float()
|
|
tpfp = mask.sum().float()
|
|
tpfn = (labels == 1).sum().float()
|
|
p = TP / tpfp
|
|
r = TP/tpfn
|
|
f1 = 2*p*r/(p+r)
|
|
|
|
return f1, p, r
|
|
|
|
def compute_AP(predictions, labels):
|
|
num_class = predictions.size(1)
|
|
ap = torch.zeros(num_class).to(predictions.device)
|
|
empty_class = 0
|
|
for idx_cls in range(num_class):
|
|
prediction = predictions[:, idx_cls]
|
|
label = labels[:, idx_cls]
|
|
mask = label.abs() == 1
|
|
if (label > 0).sum() == 0:
|
|
empty_class += 1
|
|
continue
|
|
binary_label = torch.clamp(label[mask], min=0, max=1)
|
|
sorted_pred, sort_idx = prediction[mask].sort(descending=True)
|
|
sorted_label = binary_label[sort_idx]
|
|
tmp = (sorted_label == 1).float()
|
|
tp = tmp.cumsum(0)
|
|
fp = (sorted_label != 1).float().cumsum(0)
|
|
num_pos = binary_label.sum()
|
|
rec = tp/num_pos
|
|
prec = tp/(tp+fp)
|
|
ap_cls = (tmp*prec).sum()/num_pos
|
|
ap[idx_cls].copy_(ap_cls)
|
|
return ap, empty_class
|
|
|
|
def compute_ACG(predictions, labels, k_val=5):
|
|
gt = labels.squeeze(1)
|
|
idx = predictions.topk(dim=1, k=k_val)[1]
|
|
pred = gt[idx, :]
|
|
pred[pred == -1] = 0
|
|
c = labels.eq(pred)
|
|
r = c.sum(-1)
|
|
|
|
acg = c.sum(-1).sum(-1) / k_val
|
|
lg = torch.log1p(torch.arange(1, k_val+1, 1) ).to(r.device)
|
|
|
|
dcg = (torch.pow(2, r) - 1) / lg
|
|
ir, _ = r.sort(-1, descending=True)
|
|
idcg = (torch.pow(2, ir) - 1) / lg
|
|
idcg[idcg == 0] = 1e-6
|
|
ndcg = dcg.sum(-1) / idcg.sum(-1)
|
|
|
|
pos = r.clone()
|
|
pos[pos != 0] = 1
|
|
j = torch.arange(1, k_val + 1, 1).to(pos.device)
|
|
P = torch.cumsum(pos, 1) / j
|
|
Npos = torch.sum(pos, 1)
|
|
Npos[Npos == 0] = 1
|
|
AP = torch.sum(P * pos, 1)
|
|
map = torch.sum(P * pos, 1) / Npos
|
|
|
|
acgj = torch.cumsum(r, 1) / j
|
|
wmap = torch.sum(acgj * pos, 1) / Npos
|
|
|
|
|
|
|
|
return acg, ndcg, map, wmap
|
|
|
|
def compute_mAPw(predictions, labels, k_val=5):
|
|
gt = labels.squeeze(1)
|
|
idx = predictions.topk(dim=1, k=k_val)[1]
|
|
pred = gt[idx, :]
|
|
pred[pred == -1] = 0
|
|
c = labels.eq(pred)
|
|
r = c.sum(-1)
|
|
pos = r.clone()
|
|
pos[pos != 0] = 1
|
|
P = torch.cumsum(pos) / torch.arange(1, k_val+1, 1)
|
|
|
|
|
|
def adjust_learning_rate(optimizer, epoch, args):
|
|
"""Decay the learning rate with half-cycle cosine after warmup"""
|
|
if epoch < args.warmup_epochs:
|
|
lr = args.base_lr * epoch / args.warmup_epochs
|
|
else:
|
|
lr = args.min_lr + (args.base_lr - args.min_lr) * 0.5 * \
|
|
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
|
for param_group in optimizer.param_groups:
|
|
if "lr_scale" in param_group:
|
|
param_group["lr"] = lr * param_group["lr_scale"]
|
|
else:
|
|
param_group["lr"] = lr
|
|
return lr
|
|
|
|
def load_ckpt(weight_dir, model, map_location, args):
|
|
checkpoint = torch.load(weight_dir, map_location=map_location)
|
|
if args.resume:
|
|
resume_epoch = checkpoint['epoch']
|
|
else:
|
|
resume_epoch = 0
|
|
pre_weight = checkpoint['state_dict']
|
|
|
|
new_pre_weight = OrderedDict()
|
|
|
|
model_dict = model.state_dict()
|
|
new_model_dict = OrderedDict()
|
|
for k, v in pre_weight.items():
|
|
new_k = k.replace('module.', '') if 'module' in k else k
|
|
|
|
|
|
new_pre_weight[new_k] = v
|
|
|
|
|
|
|
|
pre_weight = new_pre_weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_dict.update(pre_weight)
|
|
model.load_state_dict(model_dict, strict=False)
|
|
|
|
return model, resume_epoch
|
|
|
|
def load_ckpt_fpn(weight_dir, model, map_location):
|
|
|
|
pre_weight = torch.load(weight_dir, map_location=map_location)['state_dict']
|
|
epoch = torch.load(weight_dir, map_location=map_location)['epoch']
|
|
new_pre_weight = OrderedDict()
|
|
|
|
model_dict = model.state_dict()
|
|
|
|
for k, v in pre_weight.items():
|
|
new_k = k.replace('module.', '') if 'module' in k else k
|
|
|
|
new_pre_weight[new_k] = v
|
|
|
|
pre_weight = new_pre_weight
|
|
|
|
model_dict.update(pre_weight)
|
|
model.load_state_dict(model_dict, strict=True)
|
|
|
|
return model, epoch
|
|
def load_ckpt_old(weight_dir, model, map_location):
|
|
|
|
pre_weight = torch.load(weight_dir, map_location=map_location)['state_dict']
|
|
epoch = torch.load(weight_dir, map_location=map_location)['epoch']
|
|
new_pre_weight = OrderedDict()
|
|
|
|
model_dict = model.state_dict()
|
|
|
|
for k, v in pre_weight.items():
|
|
new_k = k.replace('module.', '') if 'module' in k else k
|
|
if not (new_k.startswith('FPN') or new_k.startswith('gap')):
|
|
new_pre_weight[new_k] = v
|
|
|
|
pre_weight = new_pre_weight
|
|
|
|
model_dict.update(pre_weight)
|
|
model.load_state_dict(model_dict, strict=False)
|
|
|
|
return model, epoch
|
|
|
|
def compare_ckpt(model1, model2):
|
|
V = dict()
|
|
for k, v in model1.items():
|
|
if k.startswith('projT'):
|
|
V[k] = v
|
|
|
|
for k, v in model2.items():
|
|
if k in sorted(V.keys()):
|
|
model2[k] = V[k]
|
|
|
|
return model2 |