EasyDetect / pipeline /tool /object_detetction_model.py
sunnychenxiwang's picture
Update pipeline/tool/object_detetction_model.py
8a32cc6 verified
raw
history blame
No virus
5.47 kB
'''
整体思路:对每一个claim调用一次目标检测器,汇总全部obejct(对相近的物体框进行删除 考虑剔除目标框or其他办法)
1. 对每一个claim调用detector 得到bouding box list;phrase list
2. 按woodpecker的方式 调用blip2
3. 按之前的方式调用ocr模型
4. 汇总时需汇总bouding box(相近的需删除)
'''
import cv2
import yaml
import torch
import os
import shortuuid
from PIL import Image
import numpy as np
from torchvision.ops import box_convert
from pipeline.tool.scene_text_model import *
# import sys
# sys.path.append("pipeline/GroundingDINO")
# from groundingdino.util.inference import load_model, load_image, predict, annotate
from pipeline.GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate
BOX_TRESHOLD = 0.35 # used in detector api.
TEXT_TRESHOLD = 0.25 # used in detector api.
AREA_THRESHOLD = 0.001 # used to filter out too small object.
IOU_THRESHOLD = 0.95 # used to filter the same instance. greater than threshold means the same instance
class GroundingDINO:
def __init__(self):
self.BOX_TRESHOLD = 0.35
self.TEXT_TRESHOLD = 0.25
self.text_rec = MAERec()
# load only one time
self.model = load_model("pipeline/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
"models/groundingdino_swint_ogc.pth", )
#device='cuda:0')
def execute(self, image_path, content, new_path, use_text_rec):
IMAGE_PATH = image_path
image_source, image = load_image(IMAGE_PATH)
if use_text_rec:
# 在场景文本中下调boxthreshold
boxes, logits, phrases = predict(model=self.model,image=image,caption=content,box_threshold=0.2,text_threshold=self.TEXT_TRESHOLD,device='cuda:0')
h, w, _ = image_source.shape
torch_boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist()
dir_name = IMAGE_PATH.split("/")[-1][:-4]
cache_dir = new_path + dir_name
os.makedirs(cache_dir, exist_ok=True)
res_list = []
for box, norm_box in zip(xyxy, normed_xyxy):
# filter out too small object
thre = AREA_THRESHOLD
if (norm_box[2]-norm_box[0]) * (norm_box[3]-norm_box[1]) < 0.001:
continue
crop_id = shortuuid.uuid()
crop_img = Image.fromarray(image_source).crop(box)
crop_path = os.path.join(cache_dir, f"{crop_id}.jpg")
crop_img.save(crop_path)
_, res = self.text_rec.execute(crop_path)
print(res)
res_list.append(res)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=res_list)
new_id = shortuuid.uuid()
new_image_path = os.path.join(cache_dir, f"{new_id}.jpg")
cv2.imwrite(new_image_path, annotated_frame)
result = {"boxes":normed_xyxy, "logits":logits, "phrases":res_list, "new_path":new_image_path}
return result
else:
new_path = new_path + IMAGE_PATH.split('/')[-1]
print(content)
boxes, logits, phrases = predict(model=self.model,image=image,caption=content,box_threshold=self.BOX_TRESHOLD,text_threshold=self.TEXT_TRESHOLD,device='cuda:0')
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
cv2.imwrite(new_path, annotated_frame)
h, w, _ = image_source.shape
torch_boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist()
result = {"boxes":normed_xyxy, "logits":logits, "phrases":phrases, "new_path":new_path, "xyxy":xyxy, "image_source":image_source}
return result
if __name__ == '__main__':
config = yaml.load(open("/home/wcx/wcx/GroundingDINO/LVLM/config/config.yaml", "r"), Loader=yaml.FullLoader)
t = GroundingDINO(config=config)
# /newdisk3/wcx/TextVQA/test_images/fca674d065b0ee2c.jpg
# /newdisk3/wcx/TextVQA/test_images/6648410adb1b08cb.jpg
image_path = "/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/image.jpg"
#input = {"text":{"question":"Describe the image","answer":""},"image":image_path}
# res = t.execute(image_path=image_path,content="word.number",new_path="/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/extra/",use_text_rec=True)
# print(res)
res2 = t.execute(image_path,content="car.man.glasses.coat",new_path="/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/extra/",use_text_rec=False)
print(res2)
'''
dog cat
[[0.107, 0.005, 0.56, 0.999], [0.597, 0.066, 1.0, 0.953]]
'basketball', 'boy', 'car'
[0.741, 0.179, 0.848, 0.285], [0.773, 0.299, 0.98, 0.828], [0.001, 0.304, 0.992, 0.854]
'worlld
[0.405, 0.504, 0.726, 0.7]
'''
"""
cloud.agricultural exhibit.music.sky.food vendor.sign.street sign.carnival ride
/val2014/COCO_val2014_000000029056.jpg
"""