from collections import defaultdict import json import argparse import os import random import torch from PIL import Image from tqdm import tqdm from interpreter import * from executor import * from methods import * METHODS_MAP = { "baseline": Baseline, "random": Random, "parse": Parse, } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_file", type=str, help="input file with expressions and annotations in jsonlines format") parser.add_argument("--image_root", type=str, help="path to images (train2014 directory of COCO)") parser.add_argument("--clip_model", type=str, default="RN50x16,ViT-B/32", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma") parser.add_argument("--clip_type", type=str, default="aclip", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma") parser.add_argument("--albef_path", type=str, default=None, help="to use ALBEF (instead of CLIP), specify the path to the ALBEF checkpoint") parser.add_argument("--method", type=str, default="parse", help="method to solve expressions") parser.add_argument("--box_representation_method", type=str, default="crop,blur", help="method of representing boxes as individual images (crop, blur, or both separated by a comma)") parser.add_argument("--box_method_aggregator", type=str, default="sum", help="method of combining box representation scores") parser.add_argument("--box_area_threshold", type=float, default=0.0, help="minimum area (as a proportion of image area) for a box to be considered as the answer") parser.add_argument("--output_file", type=str, default=None, help="(optional) output path to save results") parser.add_argument("--detector_file", type=str, default=None, help="(optional) file containing object detections. if not provided, the gold object boxes will be used.") parser.add_argument("--mock", action="store_true", help="(optional) mock CLIP execution.") parser.add_argument("--device", type=int, default=0, help="CUDA device to use.") parser.add_argument("--shuffle_words", action="store_true", help="If true, shuffle words in the sentence") parser.add_argument("--gradcam_alpha", type=float, nargs='+', help="alpha value to use for gradcam method") parser.add_argument("--enlarge_boxes", type=float, default=0.0, help="(optional) whether to enlarge boxes when passing them to the model") parser.add_argument("--part", type=str, default=None, help="(optional) specify how many parts to divide the dataset into and which part to run in the format NUM_PARTS,PART_NUM") parser.add_argument("--batch_size", type=int, default=1, help="number of instances to process in one model call (only supported for baseline model)") parser.add_argument("--baseline_head", action="store_true", help="For baseline, controls whether model is called on both full expression and head noun chunk of expression") parser.add_argument("--mdetr", type=str, default=None, help="to use MDETR as the executor model, specify the name of the MDETR model") parser.add_argument("--albef_block_num", type=int, default=8, help="block num for ALBEF gradcam") parser.add_argument("--albef_mode", type=str, choices=["itm", "itc"], default="itm") parser.add_argument("--expand_position_embedding",action="store_true") parser.add_argument("--gradcam_background", action="store_true") parser.add_argument("--mdetr_given_bboxes", action="store_true") parser.add_argument("--mdetr_use_token_mapping", action="store_true") parser.add_argument("--non_square_size", action="store_true") parser.add_argument("--blur_std_dev", type=int, default=100, help="standard deviation of Gaussian blur") parser.add_argument("--gradcam_ensemble_before", action="store_true", help="Average gradcam maps of different models before summing over the maps") parser.add_argument("--cache_path", type=str, default=None, help="cache features") # Arguments related to Parse method. parser.add_argument("--no_rel", action="store_true", help="Disable relation extraction.") parser.add_argument("--no_sup", action="store_true", help="Disable superlative extraction.") parser.add_argument("--no_null", action="store_true", help="Disable null keyword heuristics.") parser.add_argument("--ternary", action="store_true", help="Disable ternary relation extraction.") parser.add_argument("--baseline_threshold", type=float, default=float("inf"), help="(Parse) Threshold to use relations/superlatives.") parser.add_argument("--temperature", type=float, default=1., help="(Parse) Sigmoid temperature.") parser.add_argument("--superlative_head_only", action="store_true", help="(Parse) Superlatives only quanntify head predicate.") parser.add_argument("--sigmoid", action="store_true", help="(Parse) Use sigmoid, not softmax.") parser.add_argument("--no_possessive", action="store_true", help="(Parse) Model extraneous relations as possessive relations.") parser.add_argument("--expand_chunks", action="store_true", help="(Parse) Expand noun chunks to include descendant tokens that aren't ancestors of tokens in other chunks") parser.add_argument("--parse_no_branch", action="store_true", help="(Parse) Only do the parsing procedure if some relation/superlative keyword is in the expression") parser.add_argument("--possessive_no_expand", action="store_true", help="(Parse) Expand ent2 in possessive case") args = parser.parse_args() with open(args.input_file) as f: lines = f.readlines() data = [json.loads(line) for line in lines] device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu" gradcam = args.method == "gradcam" executor = ClipExecutor(clip_model=args.clip_model, box_representation_method=args.box_representation_method, method_aggregator=args.box_method_aggregator, device=device, square_size=not args.non_square_size, expand_position_embedding=args.expand_position_embedding, blur_std_dev=args.blur_std_dev, cache_path=args.cache_path, input_file=args.input_file, clip_type=args.clip_type) method = METHODS_MAP[args.method](args) correct_count = 0 total_count = 0 if args.output_file: output_file = open(args.output_file, "w") if args.detector_file: detector_file = open(args.detector_file) detections_list = json.load(detector_file) if isinstance(detections_list, dict): detections_map = {int(image_id): detections_list[image_id] for image_id in detections_list} else: detections_map = defaultdict(list) for detection in detections_list: detections_map[detection["image_id"]].append(detection["box"]) part = 0 if args.part is not None: # for multi-gpu test / part-data test num_parts = int(args.part.split(",")[0]) part = int(args.part.split(",")[1]) data = data[int(len(data)*part/num_parts):int(len(data)*(part+1)/num_parts)] batch_count = 0 batch_boxes = [] batch_gold_boxes = [] batch_gold_index = [] batch_file_names = [] batch_sentences = [] for datum in tqdm(data): if "coco" in datum["file_name"].lower(): file_name = "_".join(datum["file_name"].split("_")[:-1])+".jpg" else: file_name = datum["file_name"] img_path = os.path.join(args.image_root, file_name) img = Image.open(img_path).convert('RGB') gold_boxes = [Box(x=ann["bbox"][0], y=ann["bbox"][1], w=ann["bbox"][2], h=ann["bbox"][3]) for ann in datum["anns"]] if isinstance(datum["ann_id"], int) or isinstance(datum["ann_id"], str): datum["ann_id"] = [datum["ann_id"]] assert isinstance(datum["ann_id"], list) gold_index = [i for i in range(len(datum["anns"])) if datum["anns"][i]["id"] in datum["ann_id"]] if args.detector_file: boxes = [Box(x=box[0], y=box[1], w=box[2], h=box[3]) for box in detections_map[int(datum["image_id"])]] if len(boxes) == 0: boxes = [Box(x=0, y=0, w=img.width, h=img.height)] else: boxes = gold_boxes for sentence in datum["sentences"]: env = Environment(img, boxes, executor, (args.mdetr is not None and not args.mdetr_given_bboxes), str(datum["image_id"]), img_path) if args.shuffle_words: words = sentence["raw"].lower().split() random.shuffle(words) result = method.execute(" ".join(words), env) else: result = method.execute(sentence["raw"].lower(), env) boxes = env.boxes print(sentence["raw"].lower()) correct = False for g_index in gold_index: if iou(boxes[result["pred"]], gold_boxes[g_index]) > 0.5: correct = True break if correct: result["correct"] = 1 correct_count += 1 else: result["correct"] = 0 if args.detector_file: argmax_ious = [] max_ious = [] for g_index in gold_index: ious = [iou(box, gold_boxes[g_index]) for box in boxes] argmax_iou = -1 max_iou = 0 if max(ious) >= 0.5: for index, value in enumerate(ious): if value > max_iou: max_iou = value argmax_iou = index argmax_ious.append(argmax_iou) max_ious.append(max_iou) argmax_iou = -1 max_iou = 0 if max(max_ious) >= 0.5: for index, value in zip(argmax_ious, max_ious): if value > max_iou: max_iou = value argmax_iou = index result["gold_index"] = argmax_iou else: result["gold_index"] = gold_index result["bboxes"] = [[box.left, box.top, box.right, box.bottom] for box in boxes] result["file_name"] = file_name result["probabilities"] = result["probs"] result["text"] = sentence["raw"].lower() if args.output_file: # Serialize numpy arrays for JSON. for key in result: if isinstance(result[key], np.ndarray): result[key] = result[key].tolist() if isinstance(result[key], np.int64): result[key] = result[key].item() output_file.write(json.dumps(result)+"\n") total_count += 1 print(f"est_acc: {100 * correct_count / total_count:.3f}") if args.output_file: output_file.close() print(f"acc: {100 * correct_count / total_count:.3f}") acc = 100 * correct_count / total_count result = {} result['acc'] = acc json.dump(acc, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_acc_' + str(part)+'.json'),'w')) json.dump(str(correct_count)+' '+str(total_count), open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_count_' + str(part)+'.json'),'w')) stats = method.get_stats() if stats: pairs = sorted(list(stats.items()), key=lambda tup: tup[0]) for key, value in pairs: result[key] = value if isinstance(value, float): print(f"{key}: {value:.5f}") else: print(f"{key}: {value}") json.dump(result, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_' + str(part)+'.json'),'w'))