hpc-yekin
initial commit
92e0882
from collections import defaultdict
import json
import argparse
import os
import random
import torch
from PIL import Image
from tqdm import tqdm
from interpreter import *
from executor import *
from methods import *
METHODS_MAP = {
"baseline": Baseline,
"random": Random,
"parse": Parse,
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str, help="input file with expressions and annotations in jsonlines format")
parser.add_argument("--image_root", type=str, help="path to images (train2014 directory of COCO)")
parser.add_argument("--clip_model", type=str, default="RN50x16,ViT-B/32", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma")
parser.add_argument("--clip_type", type=str, default="aclip", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma")
parser.add_argument("--albef_path", type=str, default=None, help="to use ALBEF (instead of CLIP), specify the path to the ALBEF checkpoint")
parser.add_argument("--method", type=str, default="parse", help="method to solve expressions")
parser.add_argument("--box_representation_method", type=str, default="crop,blur", help="method of representing boxes as individual images (crop, blur, or both separated by a comma)")
parser.add_argument("--box_method_aggregator", type=str, default="sum", help="method of combining box representation scores")
parser.add_argument("--box_area_threshold", type=float, default=0.0, help="minimum area (as a proportion of image area) for a box to be considered as the answer")
parser.add_argument("--output_file", type=str, default=None, help="(optional) output path to save results")
parser.add_argument("--detector_file", type=str, default=None, help="(optional) file containing object detections. if not provided, the gold object boxes will be used.")
parser.add_argument("--mock", action="store_true", help="(optional) mock CLIP execution.")
parser.add_argument("--device", type=int, default=0, help="CUDA device to use.")
parser.add_argument("--shuffle_words", action="store_true", help="If true, shuffle words in the sentence")
parser.add_argument("--gradcam_alpha", type=float, nargs='+', help="alpha value to use for gradcam method")
parser.add_argument("--enlarge_boxes", type=float, default=0.0, help="(optional) whether to enlarge boxes when passing them to the model")
parser.add_argument("--part", type=str, default=None, help="(optional) specify how many parts to divide the dataset into and which part to run in the format NUM_PARTS,PART_NUM")
parser.add_argument("--batch_size", type=int, default=1, help="number of instances to process in one model call (only supported for baseline model)")
parser.add_argument("--baseline_head", action="store_true", help="For baseline, controls whether model is called on both full expression and head noun chunk of expression")
parser.add_argument("--mdetr", type=str, default=None, help="to use MDETR as the executor model, specify the name of the MDETR model")
parser.add_argument("--albef_block_num", type=int, default=8, help="block num for ALBEF gradcam")
parser.add_argument("--albef_mode", type=str, choices=["itm", "itc"], default="itm")
parser.add_argument("--expand_position_embedding",action="store_true")
parser.add_argument("--gradcam_background", action="store_true")
parser.add_argument("--mdetr_given_bboxes", action="store_true")
parser.add_argument("--mdetr_use_token_mapping", action="store_true")
parser.add_argument("--non_square_size", action="store_true")
parser.add_argument("--blur_std_dev", type=int, default=100, help="standard deviation of Gaussian blur")
parser.add_argument("--gradcam_ensemble_before", action="store_true", help="Average gradcam maps of different models before summing over the maps")
parser.add_argument("--cache_path", type=str, default=None, help="cache features")
# Arguments related to Parse method.
parser.add_argument("--no_rel", action="store_true", help="Disable relation extraction.")
parser.add_argument("--no_sup", action="store_true", help="Disable superlative extraction.")
parser.add_argument("--no_null", action="store_true", help="Disable null keyword heuristics.")
parser.add_argument("--ternary", action="store_true", help="Disable ternary relation extraction.")
parser.add_argument("--baseline_threshold", type=float, default=float("inf"), help="(Parse) Threshold to use relations/superlatives.")
parser.add_argument("--temperature", type=float, default=1., help="(Parse) Sigmoid temperature.")
parser.add_argument("--superlative_head_only", action="store_true", help="(Parse) Superlatives only quanntify head predicate.")
parser.add_argument("--sigmoid", action="store_true", help="(Parse) Use sigmoid, not softmax.")
parser.add_argument("--no_possessive", action="store_true", help="(Parse) Model extraneous relations as possessive relations.")
parser.add_argument("--expand_chunks", action="store_true", help="(Parse) Expand noun chunks to include descendant tokens that aren't ancestors of tokens in other chunks")
parser.add_argument("--parse_no_branch", action="store_true", help="(Parse) Only do the parsing procedure if some relation/superlative keyword is in the expression")
parser.add_argument("--possessive_no_expand", action="store_true", help="(Parse) Expand ent2 in possessive case")
args = parser.parse_args()
with open(args.input_file) as f:
lines = f.readlines()
data = [json.loads(line) for line in lines]
device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu"
gradcam = args.method == "gradcam"
executor = ClipExecutor(clip_model=args.clip_model, box_representation_method=args.box_representation_method, method_aggregator=args.box_method_aggregator, device=device, square_size=not args.non_square_size, expand_position_embedding=args.expand_position_embedding, blur_std_dev=args.blur_std_dev, cache_path=args.cache_path, input_file=args.input_file, clip_type=args.clip_type)
method = METHODS_MAP[args.method](args)
correct_count = 0
total_count = 0
if args.output_file:
output_file = open(args.output_file, "w")
if args.detector_file:
detector_file = open(args.detector_file)
detections_list = json.load(detector_file)
if isinstance(detections_list, dict):
detections_map = {int(image_id): detections_list[image_id] for image_id in detections_list}
else:
detections_map = defaultdict(list)
for detection in detections_list:
detections_map[detection["image_id"]].append(detection["box"])
part = 0
if args.part is not None: # for multi-gpu test / part-data test
num_parts = int(args.part.split(",")[0])
part = int(args.part.split(",")[1])
data = data[int(len(data)*part/num_parts):int(len(data)*(part+1)/num_parts)]
batch_count = 0
batch_boxes = []
batch_gold_boxes = []
batch_gold_index = []
batch_file_names = []
batch_sentences = []
for datum in tqdm(data):
if "coco" in datum["file_name"].lower():
file_name = "_".join(datum["file_name"].split("_")[:-1])+".jpg"
else:
file_name = datum["file_name"]
img_path = os.path.join(args.image_root, file_name)
img = Image.open(img_path).convert('RGB')
gold_boxes = [Box(x=ann["bbox"][0], y=ann["bbox"][1], w=ann["bbox"][2], h=ann["bbox"][3]) for ann in datum["anns"]]
if isinstance(datum["ann_id"], int) or isinstance(datum["ann_id"], str):
datum["ann_id"] = [datum["ann_id"]]
assert isinstance(datum["ann_id"], list)
gold_index = [i for i in range(len(datum["anns"])) if datum["anns"][i]["id"] in datum["ann_id"]]
if args.detector_file:
boxes = [Box(x=box[0], y=box[1], w=box[2], h=box[3]) for box in detections_map[int(datum["image_id"])]]
if len(boxes) == 0:
boxes = [Box(x=0, y=0, w=img.width, h=img.height)]
else:
boxes = gold_boxes
for sentence in datum["sentences"]:
env = Environment(img, boxes, executor, (args.mdetr is not None and not args.mdetr_given_bboxes), str(datum["image_id"]), img_path)
if args.shuffle_words:
words = sentence["raw"].lower().split()
random.shuffle(words)
result = method.execute(" ".join(words), env)
else:
result = method.execute(sentence["raw"].lower(), env)
boxes = env.boxes
print(sentence["raw"].lower())
correct = False
for g_index in gold_index:
if iou(boxes[result["pred"]], gold_boxes[g_index]) > 0.5:
correct = True
break
if correct:
result["correct"] = 1
correct_count += 1
else:
result["correct"] = 0
if args.detector_file:
argmax_ious = []
max_ious = []
for g_index in gold_index:
ious = [iou(box, gold_boxes[g_index]) for box in boxes]
argmax_iou = -1
max_iou = 0
if max(ious) >= 0.5:
for index, value in enumerate(ious):
if value > max_iou:
max_iou = value
argmax_iou = index
argmax_ious.append(argmax_iou)
max_ious.append(max_iou)
argmax_iou = -1
max_iou = 0
if max(max_ious) >= 0.5:
for index, value in zip(argmax_ious, max_ious):
if value > max_iou:
max_iou = value
argmax_iou = index
result["gold_index"] = argmax_iou
else:
result["gold_index"] = gold_index
result["bboxes"] = [[box.left, box.top, box.right, box.bottom] for box in boxes]
result["file_name"] = file_name
result["probabilities"] = result["probs"]
result["text"] = sentence["raw"].lower()
if args.output_file:
# Serialize numpy arrays for JSON.
for key in result:
if isinstance(result[key], np.ndarray):
result[key] = result[key].tolist()
if isinstance(result[key], np.int64):
result[key] = result[key].item()
output_file.write(json.dumps(result)+"\n")
total_count += 1
print(f"est_acc: {100 * correct_count / total_count:.3f}")
if args.output_file:
output_file.close()
print(f"acc: {100 * correct_count / total_count:.3f}")
acc = 100 * correct_count / total_count
result = {}
result['acc'] = acc
json.dump(acc, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_acc_' + str(part)+'.json'),'w'))
json.dump(str(correct_count)+' '+str(total_count), open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_count_' + str(part)+'.json'),'w'))
stats = method.get_stats()
if stats:
pairs = sorted(list(stats.items()), key=lambda tup: tup[0])
for key, value in pairs:
result[key] = value
if isinstance(value, float):
print(f"{key}: {value:.5f}")
else:
print(f"{key}: {value}")
json.dump(result, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_' + str(part)+'.json'),'w'))