Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import torch | |
import json | |
import numpy as np | |
from torch.nn import functional as F | |
def load_class_freq( | |
path='datasets/metadata/lvis_v1_train_cat_info.json', freq_weight=1.0): | |
cat_info = json.load(open(path, 'r')) | |
cat_info = torch.tensor( | |
[c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])]) | |
freq_weight = cat_info.float() ** freq_weight | |
return freq_weight | |
def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None): | |
appeared = torch.unique(gt_classes) # C' | |
prob = appeared.new_ones(C + 1).float() | |
prob[-1] = 0 | |
if len(appeared) < num_sample_cats: | |
if weight is not None: | |
prob[:C] = weight.float().clone() | |
prob[appeared] = 0 | |
more_appeared = torch.multinomial( | |
prob, num_sample_cats - len(appeared), | |
replacement=False) | |
appeared = torch.cat([appeared, more_appeared]) | |
return appeared | |
def reset_cls_test(model, cls_path, num_classes): | |
model.roi_heads.num_classes = num_classes | |
if type(cls_path) == str: | |
print('Resetting zs_weight', cls_path) | |
zs_weight = torch.tensor( | |
np.load(cls_path), | |
dtype=torch.float32).permute(1, 0).contiguous() # D x C | |
else: | |
zs_weight = cls_path | |
zs_weight = torch.cat( | |
[zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))], | |
dim=1) # D x (C + 1) | |
if model.roi_heads.box_predictor[0].cls_score.norm_weight: | |
zs_weight = F.normalize(zs_weight, p=2, dim=0) | |
zs_weight = zs_weight.to(model.device) | |
for k in range(len(model.roi_heads.box_predictor)): | |
del model.roi_heads.box_predictor[k].cls_score.zs_weight | |
model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight |