|
|
|
import torch |
|
import json |
|
import numpy as np |
|
from torch.nn import functional as F |
|
|
|
def load_class_freq( |
|
path='datasets/metadata/lvis_v1_train_cat_info.json', freq_weight=1.0): |
|
cat_info = json.load(open(path, 'r')) |
|
cat_info = torch.tensor( |
|
[c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])]) |
|
freq_weight = cat_info.float() ** freq_weight |
|
return freq_weight |
|
|
|
|
|
def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None): |
|
appeared = torch.unique(gt_classes) |
|
prob = appeared.new_ones(C + 1).float() |
|
prob[-1] = 0 |
|
if len(appeared) < num_sample_cats: |
|
if weight is not None: |
|
prob[:C] = weight.float().clone() |
|
prob[appeared] = 0 |
|
more_appeared = torch.multinomial( |
|
prob, num_sample_cats - len(appeared), |
|
replacement=False) |
|
appeared = torch.cat([appeared, more_appeared]) |
|
return appeared |
|
|
|
|
|
|
|
def reset_cls_test(model, cls_path, num_classes): |
|
model.roi_heads.num_classes = num_classes |
|
if type(cls_path) == str: |
|
print('Resetting zs_weight', cls_path) |
|
zs_weight = torch.tensor( |
|
np.load(cls_path), |
|
dtype=torch.float32).permute(1, 0).contiguous() |
|
else: |
|
zs_weight = cls_path |
|
zs_weight = torch.cat( |
|
[zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))], |
|
dim=1) |
|
if model.roi_heads.box_predictor[0].cls_score.norm_weight: |
|
zs_weight = F.normalize(zs_weight, p=2, dim=0) |
|
zs_weight = zs_weight.to(model.device) |
|
for k in range(len(model.roi_heads.box_predictor)): |
|
del model.roi_heads.box_predictor[k].cls_score.zs_weight |
|
model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight |