TaskCLIP / demo.py
HanningChen
Initial HF Space: FastAPI + HTML (no weights yet)
f2f112a
import json
import os
from ImageBind.imagebind import data
from ImageBind.imagebind.models import imagebind_model
from ImageBind.imagebind.models.imagebind_model import ModalityType
from collections import OrderedDict
import torch
import argparse
from utils import crop_image, draw_bboxes, save_image, find_same_class, open_image_follow_symlink
from ultralytics import YOLO
from PIL import Image
import numpy as np
from models.TaskCLIP import TaskCLIP
id2task_name_file = './id2task_name.json'
task2prompt_file = './task20.json'
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-vlm_model', default='imagebind', help='Set front CLIP model')
parser.add_argument('-od_model', default='yolox', help='Set object detection model')
parser.add_argument('-device', default='cuda:0', help='Set running environment')
parser.add_argument('-task_id', type=int, default=1, help='Set task id')
parser.add_argument('-image_path', type=str, default='./images/demo_image_1.jpg', help='Set input image path')
parser.add_argument('-activation', type=str, default='relu')
parser.add_argument('-ratio_text', type=float, default=0.3)
parser.add_argument('-ratio_image', type=float, default=0.3)
parser.add_argument('-ratio_glob', type=float, default=0.3)
parser.add_argument('-norm_before', action='store_true', default=False)
parser.add_argument('-norm_after', action='store_true', default=False)
parser.add_argument('-norm_range',type=str, default='10|30')
parser.add_argument('-cross_attention',action='store_true', default=False)
parser.add_argument('-eval_model_path',default='./test_model/decoder_epoch19.pt', help='set path for loading trained TaskCLIP model')
parser.add_argument('-threshold', type=float, default=0.01, help='Set threshold for positive detection')
parser.add_argument('-forward', action='store_true', default=True)
parser.add_argument('-cluster', action='store_true', default=True)
parser.add_argument('-forward_thre', type=float, default=0.1, help='Set threshold for positive detection during forward optimization')
args = parser.parse_args()
device = args.device
threshold = args.threshold
# prepare task name and key words
with open(id2task_name_file, 'r') as f:
id2task_name = json.load(f)
task_id = str(args.task_id)
task_name = id2task_name[task_id]
# prepare input image
image_path = args.image_path
image_name = args.image_path.split('/')[-1].split('.')[0]
image = open_image_follow_symlink(image_path).convert('RGB')
# load vision-language model
vlm_model_name = args.vlm_model
if vlm_model_name == 'imagebind':
vlm_model = imagebind_model.imagebind_huge(pretrained=True).to(device)
vlm_model.eval()
# load object detection model
if args.od_model == 'yolox':
od_model = YOLO('./.checkpoints/yolo12x.pt')
elif args.od_model == 'yolol':
od_model = YOLO('./.checkpoints/tolo12l.pt')
elif args.od_model == 'yolom':
od_model = YOLO('./.checkpoints/tolo12m.pt')
elif args.od_model == 'yolos':
od_model = YOLO('./.checkpoints/tolo12s.pt')
elif args.od_model == 'yolon':
od_model = YOLO('./.checkpoints/tolo12n.pt')
# get key words prompt
with open(task2prompt_file, 'r') as f:
prompt = json.load(f)
prompt_use = []
for x in range(len(prompt[task_name])):
prompt_use.append('The item is ' + prompt[task_name][x])
# get bbox image
outputs = od_model(image_path)
img = np.array(image)
ocvimg = img[:, :, ::-1].copy()
bbox_list = outputs[0].boxes.xyxy.tolist()
classes = outputs[0].boxes.cls.tolist()
names = outputs[0].names
confidences = outputs[0].boxes.conf.tolist()
predict_res = []
json_entry = {}
json_entry['bbox'] = []
json_entry['class'] = classes
json_entry['confidences'] = confidences
json_entry['bbox'] = bbox_list
# crop bbox images
seg_dic = crop_image(ocvimg, bbox_list)
seg_list = []
for id in seg_dic.keys():
seg_list.append(seg_dic[id])
if (len(seg_list) == 0):
print("*"*100)
print("Didn't detect any object in the image.")
print("*"*100)
N_seg = len(seg_list)
# NOTE: test without reasoning model
img_with_bbox = draw_bboxes(ocvimg, bbox_list, (0, 255, 0))
save_image(img_with_bbox, f'./res/{task_id}/{image_name}_no_reasoning.jpg')
# encode bbox image and prompt keywords
with torch.no_grad():
if vlm_model_name == 'imagebind':
input = {
ModalityType.TEXT: data.load_and_transform_text(prompt_use, device),
ModalityType.VISION: data.read_and_transform_vision_data(seg_list, device),
}
embeddings = vlm_model(input)
text_embeddings = embeddings[ModalityType.TEXT]
bbox_embeddings = embeddings[ModalityType.VISION]
input = {
ModalityType.VISION: data.read_and_transform_vision_data([image], device),
}
embeddings = vlm_model(input)
image_embedding = embeddings[ModalityType.VISION].squeeze(dim=0)
# prepare TaskCLIP model
num_layers = 8
nhead = 4
model_config = {}
model_config['num_layers'] = num_layers
model_config['norm'] = None
model_config['return_intermediate'] = False
model_config['d_model'] = image_embedding.shape[-1]
model_config['nhead'] = nhead
model_config['dim_feedforward'] = 2048
model_config['dropout'] = 0.1
model_config['N_words'] = text_embeddings.shape[0]
model_config['activation'] = args.activation
model_config['normalize_before'] = False
model_config['device'] = device
model_config['ratio_text'] = args.ratio_text
model_config['ratio_image'] = args.ratio_image
model_config['ratio_glob'] = args.ratio_glob
model_config['norm_before'] = args.norm_before
model_config['norm_after'] = args.norm_after
model_config['MIN_VAL'] = float(args.norm_range.split('|')[0])
model_config['MAX_VAL'] = float(args.norm_range.split('|')[1])
model_config['cross_attention'] = args.cross_attention
task_clip_model = TaskCLIP(model_config, normalize_before=model_config['normalize_before'], device = model_config['device'])
task_clip_model.load_state_dict(torch.load(args.eval_model_path))
task_clip_model.to(device)
# feed text, bbox, and image embeddings into HDC model
with torch.no_grad():
task_clip_model.eval()
tgt = bbox_embeddings
memory = text_embeddings
image_embedding = image_embedding.view(1,-1)
tgt_new, memory_new, score_res, score_raw = task_clip_model(tgt, memory,image_embedding)
score = score_res.view(-1)
score = score.cpu().squeeze().detach().numpy().tolist()
# post-processing and optimization
predict_res = []
for i in range(len(bbox_list)):
predict_res.append({})
predict_res[i]["category_id"] = -1
predict_res[i]["score"] = -1
predict_res[i]["class"] = int(json_entry['class'][i])
# same class forward optimization
if isinstance(score, list):
visited = [0]*len(score)
for i, x in enumerate(score):
if visited[i] == 1:
continue
if x > threshold:
visited[i] = 1
predict_res[i]["category_id"] = 1
predict_res[i]["score"] = float(x)
if args.forward:
find_same_class(predict_res, score, visited, i, json_entry['class'], json_entry['confidences'], args.forward_thre)
else:
predict_res[i]["category_id"] = 0
predict_res[i]["score"] = 1 - float(x)
else:
if score > threshold:
predict_res[0]["category_id"] = 1
predict_res[0]["score"] = float(score)
else:
predict_res[0]["category_id"] = 0
predict_res[0]["score"] = 1 - float(score)
# cluster bbox optimization
if args.cluster and args.forward and N_seg > 1:
cluster = {}
for p in predict_res:
if int(p["category_id"]) == 1:
if p["class"] in cluster.keys():
cluster[p["class"]].append(p["score"])
else:
cluster[p["class"]] = [p["score"]]
# choose one cluster
if len(cluster.keys()) > 1:
cluster_ave = {}
for c in cluster.keys():
cluster_ave[c] = np.sum(cluster[c])/len(cluster[c])
select_class = max(cluster_ave, key=lambda k: cluster_ave[k])
# remove lower score class
for p in predict_res:
if p["category_id"] == 1 and p["class"] != select_class:
p["category_id"] = 0
score_final = [x["category_id"] for x in predict_res]
# mask = score > threshold
mask = np.array(score_final) == 1
bbox_arr = np.asarray(bbox_list)
bbox_select = bbox_arr[mask]
img_with_bbox = draw_bboxes(ocvimg, bbox_select, (255, 0, 0))
save_image(img_with_bbox, f'./res/{task_id}/{image_name}_reasoning.jpg')