Detic / tools /get_lvis_cat_info.py
AK391
files
159f437
raw
history blame
1.61 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import json
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ann", default='datasets/lvis/lvis_v1_train.json')
parser.add_argument("--add_freq", action='store_true')
parser.add_argument("--r_thresh", type=int, default=10)
parser.add_argument("--c_thresh", type=int, default=100)
args = parser.parse_args()
print('Loading', args.ann)
data = json.load(open(args.ann, 'r'))
cats = data['categories']
image_count = {x['id']: set() for x in cats}
ann_count = {x['id']: 0 for x in cats}
for x in data['annotations']:
image_count[x['category_id']].add(x['image_id'])
ann_count[x['category_id']] += 1
num_freqs = {x: 0 for x in ['r', 'f', 'c']}
for x in cats:
x['image_count'] = len(image_count[x['id']])
x['instance_count'] = ann_count[x['id']]
if args.add_freq:
freq = 'f'
if x['image_count'] < args.c_thresh:
freq = 'c'
if x['image_count'] < args.r_thresh:
freq = 'r'
x['frequency'] = freq
num_freqs[freq] += 1
print(cats)
image_counts = sorted([x['image_count'] for x in cats])
# print('image count', image_counts)
# import pdb; pdb.set_trace()
if args.add_freq:
for x in ['r', 'c', 'f']:
print(x, num_freqs[x])
out = cats # {'categories': cats}
out_path = args.ann[:-5] + '_cat_info.json'
print('Saving to', out_path)
json.dump(out, open(out_path, 'w'))