Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import numpy as np | |
| from PIL import Image | |
| from loguru import logger | |
| import sys | |
| import inspect | |
| import math | |
| import torch | |
| import torch.distributed as dist | |
| from collections import OrderedDict | |
| from torch import nn | |
| def init_random_seed(seed=None, device='cuda', rank=0, world_size=1): | |
| """Initialize random seed.""" | |
| if seed is not None: | |
| return seed | |
| # Make sure all ranks share the same random seed to prevent | |
| # some potential bugs. Please refer to | |
| # https://github.com/open-mmlab/mmdetection/issues/6339 | |
| seed = np.random.randint(2**31) | |
| if world_size == 1: | |
| return seed | |
| if rank == 0: | |
| random_num = torch.tensor(seed, dtype=torch.int32, device=device) | |
| else: | |
| random_num = torch.tensor(0, dtype=torch.int32, device=device) | |
| dist.broadcast(random_num, src=0) | |
| return random_num.item() | |
| def set_random_seed(seed, deterministic=False): | |
| """Set random seed.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| if deterministic: | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def worker_init_fn(worker_id, num_workers, rank, seed): | |
| # The seed of each worker equals to | |
| # num_worker * rank + worker_id + user_seed | |
| worker_seed = num_workers * rank + worker_id + seed | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self, name, fmt=":f"): | |
| self.name = name | |
| self.fmt = fmt | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def __str__(self): | |
| if self.name == "Lr": | |
| fmtstr = "{name}={val" + self.fmt + "}" | |
| else: | |
| fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})" | |
| return fmtstr.format(**self.__dict__) | |
| class ProgressMeter(object): | |
| def __init__(self, num_batches, meters, prefix=""): | |
| self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
| self.meters = meters | |
| self.prefix = prefix | |
| def display(self, batch): | |
| entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
| entries += [str(meter) for meter in self.meters] | |
| logger.info(" ".join(entries)) | |
| def _get_batch_fmtstr(self, num_batches): | |
| num_digits = len(str(num_batches // 1)) | |
| fmt = "{:" + str(num_digits) + "d}" | |
| return "[" + fmt + "/" + fmt.format(num_batches) + "]" | |
| def get_caller_name(depth=0): | |
| """ | |
| Args: | |
| depth (int): Depth of caller conext, use 0 for caller depth. | |
| Default value: 0. | |
| Returns: | |
| str: module name of the caller | |
| """ | |
| # the following logic is a little bit faster than inspect.stack() logic | |
| frame = inspect.currentframe().f_back | |
| for _ in range(depth): | |
| frame = frame.f_back | |
| return frame.f_globals["__name__"] | |
| class StreamToLoguru: | |
| """ | |
| stream object that redirects writes to a logger instance. | |
| """ | |
| def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): | |
| """ | |
| Args: | |
| level(str): log level string of loguru. Default value: "INFO". | |
| caller_names(tuple): caller names of redirected module. | |
| Default value: (apex, pycocotools). | |
| """ | |
| self.level = level | |
| self.linebuf = "" | |
| self.caller_names = caller_names | |
| def write(self, buf): | |
| full_name = get_caller_name(depth=1) | |
| module_name = full_name.rsplit(".", maxsplit=-1)[0] | |
| if module_name in self.caller_names: | |
| for line in buf.rstrip().splitlines(): | |
| # use caller level log | |
| logger.opt(depth=2).log(self.level, line.rstrip()) | |
| else: | |
| sys.__stdout__.write(buf) | |
| def flush(self): | |
| pass | |
| def redirect_sys_output(log_level="INFO"): | |
| redirect_logger = StreamToLoguru(log_level) | |
| sys.stderr = redirect_logger | |
| sys.stdout = redirect_logger | |
| def setup_logger(save_dir, filename="log.txt", mode="a"): | |
| """setup logger for training and testing. | |
| Args: | |
| save_dir(str): location to save log file | |
| distributed_rank(int): device rank when multi-gpu environment | |
| filename (string): log save name. | |
| mode(str): log file write mode, `append` or `override`. default is `a`. | |
| Return: | |
| logger instance. | |
| """ | |
| loguru_format = ( | |
| "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " | |
| "<level>{level: <8}</level> | " | |
| "<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>") | |
| logger.remove() | |
| save_file = os.path.join(save_dir, filename) | |
| if mode == "o" and os.path.exists(save_file): | |
| os.remove(save_file) | |
| # only keep logger in rank0 process | |
| logger.add( | |
| sys.stderr, | |
| format=loguru_format, | |
| level="INFO", | |
| enqueue=True, | |
| ) | |
| logger.add(save_file) | |
| # redirect stdout/stderr to loguru | |
| redirect_sys_output("INFO") | |
| def trainMetric(pred, label): | |
| pred = torch.argmax(pred,dim = 1) | |
| prec = torch.sum(pred == label) | |
| return prec | |
| # def compute_AP(predicted_probs, true_labels): | |
| # num_samples, num_classes = true_labels.shape | |
| # | |
| # # 初始化用于存储每个类别的 AP 的列表 | |
| # aps = [] | |
| # | |
| # for class_idx in range(num_classes): | |
| # class_true_labels = true_labels[:, class_idx] | |
| # class_similarity_scores = predicted_probs[:, class_idx] | |
| # | |
| # # 获取按相似性分数排序后的样本索引 | |
| # sorted_indices = torch.argsort(class_similarity_scores, descending=True) | |
| # | |
| # # 计算累积精度和召回率 | |
| # tp = 0 | |
| # fp = 0 | |
| # precision_at_rank = [] | |
| # recall_at_rank = [] | |
| # | |
| # for rank, idx in enumerate(sorted_indices): | |
| # if class_true_labels[idx] == 1: | |
| # tp += 1 | |
| # else: | |
| # fp += 1 | |
| # precision = tp / (tp + fp) | |
| # recall = tp / torch.sum(class_true_labels) | |
| # precision_at_rank.append(precision) | |
| # recall_at_rank.append(recall) | |
| # | |
| # # 计算平均精度(AP)通过计算曲线下的面积 | |
| # precision_at_rank = torch.tensor(precision_at_rank) | |
| # recall_at_rank = torch.tensor(recall_at_rank) | |
| # ap = torch.trapz(precision_at_rank, recall_at_rank) | |
| # | |
| # aps.append(ap) | |
| # | |
| # | |
| # return aps | |
| def token_wise_similarity(rep1, rep2, mask=None, chunk_size=1024): | |
| batch_size1, n_token1, feat_dim = rep1.shape | |
| batch_size2, n_token2, _ = rep2.shape | |
| num_folds = math.ceil(batch_size2 / chunk_size) | |
| output = [] | |
| for i in range(num_folds): | |
| rep2_seg = rep2[i * chunk_size:(i + 1) * chunk_size] | |
| out_i = rep1.reshape(-1, feat_dim) @ rep2_seg.reshape(-1, feat_dim).T | |
| out_i = out_i.reshape(batch_size1, n_token1, -1, n_token2).max(3)[0] | |
| if mask is None: | |
| out_i = out_i.mean(1) | |
| else: | |
| out_i = out_i.sum(1) | |
| output.append(out_i) | |
| output = torch.cat(output, dim=1) | |
| if mask is not None: | |
| output = output / mask.sum(1, keepdim=True).clamp_(min=1) | |
| return output | |
| def compute_acc(logits, targets, topk=5): | |
| targets = targets.squeeze(1) | |
| p = logits.topk(topk, 1, True, True)[1] | |
| pred = logits.topk(topk, 1, True, True)[1] | |
| gt = targets[pred,:] | |
| a = gt.view(1, -1) | |
| # b = a.expand_as(pred) | |
| c = gt.eq(targets) | |
| correct = pred.eq(targets.view(1, -1).expand_as(pred)).contiguous() | |
| acc_1 = correct[:1].sum(0) | |
| acc_k = correct[:topk].sum(0) | |
| return acc_1, acc_k | |
| def compute_mAP(predicted_probs, true_labels): | |
| aps = compute_AP(predicted_probs, true_labels) | |
| aps = [ap for ap in aps if not torch.isnan(ap)] | |
| mAP = torch.mean(torch.tensor(aps)) | |
| return mAP | |
| def compute_F1(predictions, labels, k_val=5): | |
| labels = labels.squeeze(1) | |
| idx = predictions.topk(dim=1, k=k_val)[1] | |
| predictions.fill_(0) | |
| predictions.scatter_(dim=1, index=idx, src=torch.ones(predictions.size(0), k_val).to(predictions.device)) | |
| mask = predictions == 1 | |
| TP = (labels[mask] == 1).sum().float() | |
| tpfp = mask.sum().float() | |
| tpfn = (labels == 1).sum().float() | |
| p = TP / tpfp | |
| r = TP/tpfn | |
| f1 = 2*p*r/(p+r) | |
| return f1, p, r | |
| def compute_AP(predictions, labels): | |
| num_class = predictions.size(1) | |
| ap = torch.zeros(num_class).to(predictions.device) | |
| empty_class = 0 | |
| for idx_cls in range(num_class): | |
| prediction = predictions[:, idx_cls] | |
| label = labels[:, idx_cls] | |
| mask = label.abs() == 1 | |
| if (label > 0).sum() == 0: | |
| empty_class += 1 | |
| continue | |
| binary_label = torch.clamp(label[mask], min=0, max=1) | |
| sorted_pred, sort_idx = prediction[mask].sort(descending=True) | |
| sorted_label = binary_label[sort_idx] | |
| tmp = (sorted_label == 1).float() | |
| tp = tmp.cumsum(0) | |
| fp = (sorted_label != 1).float().cumsum(0) | |
| num_pos = binary_label.sum() | |
| rec = tp/num_pos | |
| prec = tp/(tp+fp) | |
| ap_cls = (tmp*prec).sum()/num_pos | |
| ap[idx_cls].copy_(ap_cls) | |
| return ap, empty_class | |
| def compute_ACG(predictions, labels, k_val=5): | |
| gt = labels.squeeze(1) | |
| idx = predictions.topk(dim=1, k=k_val)[1] | |
| pred = gt[idx, :] | |
| pred[pred == -1] = 0 | |
| c = labels.eq(pred) # common label | |
| r = c.sum(-1) # similarity level | |
| # acg | |
| acg = c.sum(-1).sum(-1) / k_val | |
| lg = torch.log1p(torch.arange(1, k_val+1, 1) ).to(r.device) | |
| # dcg | |
| dcg = (torch.pow(2, r) - 1) / lg | |
| ir, _ = r.sort(-1, descending=True) | |
| idcg = (torch.pow(2, ir) - 1) / lg | |
| idcg[idcg == 0] = 1e-6 | |
| ndcg = dcg.sum(-1) / idcg.sum(-1) | |
| # map | |
| pos = r.clone() | |
| pos[pos != 0] = 1 | |
| j = torch.arange(1, k_val + 1, 1).to(pos.device) | |
| P = torch.cumsum(pos, 1) / j | |
| Npos = torch.sum(pos, 1) | |
| Npos[Npos == 0] = 1 | |
| AP = torch.sum(P * pos, 1) | |
| map = torch.sum(P * pos, 1) / Npos | |
| # wmap | |
| acgj = torch.cumsum(r, 1) / j | |
| wmap = torch.sum(acgj * pos, 1) / Npos | |
| return acg, ndcg, map, wmap | |
| def compute_mAPw(predictions, labels, k_val=5): | |
| gt = labels.squeeze(1) | |
| idx = predictions.topk(dim=1, k=k_val)[1] | |
| pred = gt[idx, :] | |
| pred[pred == -1] = 0 | |
| c = labels.eq(pred) | |
| r = c.sum(-1) | |
| pos = r.clone() | |
| pos[pos != 0] = 1 | |
| P = torch.cumsum(pos) / torch.arange(1, k_val+1, 1) | |
| def adjust_learning_rate(optimizer, epoch, args): | |
| """Decay the learning rate with half-cycle cosine after warmup""" | |
| if epoch < args.warmup_epochs: | |
| lr = args.base_lr * epoch / args.warmup_epochs | |
| else: | |
| lr = args.min_lr + (args.base_lr - args.min_lr) * 0.5 * \ | |
| (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) | |
| for param_group in optimizer.param_groups: | |
| if "lr_scale" in param_group: | |
| param_group["lr"] = lr * param_group["lr_scale"] | |
| else: | |
| param_group["lr"] = lr | |
| return lr | |
| def load_ckpt(weight_dir, model, map_location, args): | |
| checkpoint = torch.load(weight_dir, map_location=map_location) | |
| if args.resume: | |
| resume_epoch = checkpoint['epoch'] | |
| else: | |
| resume_epoch = 0 | |
| pre_weight = checkpoint['state_dict'] | |
| new_pre_weight = OrderedDict() | |
| # pre_weight =torch.jit.load(resume) | |
| model_dict = model.state_dict() | |
| new_model_dict = OrderedDict() | |
| for k, v in pre_weight.items(): | |
| new_k = k.replace('module.', '') if 'module' in k else k | |
| # 针对batch_size=1 | |
| # new_k = new_k.replace('1','2') if 'proj.1' in new_k else new_k | |
| new_pre_weight[new_k] = v | |
| # for k, v in model_dict.items(): | |
| # new_k = k.replace('module.', '') if 'module' in k else k | |
| # new_model_dict[new_k] = v | |
| pre_weight = new_pre_weight # ["model_state"] | |
| # pretrained_dict = {} | |
| # t_n = 0 | |
| # v_n = 0 | |
| # for k, v in pre_weight.items(): | |
| # t_n += 1 | |
| # if k in new_model_dict: | |
| # k = 'module.' + k if 'module' not in k else k | |
| # v_n += 1 | |
| # pretrained_dict[k] = v | |
| # print(k) | |
| # os._exit() | |
| # print(f'{v_n}/{t_n} weights have been loaded!') | |
| model_dict.update(pre_weight) | |
| model.load_state_dict(model_dict, strict=False) | |
| return model, resume_epoch | |
| def load_ckpt_fpn(weight_dir, model, map_location): | |
| pre_weight = torch.load(weight_dir, map_location=map_location)['state_dict'] | |
| epoch = torch.load(weight_dir, map_location=map_location)['epoch'] | |
| new_pre_weight = OrderedDict() | |
| # pre_weight =torch.jit.load(resume) | |
| model_dict = model.state_dict() | |
| for k, v in pre_weight.items(): | |
| new_k = k.replace('module.', '') if 'module' in k else k | |
| # if not (new_k.startswith('FPN') or new_k.startswith('gap')): | |
| new_pre_weight[new_k] = v | |
| pre_weight = new_pre_weight | |
| # ["model_state"] | |
| model_dict.update(pre_weight) | |
| model.load_state_dict(model_dict, strict=True) | |
| return model, epoch | |
| def load_ckpt_old(weight_dir, model, map_location): | |
| pre_weight = torch.load(weight_dir, map_location=map_location)['state_dict'] | |
| epoch = torch.load(weight_dir, map_location=map_location)['epoch'] | |
| new_pre_weight = OrderedDict() | |
| # pre_weight =torch.jit.load(resume) | |
| model_dict = model.state_dict() | |
| for k, v in pre_weight.items(): | |
| new_k = k.replace('module.', '') if 'module' in k else k | |
| if not (new_k.startswith('FPN') or new_k.startswith('gap')): | |
| new_pre_weight[new_k] = v | |
| pre_weight = new_pre_weight | |
| # ["model_state"] | |
| model_dict.update(pre_weight) | |
| model.load_state_dict(model_dict, strict=False) | |
| return model, epoch | |
| def compare_ckpt(model1, model2): | |
| V = dict() | |
| for k, v in model1.items(): | |
| if k.startswith('projT'): | |
| V[k] = v | |
| for k, v in model2.items(): | |
| if k in sorted(V.keys()): | |
| model2[k] = V[k] | |
| return model2 |