|
|
|
import argparse |
|
import json |
|
from collections import defaultdict |
|
|
|
|
|
|
|
COCO_SYNSET_CATEGORIES = [ |
|
{"synset": "person.n.01", "coco_cat_id": 1}, |
|
{"synset": "bicycle.n.01", "coco_cat_id": 2}, |
|
{"synset": "car.n.01", "coco_cat_id": 3}, |
|
{"synset": "motorcycle.n.01", "coco_cat_id": 4}, |
|
{"synset": "airplane.n.01", "coco_cat_id": 5}, |
|
{"synset": "bus.n.01", "coco_cat_id": 6}, |
|
{"synset": "train.n.01", "coco_cat_id": 7}, |
|
{"synset": "truck.n.01", "coco_cat_id": 8}, |
|
{"synset": "boat.n.01", "coco_cat_id": 9}, |
|
{"synset": "traffic_light.n.01", "coco_cat_id": 10}, |
|
{"synset": "fireplug.n.01", "coco_cat_id": 11}, |
|
{"synset": "stop_sign.n.01", "coco_cat_id": 13}, |
|
{"synset": "parking_meter.n.01", "coco_cat_id": 14}, |
|
{"synset": "bench.n.01", "coco_cat_id": 15}, |
|
{"synset": "bird.n.01", "coco_cat_id": 16}, |
|
{"synset": "cat.n.01", "coco_cat_id": 17}, |
|
{"synset": "dog.n.01", "coco_cat_id": 18}, |
|
{"synset": "horse.n.01", "coco_cat_id": 19}, |
|
{"synset": "sheep.n.01", "coco_cat_id": 20}, |
|
{"synset": "beef.n.01", "coco_cat_id": 21}, |
|
{"synset": "elephant.n.01", "coco_cat_id": 22}, |
|
{"synset": "bear.n.01", "coco_cat_id": 23}, |
|
{"synset": "zebra.n.01", "coco_cat_id": 24}, |
|
{"synset": "giraffe.n.01", "coco_cat_id": 25}, |
|
{"synset": "backpack.n.01", "coco_cat_id": 27}, |
|
{"synset": "umbrella.n.01", "coco_cat_id": 28}, |
|
{"synset": "bag.n.04", "coco_cat_id": 31}, |
|
{"synset": "necktie.n.01", "coco_cat_id": 32}, |
|
{"synset": "bag.n.06", "coco_cat_id": 33}, |
|
{"synset": "frisbee.n.01", "coco_cat_id": 34}, |
|
{"synset": "ski.n.01", "coco_cat_id": 35}, |
|
{"synset": "snowboard.n.01", "coco_cat_id": 36}, |
|
{"synset": "ball.n.06", "coco_cat_id": 37}, |
|
{"synset": "kite.n.03", "coco_cat_id": 38}, |
|
{"synset": "baseball_bat.n.01", "coco_cat_id": 39}, |
|
{"synset": "baseball_glove.n.01", "coco_cat_id": 40}, |
|
{"synset": "skateboard.n.01", "coco_cat_id": 41}, |
|
{"synset": "surfboard.n.01", "coco_cat_id": 42}, |
|
{"synset": "tennis_racket.n.01", "coco_cat_id": 43}, |
|
{"synset": "bottle.n.01", "coco_cat_id": 44}, |
|
{"synset": "wineglass.n.01", "coco_cat_id": 46}, |
|
{"synset": "cup.n.01", "coco_cat_id": 47}, |
|
{"synset": "fork.n.01", "coco_cat_id": 48}, |
|
{"synset": "knife.n.01", "coco_cat_id": 49}, |
|
{"synset": "spoon.n.01", "coco_cat_id": 50}, |
|
{"synset": "bowl.n.03", "coco_cat_id": 51}, |
|
{"synset": "banana.n.02", "coco_cat_id": 52}, |
|
{"synset": "apple.n.01", "coco_cat_id": 53}, |
|
{"synset": "sandwich.n.01", "coco_cat_id": 54}, |
|
{"synset": "orange.n.01", "coco_cat_id": 55}, |
|
{"synset": "broccoli.n.01", "coco_cat_id": 56}, |
|
{"synset": "carrot.n.01", "coco_cat_id": 57}, |
|
|
|
{"synset": "sausage.n.01", "coco_cat_id": 58}, |
|
{"synset": "pizza.n.01", "coco_cat_id": 59}, |
|
{"synset": "doughnut.n.02", "coco_cat_id": 60}, |
|
{"synset": "cake.n.03", "coco_cat_id": 61}, |
|
{"synset": "chair.n.01", "coco_cat_id": 62}, |
|
{"synset": "sofa.n.01", "coco_cat_id": 63}, |
|
{"synset": "pot.n.04", "coco_cat_id": 64}, |
|
{"synset": "bed.n.01", "coco_cat_id": 65}, |
|
{"synset": "dining_table.n.01", "coco_cat_id": 67}, |
|
{"synset": "toilet.n.02", "coco_cat_id": 70}, |
|
{"synset": "television_receiver.n.01", "coco_cat_id": 72}, |
|
{"synset": "laptop.n.01", "coco_cat_id": 73}, |
|
{"synset": "mouse.n.04", "coco_cat_id": 74}, |
|
{"synset": "remote_control.n.01", "coco_cat_id": 75}, |
|
{"synset": "computer_keyboard.n.01", "coco_cat_id": 76}, |
|
{"synset": "cellular_telephone.n.01", "coco_cat_id": 77}, |
|
{"synset": "microwave.n.02", "coco_cat_id": 78}, |
|
{"synset": "oven.n.01", "coco_cat_id": 79}, |
|
{"synset": "toaster.n.02", "coco_cat_id": 80}, |
|
{"synset": "sink.n.01", "coco_cat_id": 81}, |
|
{"synset": "electric_refrigerator.n.01", "coco_cat_id": 82}, |
|
{"synset": "book.n.01", "coco_cat_id": 84}, |
|
{"synset": "clock.n.01", "coco_cat_id": 85}, |
|
{"synset": "vase.n.01", "coco_cat_id": 86}, |
|
{"synset": "scissors.n.01", "coco_cat_id": 87}, |
|
{"synset": "teddy.n.01", "coco_cat_id": 88}, |
|
{"synset": "hand_blower.n.01", "coco_cat_id": 89}, |
|
{"synset": "toothbrush.n.01", "coco_cat_id": 90}, |
|
] |
|
|
|
def map_name(x): |
|
x = x.replace('_', ' ') |
|
if '(' in x: |
|
x = x[:x.find('(')] |
|
return x.lower().strip() |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--cc_ann', default='datasets/cc3m/train_image_info.json') |
|
parser.add_argument('--out_path', default='datasets/cc3m/train_image_info_tags.json') |
|
parser.add_argument('--keep_images', action='store_true') |
|
parser.add_argument('--allcaps', action='store_true') |
|
parser.add_argument('--cat_path', default='') |
|
parser.add_argument('--convert_caption', action='store_true') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
cc_data = json.load(open(args.cc_ann, 'r')) |
|
if args.convert_caption: |
|
num_caps = 0 |
|
caps = defaultdict(list) |
|
for x in cc_data['annotations']: |
|
caps[x['image_id']].append(x['caption']) |
|
for x in cc_data['images']: |
|
x['captions'] = caps[x['id']] |
|
num_caps += len(x['captions']) |
|
print('# captions', num_caps) |
|
|
|
if args.cat_path != '': |
|
print('Loading', args.cat_path) |
|
cats = json.load(open(args.cat_path))['categories'] |
|
if 'synonyms' not in cats[0]: |
|
cocoid2synset = {x['coco_cat_id']: x['synset'] \ |
|
for x in COCO_SYNSET_CATEGORIES} |
|
synset2synonyms = {x['synset']: x['synonyms'] \ |
|
for x in cc_data['categories']} |
|
for x in cats: |
|
synonyms = synset2synonyms[cocoid2synset[x['id']]] |
|
x['synonyms'] = synonyms |
|
x['frequency'] = 'f' |
|
cc_data['categories'] = cats |
|
|
|
id2cat = {x['id']: x for x in cc_data['categories']} |
|
class_count = {x['id']: 0 for x in cc_data['categories']} |
|
class_data = {x['id']: [' ' + map_name(xx) + ' ' for xx in x['synonyms']] \ |
|
for x in cc_data['categories']} |
|
num_examples = 5 |
|
examples = {x['id']: [] for x in cc_data['categories']} |
|
|
|
print('class_data', class_data) |
|
|
|
images = [] |
|
for i, x in enumerate(cc_data['images']): |
|
if i % 10000 == 0: |
|
print(i, len(cc_data['images'])) |
|
if args.allcaps: |
|
caption = (' '.join(x['captions'])).lower() |
|
else: |
|
caption = x['captions'][0].lower() |
|
x['pos_category_ids'] = [] |
|
for cat_id, cat_names in class_data.items(): |
|
find = False |
|
for c in cat_names: |
|
if c in caption or caption.startswith(c[1:]) \ |
|
or caption.endswith(c[:-1]): |
|
find = True |
|
break |
|
if find: |
|
x['pos_category_ids'].append(cat_id) |
|
class_count[cat_id] += 1 |
|
if len(examples[cat_id]) < num_examples: |
|
examples[cat_id].append(caption) |
|
if len(x['pos_category_ids']) > 0 or args.keep_images: |
|
images.append(x) |
|
|
|
zero_class = [] |
|
for cat_id, count in class_count.items(): |
|
print(id2cat[cat_id]['name'], count, end=', ') |
|
if count == 0: |
|
zero_class.append(id2cat[cat_id]) |
|
print('==') |
|
print('zero class', zero_class) |
|
|
|
|
|
|
|
|
|
|
|
for freq in ['r', 'c', 'f']: |
|
print('#Images', freq, sum([v for k, v in class_count.items() \ |
|
if id2cat[k]['frequency'] == freq])) |
|
|
|
try: |
|
out_data = {'images': images, 'categories': cc_data['categories'], \ |
|
'annotations': []} |
|
for k, v in out_data.items(): |
|
print(k, len(v)) |
|
if args.keep_images and not args.out_path.endswith('_full.json'): |
|
args.out_path = args.out_path[:-5] + '_full.json' |
|
print('Writing to', args.out_path) |
|
json.dump(out_data, open(args.out_path, 'w')) |
|
except: |
|
pass |
|
|