desco / tools /test_net_omnilabel.py
zdou0830's picture
desco
749745d
raw
history blame
27.3 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Set up custom environment before nearly anything else is imported
# NOTE: this should be the first import (no not reorder)
from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip
import argparse
import os
import functools
import io
import datetime
import itertools
import json
from tqdm import tqdm
import numpy as np
import torch
import torch.distributed as dist
from collections import defaultdict
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data import make_data_loader
from maskrcnn_benchmark.engine.inference import inference, create_positive_dict, clean_name
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, all_gather
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
from maskrcnn_benchmark.utils.stats import get_model_complexity_info
from omnilabeltools import OmniLabel, OmniLabelEval, visualize_image_sample
import time
import json
import tempfile
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, CLIPTokenizerFast
import omnilabeltools as olt
from omnilabeltools import OmniLabel, OmniLabelEval
import pdb
import wandb
from multiprocessing import Pool
class LLM:
def __init__(self, version, prompt_file = None, temp = 1.0):
self.version = version
self.prompt_file = prompt_file
self.temp = temp
with open(self.prompt_file, "r") as f:
self.prompt = f.read()
def __call__(self, entity):
time.sleep(0.1)
success = False
fail_count = 0
if isinstance(entity, list):
prompt = [self.prompt.replace("PROMPT", e) for e in entity]
else:
if self.version == "chat":
raw_prompt = self.prompt.replace("PROMPT", entity)
try:
prompt = json.loads(raw_prompt)
except:
prompt = [{"role": "user", "content": raw_prompt}]
else:
prompt = self.prompt.replace("PROMPT", entity)
while not success:
try:
if self.version == "chat":
model = "gpt-3.5-turbo"
response = openai.ChatCompletion.create(
model=model,
messages = prompt,
temperature=self.temp,
)
else:
if self.version == "curie":
model = "curie"
else:
model = "text-davinci-003"
response = openai.Completion.create(
model=model,
prompt=prompt,
temperature=self.temp,
max_tokens=128,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
)
success = True
fail_count = 0
except Exception as e:
print(f"Exception: {e}")
time.sleep(0.1)
fail_count += 1
if fail_count > 10:
print("Too many failures")
return "Too many failures"
if isinstance(entity, list):
if self.version == "chat":
return [r["message"]["content"] for r in response["choices"]]
else:
return [r["text"] for r in response["choices"]]
else:
if self.version == "chat":
return response["choices"][0]["message"]["content"]
else:
return response["choices"][0]["text"]
def init_distributed_mode(args):
"""Initialize distributed training, if appropriate"""
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
# args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
dist.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
timeout=datetime.timedelta(0, 7200),
)
dist.barrier()
setup_for_distributed(args.rank == 0)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def remove_full_stop(description_list):
ret_list = []
for descript in description_list:
if descript[-1] == '.':
descript = descript[:-1] # remove '.'
ret_list.append(descript)
return ret_list
def num_of_words(text):
return len(text.split(' '))
def create_queries_and_maps(labels, label_list, tokenizer, additional_labels=None, cfg=None, center_nouns_length = None, override_tokens_positive = None):
# Clean label list
label_list = [clean_name(i) for i in label_list]
# Form the query and get the mapping
tokens_positive = []
start_i = 0
end_i = 0
objects_query = "Detect: "
#objects_query = ""
prefix_length = len(objects_query)
# sep between tokens, follow training
separation_tokens = cfg.DATASETS.SEPARATION_TOKENS
caption_prompt = cfg.DATASETS.CAPTION_PROMPT
use_caption_prompt = cfg.DATASETS.USE_CAPTION_PROMPT and caption_prompt is not None
for _index, label in enumerate(label_list):
if use_caption_prompt:
objects_query += caption_prompt[_index]["prefix"]
start_i = len(objects_query)
if use_caption_prompt:
objects_query += caption_prompt[_index]["name"]
else:
objects_query += label
if "a kind of " in label:
end_i = len(label.split(",")[0]) + start_i
else:
end_i = len(objects_query)
tokens_positive.append([(start_i, end_i)]) # Every label has a [(start, end)]
if use_caption_prompt:
objects_query += caption_prompt[_index]["suffix"]
if _index != len(label_list) - 1:
objects_query += separation_tokens
if additional_labels is not None:
objects_query += separation_tokens
for _index, label in enumerate(additional_labels):
objects_query += label
if _index != len(additional_labels) - 1:
objects_query += separation_tokens
# print(objects_query)
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased":
tokenized = tokenizer(objects_query, return_tensors="pt")
elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "roberta-base":
tokenized = tokenizer(objects_query, return_tensors="pt")
elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
tokenized = tokenizer(
objects_query, max_length=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, truncation=True, return_tensors="pt"
)
else:
raise NotImplementedError
if override_tokens_positive is not None:
new_tokens_positive = []
for override in override_tokens_positive:
new_tokens_positive.append((override[0] + prefix_length, override[1] + prefix_length))
tokens_positive = [new_tokens_positive] # this is because we only have one label
# Create the mapping between tokenized sentence and the original label
# if one_hot:
# positive_map_token_to_label, positive_map_label_to_token = create_one_hot_dict(labels, no_minus_one_for_one_hot=cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT)
# else:
positive_map_token_to_label, positive_map_label_to_token = create_positive_dict(
tokenized, tokens_positive, labels=labels
) # from token position to original label
return objects_query, positive_map_label_to_token
def main():
parser = argparse.ArgumentParser(description="PyTorch Detection to Grounding Inference")
parser.add_argument(
"--config-file",
default="configs/pretrain/glip_Swin_T_O365_GoldG.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--weight",
default=None,
metavar="FILE",
help="path to config file",
)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument(
"opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER
)
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
parser.add_argument("--task_config", default=None)
parser.add_argument("--chunk_size", default=20, type=int, help="number of descriptions each time")
parser.add_argument("--threshold", default=None, type=float, help="number of boxes stored in each run")
parser.add_argument("--topk_per_eval", default=None, type=int, help="number of boxes stored in each run")
parser.add_argument("--group_query", action="store_true", help="group query")
parser.add_argument("--noun_phrase_file", default=None, type=str, help="noun phrase file")
args = parser.parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
distributed = num_gpus > 1
if distributed:
# torch.cuda.set_device(args.local_rank)
# torch.distributed.init_process_group(
# backend="nccl", init_method="env://"
# )
init_distributed_mode(args)
print("Passed distributed init")
cfg.local_rank = args.local_rank
cfg.num_gpus = num_gpus
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
log_dir = cfg.OUTPUT_DIR
if args.weight:
log_dir = os.path.join(log_dir, "eval", os.path.splitext(os.path.basename(args.weight))[0])
if log_dir:
mkdir(log_dir)
logger = setup_logger("maskrcnn_benchmark", log_dir, get_rank())
logger.info(args)
logger.info("Using {} GPUs".format(num_gpus))
logger.info(cfg)
# logger.info("Collecting env info (might take some time)")
# logger.info("\n" + collect_env_info())
device = cfg.MODEL.DEVICE
cpu_device = torch.device("cpu")
model = build_detection_model(cfg)
model.to(device)
# we currently disable this
# params, flops = get_model_complexity_info(model,
# (3, cfg.INPUT.MAX_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST),
# input_constructor=lambda x: {'images': [torch.rand(x).cuda()]})
# print("FLOPs: {}, #Parameter: {}".format(params, flops))
checkpointer = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR)
if args.weight:
_ = checkpointer.load(args.weight, force=True)
else:
_ = checkpointer.load(cfg.MODEL.WEIGHT)
if args.weight:
weight_iter = os.path.splitext(os.path.basename(args.weight))[0].split("_")[-1]
try:
weight_iter = int(weight_iter)
except:
weight_iter = 1
else:
weight_iter = 1
# get the wandb name
train_wandb_name = os.path.basename(cfg.OUTPUT_DIR)
eval_wandb_name = train_wandb_name + "_eval" + "_Fixed{}_Chunk{}".format(not cfg.DATASETS.LVIS_USE_NORMAL_AP, cfg.TEST.CHUNKED_EVALUATION)
if is_main_process() and train_wandb_name != "__test__":
api = wandb.Api()
runs = api.runs('haroldli/language_det_eval')
matched_run = None
history = []
exclude_keys = ['_runtime', '_timestamp']
for run in runs:
if run.name == eval_wandb_name and str(run._state) == "finished":
print("run found", run.name)
print(run.summary)
matched_run = run
run_his = matched_run.scan_history()
#print([len(i) for i in run_his])
for stat in run_his:
stat_i = {k: v for k, v in stat.items() if k not in exclude_keys and v is not None}
if len(stat_i) > 1:
history.append(stat_i)
#matched_run.delete()
break # only update one
wandb_run = wandb.init(
project = 'language_det_eval',
job_type = 'evaluate',
name = eval_wandb_name,
)
#pprint(history)
# exclude_keys = ['_step', '_runtime', '_timestamp']
# for stat in history:
# wandb.log(
# {k: v for k, v in stat.items() if k not in exclude_keys},
# step = stat['_step'],
# )
else:
wandb_run = None
history = None
print("weight_iter: ", weight_iter)
print("train_wandb_name: ", train_wandb_name)
print("eval_wandb_name: ", eval_wandb_name)
# build tokenizer to process data
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased":
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "roberta-base":
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
tokenizer = CLIPTokenizerFast.from_pretrained(
"openai/clip-vit-base-patch32", from_slow=True, mask_token="ðŁĴij</w>"
)
else:
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True)
else:
tokenizer = None
raise NotImplementedError
### inference & evaluation
topk_per_eval = args.topk_per_eval
threshold = args.threshold
model.eval()
chunk_size = args.chunk_size # num of texts each time
if cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
class_plus = 1
else:
class_plus = 0
task_config = args.task_config
assert task_config is not None, "task_config should be assigned"
cfg_ = cfg.clone()
cfg_.defrost()
cfg_.merge_from_file(task_config)
cfg_.merge_from_list(args.opts)
dataset_name = cfg_.DATASETS.TEST[0]
output_folder = os.path.join(log_dir, "inference", dataset_name)
if not os.path.exists(output_folder):
mkdir(output_folder)
data_loaders_val = make_data_loader(cfg_, is_train=False, is_distributed=distributed)
_iterator = tqdm(data_loaders_val[0]) # only for the first test set
predictions = []
# adhoclly
# if "coco" in cfg_.DATASETS.TEST[0]:
# gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_coco.json'
# elif "oi_v5" in cfg_.DATASETS.TEST[0]:
# gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv5.json'
# elif "oi_v6" in cfg_.DATASETS.TEST[0]:
# gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv6.json'
# else:
# assert(0)
# omni_label = OmniLabel(path_json=gt_json)
if args.noun_phrase_file is not None:
try:
noun_phrase = json.load(open(args.noun_phrase_file))
except:
noun_phrase = {}
print("No noun phrase file found, will generate one")
llm = LLM(version="chat", prompt_file="tools/data_process/prompts/noun.v1.txt", temp=0.0)
else:
noun_phrase = {}
# stats
pos_rates = []
query_length = []
all_info = []
for iidx, batch in enumerate(_iterator):
images, targets, image_ids, *_ = batch
# import ipdb
# ipdb.set_trace()
images = images.to(device)
text_queries = targets[0].get_field('inference_obj_descriptions')
text_queries_ids = targets[0].get_field("inference_obj_description_ids")
image_size = targets[0].size
image_id = image_ids[0]
# pdb.set_trace()
#print(data_loaders_val[0].dataset.dataset_dicts[iidx])
#all_info.append(data_loaders_val[0].dataset.dataset_dicts[iidx])
# get the positive label if there is one
try:
positive_info = omni_label.get_image_sample(image_id)
positive_instances = positive_info['instances']
positive_labels = []
for i in positive_instances: positive_labels.extend(i['description_ids'])
positive_labels = list(set(positive_labels))
except:
positive_labels = None
des_id_start = 0
# rearrange the queries
query_indexes = [i for i in range(len(text_queries_ids)) if num_of_words(text_queries[i]) > 2]
cat_indexes = [i for i in range(len(text_queries_ids)) if num_of_words(text_queries[i]) <= 2]
# rearrange the queries
if args.group_query:
text_queries_ids = [text_queries_ids[i] for i in query_indexes] + [text_queries_ids[i] for i in cat_indexes]
text_queries = [text_queries[i] for i in query_indexes] + [text_queries[i] for i in cat_indexes]
while des_id_start < len(text_queries_ids):
# sinlge descriptions each time
if args.group_query:
if num_of_words(text_queries[des_id_start]) > 2:
description_list = remove_full_stop(text_queries[des_id_start:des_id_start+8])
description_id_list = text_queries_ids[des_id_start:des_id_start+8]
des_id_start += 8
else:
description_list = remove_full_stop(text_queries[des_id_start:des_id_start+chunk_size])
description_id_list = text_queries_ids[des_id_start:des_id_start+chunk_size]
des_id_start += chunk_size
else:
if num_of_words(text_queries[des_id_start]) > 2:
_det_phrase = True
description_list = remove_full_stop([text_queries[des_id_start]])
description_id_list = [text_queries_ids[des_id_start]]
des_id_start += 1
else:
_det_phrase = False
description_list = remove_full_stop(text_queries[des_id_start:des_id_start+chunk_size])
description_id_list = text_queries_ids[des_id_start:des_id_start+chunk_size]
des_id_start += chunk_size
# create postive map, always use continuous labels starting from 1
continue_labels = np.arange(0, chunk_size) + class_plus
if _det_phrase and args.noun_phrase_file is not None:
# try to find the centern noun phrase
center_noun = noun_phrase.get(description_list[0], None)
if center_noun is None:
center_noun = llm(description_list[0])
if len(center_noun) == 0:
center_noun = description_list[0] # failed case
noun_phrase[description_list[0]] = center_noun
start = description_list[0].lower().find(center_noun.lower())
end = start + len(center_noun)
override_tokens_positive = [(start, end)]
print(description_list[0], center_noun, override_tokens_positive)
cur_queries, positive_map_label_to_token = create_queries_and_maps(continue_labels, description_list, tokenizer, cfg=cfg, override_tokens_positive=override_tokens_positive)
else:
cur_queries, positive_map_label_to_token = create_queries_and_maps(continue_labels, description_list, tokenizer, cfg=cfg)
set_description_id_list = set(description_id_list)
# intersection between positive labels and current description ids
if positive_labels is not None:
pos_rate = len(set_description_id_list.intersection(set(positive_labels))) / len(set_description_id_list)
pos_rates.append(pos_rate)
query_length.append(len(set_description_id_list))
# print(cur_queries)
with torch.no_grad():
output = model(images, captions=[cur_queries], positive_map=positive_map_label_to_token)
output = output[0].to(cpu_device).convert(mode="xywh")
output = output.resize(image_size) # to the oringinal scale
# print(output)
# import ipdb
# ipdb.set_trace()
# thresolding
if threshold is not None:
scores = output.get_field('scores')
output = output[scores > threshold]
# sorted by scores
if topk_per_eval is not None:
scores = output.get_field('scores')
_, sortIndices = scores.sort(descending=True)
output = output[sortIndices]
# topk
output = output[:topk_per_eval]
# map continuous id to description id
cont_ids_2_descript_ids = {i:v for i, v in enumerate(description_id_list)}
pred_boxes = output.bbox
pred_labels = output.get_field('labels') - class_plus # continuous ids, starting from 0
pred_scores = output.get_field('scores')
# convert continuous id to description id
for box_idx, box in enumerate(pred_boxes):
predictions.append({
"image_id": image_id,
"bbox": box.cpu().tolist(),
"description_ids": [cont_ids_2_descript_ids[pred_labels[box_idx].item()]],
"scores": [pred_scores[box_idx].item()],
})
#print("pos_rate: %.2f"%(np.mean(pos_rates)), pos_rates)
#print("query_length: %.2f"%(np.mean(query_length)), query_length)
# draw a histogram of pos_rate
plt.hist(pos_rates, bins=10)
plt.savefig(os.path.join(output_folder, "pos_rate.png"))
plt.close()
if args.noun_phrase_file is not None:
with open(args.noun_phrase_file, "w") as f:
json.dump(noun_phrase, f, indent=4)
# collect predictions from all GPUs
synchronize()
all_predictions = all_gather(predictions)
all_predictions = list(itertools.chain(*all_predictions))
if not is_main_process():
return
result_save_json = "%s_results.json"%(dataset_name)
results_path = os.path.join(output_folder, result_save_json)
print('Saving to', results_path)
json.dump(all_predictions, open(results_path, 'w'))
from maskrcnn_benchmark.config.paths_catalog import DatasetCatalog
datasetMeta = DatasetCatalog.get(dataset_name)
gt_path_json = datasetMeta['args']['ann_file']
# import ipdb
# ipdb.set_trace()
# evaluation
gt = OmniLabel(gt_path_json) # load ground truth dataset
dt = gt.load_res(results_path) # load prediction results
ole = OmniLabelEval(gt, dt)
# ole.params.resThrs = ... # set evaluation parameters as desired
ole.evaluate()
ole.accumulate()
score = ole.summarize()
# OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json
#with open("tools/files/omnilabel_coco.json", "a") as f:
# json.dump(all_info, f)
if is_main_process():
if wandb_run is not None:
#
dataset_name = cfg.DATASETS.TEST[0]
write_to_wandb_log(score, dataset_name, weight_iter, history)
with open("{}/detailed.json".format(output_folder), "w") as f:
json.dump(score, f)
wandb_run.save("{}/detailed.json".format(output_folder))
print(score)
def write_to_wandb_log(score, dataset_name, weight_iter, history):
all_results = defaultdict(dict)
exclude_keys = ['_step', '_runtime', '_timestamp']
if history is not None:
for stat in history:
all_results[stat['_step']].update({k: v for k, v in stat.items() if k not in exclude_keys})
result_dict = {}
for score_i in score:
if score_i["metric"]['metric'] == "AP" and score_i["metric"]['iou'] == "0.50:0.95" and score_i["metric"]['area'] == "all":
result_dict[f"{dataset_name}_AP_{score_i['metric']['description']}"] = score_i['value']
#wandb.log({f"{dataset_name}_mAP_all": mAP_all, f"{dataset_name}_mAP_rare": mAP_rare, f"{dataset_name}_mAP_common": mAP_common, f"{dataset_name}_mAP_frequent": mAP_frequent}, step = weight_iter)
all_results[weight_iter].update(result_dict)
# sort all results
max_key = max(all_results.keys())
for i in range(max_key + 1):
if i in all_results:
wandb.log(all_results[i], step = i)
else:
wandb.log({}, step = i)
# for k in sorted(all_results.keys()):
# # need to do consecutive logging
# wandb.log(all_results[k], step = k)
if __name__ == "__main__":
main()
'''
from omnilabeltools import OmniLabel, OmniLabelEval
gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv5.json') # load ground truth dataset
dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json") # load prediction results
ole = OmniLabelEval(gt, dt)
ole.evaluate()
ole.accumulate()
ole.summarize()
gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_coco.json') # load ground truth dataset
dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json")
gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_object365.json') # load ground truth dataset
dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json")
'''