# Copyright (c) Facebook, Inc. and its affiliates. import torch import json import numpy as np from torch.nn import functional as F def load_class_freq( path='datasets/metadata/lvis_v1_train_cat_info.json', freq_weight=1.0): cat_info = json.load(open(path, 'r')) cat_info = torch.tensor( [c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])]) freq_weight = cat_info.float() ** freq_weight return freq_weight def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None): appeared = torch.unique(gt_classes) # C' prob = appeared.new_ones(C + 1).float() prob[-1] = 0 if len(appeared) < num_sample_cats: if weight is not None: prob[:C] = weight.float().clone() prob[appeared] = 0 more_appeared = torch.multinomial( prob, num_sample_cats - len(appeared), replacement=False) appeared = torch.cat([appeared, more_appeared]) return appeared def reset_cls_test(model, cls_path, num_classes): model.roi_heads.num_classes = num_classes if type(cls_path) == str: print('Resetting zs_weight', cls_path) zs_weight = torch.tensor( np.load(cls_path), dtype=torch.float32).permute(1, 0).contiguous() # D x C else: zs_weight = cls_path zs_weight = torch.cat( [zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))], dim=1) # D x (C + 1) if model.roi_heads.box_predictor[0].cls_score.norm_weight: zs_weight = F.normalize(zs_weight, p=2, dim=0) zs_weight = zs_weight.to(model.device) for k in range(len(model.roi_heads.box_predictor)): del model.roi_heads.box_predictor[k].cls_score.zs_weight model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight