Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from ImageBind.imagebind import data | |
| from ImageBind.imagebind.models import imagebind_model | |
| from ImageBind.imagebind.models.imagebind_model import ModalityType | |
| from collections import OrderedDict | |
| import torch | |
| import argparse | |
| from utils import crop_image, draw_bboxes, save_image, find_same_class, open_image_follow_symlink | |
| from ultralytics import YOLO | |
| from PIL import Image | |
| import numpy as np | |
| from models.TaskCLIP import TaskCLIP | |
| id2task_name_file = './id2task_name.json' | |
| task2prompt_file = './task20.json' | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-vlm_model', default='imagebind', help='Set front CLIP model') | |
| parser.add_argument('-od_model', default='yolox', help='Set object detection model') | |
| parser.add_argument('-device', default='cuda:0', help='Set running environment') | |
| parser.add_argument('-task_id', type=int, default=1, help='Set task id') | |
| parser.add_argument('-image_path', type=str, default='./images/demo_image_1.jpg', help='Set input image path') | |
| parser.add_argument('-activation', type=str, default='relu') | |
| parser.add_argument('-ratio_text', type=float, default=0.3) | |
| parser.add_argument('-ratio_image', type=float, default=0.3) | |
| parser.add_argument('-ratio_glob', type=float, default=0.3) | |
| parser.add_argument('-norm_before', action='store_true', default=False) | |
| parser.add_argument('-norm_after', action='store_true', default=False) | |
| parser.add_argument('-norm_range',type=str, default='10|30') | |
| parser.add_argument('-cross_attention',action='store_true', default=False) | |
| parser.add_argument('-eval_model_path',default='./test_model/decoder_epoch19.pt', help='set path for loading trained TaskCLIP model') | |
| parser.add_argument('-threshold', type=float, default=0.01, help='Set threshold for positive detection') | |
| parser.add_argument('-forward', action='store_true', default=True) | |
| parser.add_argument('-cluster', action='store_true', default=True) | |
| parser.add_argument('-forward_thre', type=float, default=0.1, help='Set threshold for positive detection during forward optimization') | |
| args = parser.parse_args() | |
| device = args.device | |
| threshold = args.threshold | |
| # prepare task name and key words | |
| with open(id2task_name_file, 'r') as f: | |
| id2task_name = json.load(f) | |
| task_id = str(args.task_id) | |
| task_name = id2task_name[task_id] | |
| # prepare input image | |
| image_path = args.image_path | |
| image_name = args.image_path.split('/')[-1].split('.')[0] | |
| image = open_image_follow_symlink(image_path).convert('RGB') | |
| # load vision-language model | |
| vlm_model_name = args.vlm_model | |
| if vlm_model_name == 'imagebind': | |
| vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(device) | |
| vlm_model.eval() | |
| # load object detection model | |
| if args.od_model == 'yolox': | |
| od_model = YOLO('./.checkpoints/yolo12x.pt') | |
| elif args.od_model == 'yolol': | |
| od_model = YOLO('./.checkpoints/tolo12l.pt') | |
| elif args.od_model == 'yolom': | |
| od_model = YOLO('./.checkpoints/tolo12m.pt') | |
| elif args.od_model == 'yolos': | |
| od_model = YOLO('./.checkpoints/tolo12s.pt') | |
| elif args.od_model == 'yolon': | |
| od_model = YOLO('./.checkpoints/tolo12n.pt') | |
| # get key words prompt | |
| with open(task2prompt_file, 'r') as f: | |
| prompt = json.load(f) | |
| prompt_use = [] | |
| for x in range(len(prompt[task_name])): | |
| prompt_use.append('The item is ' + prompt[task_name][x]) | |
| # get bbox image | |
| outputs = od_model(image_path) | |
| img = np.array(image) | |
| ocvimg = img[:, :, ::-1].copy() | |
| bbox_list = outputs[0].boxes.xyxy.tolist() | |
| classes = outputs[0].boxes.cls.tolist() | |
| names = outputs[0].names | |
| confidences = outputs[0].boxes.conf.tolist() | |
| predict_res = [] | |
| json_entry = {} | |
| json_entry['bbox'] = [] | |
| json_entry['class'] = classes | |
| json_entry['confidences'] = confidences | |
| json_entry['bbox'] = bbox_list | |
| # crop bbox images | |
| seg_dic = crop_image(ocvimg, bbox_list) | |
| seg_list = [] | |
| for id in seg_dic.keys(): | |
| seg_list.append(seg_dic[id]) | |
| if (len(seg_list) == 0): | |
| print("*"*100) | |
| print("Didn't detect any object in the image.") | |
| print("*"*100) | |
| N_seg = len(seg_list) | |
| # NOTE: test without reasoning model | |
| img_with_bbox = draw_bboxes(ocvimg, bbox_list, (0, 255, 0)) | |
| save_image(img_with_bbox, f'./res/{task_id}/{image_name}_no_reasoning.jpg') | |
| # encode bbox image and prompt keywords | |
| with torch.no_grad(): | |
| if vlm_model_name == 'imagebind': | |
| input = { | |
| ModalityType.TEXT: data.load_and_transform_text(prompt_use, device), | |
| ModalityType.VISION: data.read_and_transform_vision_data(seg_list, device), | |
| } | |
| embeddings = vlm_model(input) | |
| text_embeddings = embeddings[ModalityType.TEXT] | |
| bbox_embeddings = embeddings[ModalityType.VISION] | |
| input = { | |
| ModalityType.VISION: data.read_and_transform_vision_data([image], device), | |
| } | |
| embeddings = vlm_model(input) | |
| image_embedding = embeddings[ModalityType.VISION].squeeze(dim=0) | |
| # prepare TaskCLIP model | |
| num_layers = 8 | |
| nhead = 4 | |
| model_config = {} | |
| model_config['num_layers'] = num_layers | |
| model_config['norm'] = None | |
| model_config['return_intermediate'] = False | |
| model_config['d_model'] = image_embedding.shape[-1] | |
| model_config['nhead'] = nhead | |
| model_config['dim_feedforward'] = 2048 | |
| model_config['dropout'] = 0.1 | |
| model_config['N_words'] = text_embeddings.shape[0] | |
| model_config['activation'] = args.activation | |
| model_config['normalize_before'] = False | |
| model_config['device'] = device | |
| model_config['ratio_text'] = args.ratio_text | |
| model_config['ratio_image'] = args.ratio_image | |
| model_config['ratio_glob'] = args.ratio_glob | |
| model_config['norm_before'] = args.norm_before | |
| model_config['norm_after'] = args.norm_after | |
| model_config['MIN_VAL'] = float(args.norm_range.split('|')[0]) | |
| model_config['MAX_VAL'] = float(args.norm_range.split('|')[1]) | |
| model_config['cross_attention'] = args.cross_attention | |
| task_clip_model = TaskCLIP(model_config, normalize_before=model_config['normalize_before'], device = model_config['device']) | |
| task_clip_model.load_state_dict(torch.load(args.eval_model_path)) | |
| task_clip_model.to(device) | |
| # feed text, bbox, and image embeddings into HDC model | |
| with torch.no_grad(): | |
| task_clip_model.eval() | |
| tgt = bbox_embeddings | |
| memory = text_embeddings | |
| image_embedding = image_embedding.view(1,-1) | |
| tgt_new, memory_new, score_res, score_raw = task_clip_model(tgt, memory,image_embedding) | |
| score = score_res.view(-1) | |
| score = score.cpu().squeeze().detach().numpy().tolist() | |
| # post-processing and optimization | |
| predict_res = [] | |
| for i in range(len(bbox_list)): | |
| predict_res.append({}) | |
| predict_res[i]["category_id"] = -1 | |
| predict_res[i]["score"] = -1 | |
| predict_res[i]["class"] = int(json_entry['class'][i]) | |
| # same class forward optimization | |
| if isinstance(score, list): | |
| visited = [0]*len(score) | |
| for i, x in enumerate(score): | |
| if visited[i] == 1: | |
| continue | |
| if x > threshold: | |
| visited[i] = 1 | |
| predict_res[i]["category_id"] = 1 | |
| predict_res[i]["score"] = float(x) | |
| if args.forward: | |
| find_same_class(predict_res, score, visited, i, json_entry['class'], json_entry['confidences'], args.forward_thre) | |
| else: | |
| predict_res[i]["category_id"] = 0 | |
| predict_res[i]["score"] = 1 - float(x) | |
| else: | |
| if score > threshold: | |
| predict_res[0]["category_id"] = 1 | |
| predict_res[0]["score"] = float(score) | |
| else: | |
| predict_res[0]["category_id"] = 0 | |
| predict_res[0]["score"] = 1 - float(score) | |
| # cluster bbox optimization | |
| if args.cluster and args.forward and N_seg > 1: | |
| cluster = {} | |
| for p in predict_res: | |
| if int(p["category_id"]) == 1: | |
| if p["class"] in cluster.keys(): | |
| cluster[p["class"]].append(p["score"]) | |
| else: | |
| cluster[p["class"]] = [p["score"]] | |
| # choose one cluster | |
| if len(cluster.keys()) > 1: | |
| cluster_ave = {} | |
| for c in cluster.keys(): | |
| cluster_ave[c] = np.sum(cluster[c])/len(cluster[c]) | |
| select_class = max(cluster_ave, key=lambda k: cluster_ave[k]) | |
| # remove lower score class | |
| for p in predict_res: | |
| if p["category_id"] == 1 and p["class"] != select_class: | |
| p["category_id"] = 0 | |
| score_final = [x["category_id"] for x in predict_res] | |
| # mask = score > threshold | |
| mask = np.array(score_final) == 1 | |
| bbox_arr = np.asarray(bbox_list) | |
| bbox_select = bbox_arr[mask] | |
| img_with_bbox = draw_bboxes(ocvimg, bbox_select, (255, 0, 0)) | |
| save_image(img_with_bbox, f'./res/{task_id}/{image_name}_reasoning.jpg') |