# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Set up custom environment before nearly anything else is imported # NOTE: this should be the first import (no not reorder) from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip import argparse import os import functools import io import datetime import itertools import json from tqdm import tqdm import numpy as np import torch import torch.distributed as dist from collections import defaultdict from maskrcnn_benchmark.config import cfg from maskrcnn_benchmark.data import make_data_loader from maskrcnn_benchmark.engine.inference import inference, create_positive_dict, clean_name from maskrcnn_benchmark.modeling.detector import build_detection_model from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer from maskrcnn_benchmark.utils.collect_env import collect_env_info from maskrcnn_benchmark.utils.comm import synchronize, get_rank, is_main_process, all_gather from maskrcnn_benchmark.utils.logger import setup_logger from maskrcnn_benchmark.utils.miscellaneous import mkdir from maskrcnn_benchmark.utils.stats import get_model_complexity_info from omnilabeltools import OmniLabel, OmniLabelEval, visualize_image_sample import time import json import tempfile import matplotlib.pyplot as plt from transformers import AutoTokenizer, CLIPTokenizerFast import omnilabeltools as olt from omnilabeltools import OmniLabel, OmniLabelEval import pdb import wandb from multiprocessing import Pool class LLM: def __init__(self, version, prompt_file = None, temp = 1.0): self.version = version self.prompt_file = prompt_file self.temp = temp with open(self.prompt_file, "r") as f: self.prompt = f.read() def __call__(self, entity): time.sleep(0.1) success = False fail_count = 0 if isinstance(entity, list): prompt = [self.prompt.replace("PROMPT", e) for e in entity] else: if self.version == "chat": raw_prompt = self.prompt.replace("PROMPT", entity) try: prompt = json.loads(raw_prompt) except: prompt = [{"role": "user", "content": raw_prompt}] else: prompt = self.prompt.replace("PROMPT", entity) while not success: try: if self.version == "chat": model = "gpt-3.5-turbo" response = openai.ChatCompletion.create( model=model, messages = prompt, temperature=self.temp, ) else: if self.version == "curie": model = "curie" else: model = "text-davinci-003" response = openai.Completion.create( model=model, prompt=prompt, temperature=self.temp, max_tokens=128, top_p=1, frequency_penalty=0.0, presence_penalty=0.0, ) success = True fail_count = 0 except Exception as e: print(f"Exception: {e}") time.sleep(0.1) fail_count += 1 if fail_count > 10: print("Too many failures") return "Too many failures" if isinstance(entity, list): if self.version == "chat": return [r["message"]["content"] for r in response["choices"]] else: return [r["text"] for r in response["choices"]] else: if self.version == "chat": return response["choices"][0]["message"]["content"] else: return response["choices"][0]["text"] def init_distributed_mode(args): """Initialize distributed training, if appropriate""" if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) elif "SLURM_PROCID" in os.environ: args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() else: print("Not using distributed mode") args.distributed = False return # args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = "nccl" print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) dist.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, timeout=datetime.timedelta(0, 7200), ) dist.barrier() setup_for_distributed(args.rank == 0) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def remove_full_stop(description_list): ret_list = [] for descript in description_list: if descript[-1] == '.': descript = descript[:-1] # remove '.' ret_list.append(descript) return ret_list def num_of_words(text): return len(text.split(' ')) def create_queries_and_maps(labels, label_list, tokenizer, additional_labels=None, cfg=None, center_nouns_length = None, override_tokens_positive = None): # Clean label list label_list = [clean_name(i) for i in label_list] # Form the query and get the mapping tokens_positive = [] start_i = 0 end_i = 0 objects_query = "Detect: " #objects_query = "" prefix_length = len(objects_query) # sep between tokens, follow training separation_tokens = cfg.DATASETS.SEPARATION_TOKENS caption_prompt = cfg.DATASETS.CAPTION_PROMPT use_caption_prompt = cfg.DATASETS.USE_CAPTION_PROMPT and caption_prompt is not None for _index, label in enumerate(label_list): if use_caption_prompt: objects_query += caption_prompt[_index]["prefix"] start_i = len(objects_query) if use_caption_prompt: objects_query += caption_prompt[_index]["name"] else: objects_query += label if "a kind of " in label: end_i = len(label.split(",")[0]) + start_i else: end_i = len(objects_query) tokens_positive.append([(start_i, end_i)]) # Every label has a [(start, end)] if use_caption_prompt: objects_query += caption_prompt[_index]["suffix"] if _index != len(label_list) - 1: objects_query += separation_tokens if additional_labels is not None: objects_query += separation_tokens for _index, label in enumerate(additional_labels): objects_query += label if _index != len(additional_labels) - 1: objects_query += separation_tokens # print(objects_query) if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": tokenized = tokenizer(objects_query, return_tensors="pt") elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "roberta-base": tokenized = tokenizer(objects_query, return_tensors="pt") elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": tokenized = tokenizer( objects_query, max_length=cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, truncation=True, return_tensors="pt" ) else: raise NotImplementedError if override_tokens_positive is not None: new_tokens_positive = [] for override in override_tokens_positive: new_tokens_positive.append((override[0] + prefix_length, override[1] + prefix_length)) tokens_positive = [new_tokens_positive] # this is because we only have one label # Create the mapping between tokenized sentence and the original label # if one_hot: # positive_map_token_to_label, positive_map_label_to_token = create_one_hot_dict(labels, no_minus_one_for_one_hot=cfg.DATASETS.NO_MINUS_ONE_FOR_ONE_HOT) # else: positive_map_token_to_label, positive_map_label_to_token = create_positive_dict( tokenized, tokens_positive, labels=labels ) # from token position to original label return objects_query, positive_map_label_to_token def main(): parser = argparse.ArgumentParser(description="PyTorch Detection to Grounding Inference") parser.add_argument( "--config-file", default="configs/pretrain/glip_Swin_T_O365_GoldG.yaml", metavar="FILE", help="path to config file", ) parser.add_argument( "--weight", default=None, metavar="FILE", help="path to config file", ) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER ) parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") parser.add_argument("--task_config", default=None) parser.add_argument("--chunk_size", default=20, type=int, help="number of descriptions each time") parser.add_argument("--threshold", default=None, type=float, help="number of boxes stored in each run") parser.add_argument("--topk_per_eval", default=None, type=int, help="number of boxes stored in each run") parser.add_argument("--group_query", action="store_true", help="group query") parser.add_argument("--noun_phrase_file", default=None, type=str, help="noun phrase file") args = parser.parse_args() num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 distributed = num_gpus > 1 if distributed: # torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group( # backend="nccl", init_method="env://" # ) init_distributed_mode(args) print("Passed distributed init") cfg.local_rank = args.local_rank cfg.num_gpus = num_gpus cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() log_dir = cfg.OUTPUT_DIR if args.weight: log_dir = os.path.join(log_dir, "eval", os.path.splitext(os.path.basename(args.weight))[0]) if log_dir: mkdir(log_dir) logger = setup_logger("maskrcnn_benchmark", log_dir, get_rank()) logger.info(args) logger.info("Using {} GPUs".format(num_gpus)) logger.info(cfg) # logger.info("Collecting env info (might take some time)") # logger.info("\n" + collect_env_info()) device = cfg.MODEL.DEVICE cpu_device = torch.device("cpu") model = build_detection_model(cfg) model.to(device) # we currently disable this # params, flops = get_model_complexity_info(model, # (3, cfg.INPUT.MAX_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST), # input_constructor=lambda x: {'images': [torch.rand(x).cuda()]}) # print("FLOPs: {}, #Parameter: {}".format(params, flops)) checkpointer = DetectronCheckpointer(cfg, model, save_dir=cfg.OUTPUT_DIR) if args.weight: _ = checkpointer.load(args.weight, force=True) else: _ = checkpointer.load(cfg.MODEL.WEIGHT) if args.weight: weight_iter = os.path.splitext(os.path.basename(args.weight))[0].split("_")[-1] try: weight_iter = int(weight_iter) except: weight_iter = 1 else: weight_iter = 1 # get the wandb name train_wandb_name = os.path.basename(cfg.OUTPUT_DIR) eval_wandb_name = train_wandb_name + "_eval" + "_Fixed{}_Chunk{}".format(not cfg.DATASETS.LVIS_USE_NORMAL_AP, cfg.TEST.CHUNKED_EVALUATION) if is_main_process() and train_wandb_name != "__test__": api = wandb.Api() runs = api.runs('haroldli/language_det_eval') matched_run = None history = [] exclude_keys = ['_runtime', '_timestamp'] for run in runs: if run.name == eval_wandb_name and str(run._state) == "finished": print("run found", run.name) print(run.summary) matched_run = run run_his = matched_run.scan_history() #print([len(i) for i in run_his]) for stat in run_his: stat_i = {k: v for k, v in stat.items() if k not in exclude_keys and v is not None} if len(stat_i) > 1: history.append(stat_i) #matched_run.delete() break # only update one wandb_run = wandb.init( project = 'language_det_eval', job_type = 'evaluate', name = eval_wandb_name, ) #pprint(history) # exclude_keys = ['_step', '_runtime', '_timestamp'] # for stat in history: # wandb.log( # {k: v for k, v in stat.items() if k not in exclude_keys}, # step = stat['_step'], # ) else: wandb_run = None history = None print("weight_iter: ", weight_iter) print("train_wandb_name: ", train_wandb_name) print("eval_wandb_name: ", eval_wandb_name) # build tokenizer to process data # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased": tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "roberta-base": tokenizer = AutoTokenizer.from_pretrained("roberta-base") elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: tokenizer = CLIPTokenizerFast.from_pretrained( "openai/clip-vit-base-patch32", from_slow=True, mask_token="ðŁĴij" ) else: tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", from_slow=True) else: tokenizer = None raise NotImplementedError ### inference & evaluation topk_per_eval = args.topk_per_eval threshold = args.threshold model.eval() chunk_size = args.chunk_size # num of texts each time if cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD": class_plus = 1 else: class_plus = 0 task_config = args.task_config assert task_config is not None, "task_config should be assigned" cfg_ = cfg.clone() cfg_.defrost() cfg_.merge_from_file(task_config) cfg_.merge_from_list(args.opts) dataset_name = cfg_.DATASETS.TEST[0] output_folder = os.path.join(log_dir, "inference", dataset_name) if not os.path.exists(output_folder): mkdir(output_folder) data_loaders_val = make_data_loader(cfg_, is_train=False, is_distributed=distributed) _iterator = tqdm(data_loaders_val[0]) # only for the first test set predictions = [] # adhoclly # if "coco" in cfg_.DATASETS.TEST[0]: # gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_coco.json' # elif "oi_v5" in cfg_.DATASETS.TEST[0]: # gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv5.json' # elif "oi_v6" in cfg_.DATASETS.TEST[0]: # gt_json = 'DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv6.json' # else: # assert(0) # omni_label = OmniLabel(path_json=gt_json) if args.noun_phrase_file is not None: try: noun_phrase = json.load(open(args.noun_phrase_file)) except: noun_phrase = {} print("No noun phrase file found, will generate one") llm = LLM(version="chat", prompt_file="tools/data_process/prompts/noun.v1.txt", temp=0.0) else: noun_phrase = {} # stats pos_rates = [] query_length = [] all_info = [] for iidx, batch in enumerate(_iterator): images, targets, image_ids, *_ = batch # import ipdb # ipdb.set_trace() images = images.to(device) text_queries = targets[0].get_field('inference_obj_descriptions') text_queries_ids = targets[0].get_field("inference_obj_description_ids") image_size = targets[0].size image_id = image_ids[0] # pdb.set_trace() #print(data_loaders_val[0].dataset.dataset_dicts[iidx]) #all_info.append(data_loaders_val[0].dataset.dataset_dicts[iidx]) # get the positive label if there is one try: positive_info = omni_label.get_image_sample(image_id) positive_instances = positive_info['instances'] positive_labels = [] for i in positive_instances: positive_labels.extend(i['description_ids']) positive_labels = list(set(positive_labels)) except: positive_labels = None des_id_start = 0 # rearrange the queries query_indexes = [i for i in range(len(text_queries_ids)) if num_of_words(text_queries[i]) > 2] cat_indexes = [i for i in range(len(text_queries_ids)) if num_of_words(text_queries[i]) <= 2] # rearrange the queries if args.group_query: text_queries_ids = [text_queries_ids[i] for i in query_indexes] + [text_queries_ids[i] for i in cat_indexes] text_queries = [text_queries[i] for i in query_indexes] + [text_queries[i] for i in cat_indexes] while des_id_start < len(text_queries_ids): # sinlge descriptions each time if args.group_query: if num_of_words(text_queries[des_id_start]) > 2: description_list = remove_full_stop(text_queries[des_id_start:des_id_start+8]) description_id_list = text_queries_ids[des_id_start:des_id_start+8] des_id_start += 8 else: description_list = remove_full_stop(text_queries[des_id_start:des_id_start+chunk_size]) description_id_list = text_queries_ids[des_id_start:des_id_start+chunk_size] des_id_start += chunk_size else: if num_of_words(text_queries[des_id_start]) > 2: _det_phrase = True description_list = remove_full_stop([text_queries[des_id_start]]) description_id_list = [text_queries_ids[des_id_start]] des_id_start += 1 else: _det_phrase = False description_list = remove_full_stop(text_queries[des_id_start:des_id_start+chunk_size]) description_id_list = text_queries_ids[des_id_start:des_id_start+chunk_size] des_id_start += chunk_size # create postive map, always use continuous labels starting from 1 continue_labels = np.arange(0, chunk_size) + class_plus if _det_phrase and args.noun_phrase_file is not None: # try to find the centern noun phrase center_noun = noun_phrase.get(description_list[0], None) if center_noun is None: center_noun = llm(description_list[0]) if len(center_noun) == 0: center_noun = description_list[0] # failed case noun_phrase[description_list[0]] = center_noun start = description_list[0].lower().find(center_noun.lower()) end = start + len(center_noun) override_tokens_positive = [(start, end)] print(description_list[0], center_noun, override_tokens_positive) cur_queries, positive_map_label_to_token = create_queries_and_maps(continue_labels, description_list, tokenizer, cfg=cfg, override_tokens_positive=override_tokens_positive) else: cur_queries, positive_map_label_to_token = create_queries_and_maps(continue_labels, description_list, tokenizer, cfg=cfg) set_description_id_list = set(description_id_list) # intersection between positive labels and current description ids if positive_labels is not None: pos_rate = len(set_description_id_list.intersection(set(positive_labels))) / len(set_description_id_list) pos_rates.append(pos_rate) query_length.append(len(set_description_id_list)) # print(cur_queries) with torch.no_grad(): output = model(images, captions=[cur_queries], positive_map=positive_map_label_to_token) output = output[0].to(cpu_device).convert(mode="xywh") output = output.resize(image_size) # to the oringinal scale # print(output) # import ipdb # ipdb.set_trace() # thresolding if threshold is not None: scores = output.get_field('scores') output = output[scores > threshold] # sorted by scores if topk_per_eval is not None: scores = output.get_field('scores') _, sortIndices = scores.sort(descending=True) output = output[sortIndices] # topk output = output[:topk_per_eval] # map continuous id to description id cont_ids_2_descript_ids = {i:v for i, v in enumerate(description_id_list)} pred_boxes = output.bbox pred_labels = output.get_field('labels') - class_plus # continuous ids, starting from 0 pred_scores = output.get_field('scores') # convert continuous id to description id for box_idx, box in enumerate(pred_boxes): predictions.append({ "image_id": image_id, "bbox": box.cpu().tolist(), "description_ids": [cont_ids_2_descript_ids[pred_labels[box_idx].item()]], "scores": [pred_scores[box_idx].item()], }) #print("pos_rate: %.2f"%(np.mean(pos_rates)), pos_rates) #print("query_length: %.2f"%(np.mean(query_length)), query_length) # draw a histogram of pos_rate plt.hist(pos_rates, bins=10) plt.savefig(os.path.join(output_folder, "pos_rate.png")) plt.close() if args.noun_phrase_file is not None: with open(args.noun_phrase_file, "w") as f: json.dump(noun_phrase, f, indent=4) # collect predictions from all GPUs synchronize() all_predictions = all_gather(predictions) all_predictions = list(itertools.chain(*all_predictions)) if not is_main_process(): return result_save_json = "%s_results.json"%(dataset_name) results_path = os.path.join(output_folder, result_save_json) print('Saving to', results_path) json.dump(all_predictions, open(results_path, 'w')) from maskrcnn_benchmark.config.paths_catalog import DatasetCatalog datasetMeta = DatasetCatalog.get(dataset_name) gt_path_json = datasetMeta['args']['ann_file'] # import ipdb # ipdb.set_trace() # evaluation gt = OmniLabel(gt_path_json) # load ground truth dataset dt = gt.load_res(results_path) # load prediction results ole = OmniLabelEval(gt, dt) # ole.params.resThrs = ... # set evaluation parameters as desired ole.evaluate() ole.accumulate() score = ole.summarize() # OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json #with open("tools/files/omnilabel_coco.json", "a") as f: # json.dump(all_info, f) if is_main_process(): if wandb_run is not None: # dataset_name = cfg.DATASETS.TEST[0] write_to_wandb_log(score, dataset_name, weight_iter, history) with open("{}/detailed.json".format(output_folder), "w") as f: json.dump(score, f) wandb_run.save("{}/detailed.json".format(output_folder)) print(score) def write_to_wandb_log(score, dataset_name, weight_iter, history): all_results = defaultdict(dict) exclude_keys = ['_step', '_runtime', '_timestamp'] if history is not None: for stat in history: all_results[stat['_step']].update({k: v for k, v in stat.items() if k not in exclude_keys}) result_dict = {} for score_i in score: if score_i["metric"]['metric'] == "AP" and score_i["metric"]['iou'] == "0.50:0.95" and score_i["metric"]['area'] == "all": result_dict[f"{dataset_name}_AP_{score_i['metric']['description']}"] = score_i['value'] #wandb.log({f"{dataset_name}_mAP_all": mAP_all, f"{dataset_name}_mAP_rare": mAP_rare, f"{dataset_name}_mAP_common": mAP_common, f"{dataset_name}_mAP_frequent": mAP_frequent}, step = weight_iter) all_results[weight_iter].update(result_dict) # sort all results max_key = max(all_results.keys()) for i in range(max_key + 1): if i in all_results: wandb.log(all_results[i], step = i) else: wandb.log({}, step = i) # for k in sorted(all_results.keys()): # # need to do consecutive logging # wandb.log(all_results[k], step = k) if __name__ == "__main__": main() ''' from omnilabeltools import OmniLabel, OmniLabelEval gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_openimagesv5.json') # load ground truth dataset dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json") # load prediction results ole = OmniLabelEval(gt, dt) ole.evaluate() ole.accumulate() ole.summarize() gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_coco.json') # load ground truth dataset dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json") gt = OmniLabel('DATASET/omnilabel/dataset_all_val_v0.1.3_object365.json') # load ground truth dataset dt = gt.load_res("OUTPUTS/GLIP_MODEL17/eval/model_0270000/inference/omnilabel_val/omnilabel_val_results.json") '''