File size: 5,469 Bytes
24c4def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f07c545
 
 
8a32cc6
24c4def
 
 
 
 
 
 
 
 
55d9644
 
 
24c4def
 
55d9644
 
 
24c4def
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
'''
整体思路:对每一个claim调用一次目标检测器,汇总全部obejct(对相近的物体框进行删除 考虑剔除目标框or其他办法)
1. 对每一个claim调用detector 得到bouding box list;phrase list
2. 按woodpecker的方式 调用blip2
3. 按之前的方式调用ocr模型
4. 汇总时需汇总bouding box(相近的需删除)
'''
import cv2
import yaml
import torch
import os
import shortuuid
from PIL import Image
import numpy as np
from torchvision.ops import box_convert
from pipeline.tool.scene_text_model import *
# import sys
# sys.path.append("pipeline/GroundingDINO")
# from groundingdino.util.inference import load_model, load_image, predict, annotate
from pipeline.GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate



BOX_TRESHOLD = 0.35     # used in detector api.
TEXT_TRESHOLD = 0.25    # used in detector api.
AREA_THRESHOLD = 0.001   # used to filter out too small object.
IOU_THRESHOLD = 0.95     # used to filter the same instance. greater than threshold means the same instance

class GroundingDINO:
    def __init__(self):
        self.BOX_TRESHOLD = 0.35
        self.TEXT_TRESHOLD = 0.25
        self.text_rec = MAERec()
        # load only one time
        self.model = load_model("pipeline/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", 
                                "models/groundingdino_swint_ogc.pth", )
                                #device='cuda:0')

        
    def execute(self, image_path, content, new_path, use_text_rec):
        IMAGE_PATH = image_path
        image_source, image = load_image(IMAGE_PATH)
        if use_text_rec:
            # 在场景文本中下调boxthreshold
            boxes, logits, phrases = predict(model=self.model,image=image,caption=content,box_threshold=0.2,text_threshold=self.TEXT_TRESHOLD,device='cuda:0')
            h, w, _ = image_source.shape
            torch_boxes = boxes * torch.Tensor([w, h, w, h])
            xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
            normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist()
            dir_name = IMAGE_PATH.split("/")[-1][:-4]
            cache_dir = new_path + dir_name
            os.makedirs(cache_dir, exist_ok=True)
            res_list = []
            for box, norm_box in zip(xyxy, normed_xyxy):
                # filter out too small object
                thre = AREA_THRESHOLD
                if (norm_box[2]-norm_box[0]) * (norm_box[3]-norm_box[1]) < 0.001:
                    continue
                crop_id = shortuuid.uuid()
                crop_img = Image.fromarray(image_source).crop(box)
                crop_path = os.path.join(cache_dir, f"{crop_id}.jpg")
                crop_img.save(crop_path)
                _, res = self.text_rec.execute(crop_path)
                print(res)
                res_list.append(res)
            annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=res_list)
            new_id = shortuuid.uuid()
            new_image_path = os.path.join(cache_dir, f"{new_id}.jpg")
            cv2.imwrite(new_image_path, annotated_frame)
            result = {"boxes":normed_xyxy, "logits":logits, "phrases":res_list, "new_path":new_image_path}
            return result
        else:
            new_path = new_path + IMAGE_PATH.split('/')[-1]
            print(content)
            boxes, logits, phrases = predict(model=self.model,image=image,caption=content,box_threshold=self.BOX_TRESHOLD,text_threshold=self.TEXT_TRESHOLD,device='cuda:0')
            annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
            cv2.imwrite(new_path, annotated_frame)
            h, w, _ = image_source.shape
            torch_boxes = boxes * torch.Tensor([w, h, w, h])
            xyxy = box_convert(boxes=torch_boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
            normed_xyxy = np.around(np.clip(xyxy / np.array([w, h, w, h]), 0., 1.), 3).tolist()
            result = {"boxes":normed_xyxy, "logits":logits, "phrases":phrases, "new_path":new_path, "xyxy":xyxy, "image_source":image_source}
            return result  
    
        
        
if __name__ == '__main__':
    config = yaml.load(open("/home/wcx/wcx/GroundingDINO/LVLM/config/config.yaml", "r"), Loader=yaml.FullLoader) 
    t = GroundingDINO(config=config)
    # /newdisk3/wcx/TextVQA/test_images/fca674d065b0ee2c.jpg 
    # /newdisk3/wcx/TextVQA/test_images/6648410adb1b08cb.jpg
    image_path = "/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/image.jpg"
    #input = {"text":{"question":"Describe the image","answer":""},"image":image_path}
    # res = t.execute(image_path=image_path,content="word.number",new_path="/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/extra/",use_text_rec=True)
    # print(res)
    res2 = t.execute(image_path,content="car.man.glasses.coat",new_path="/home/wcx/wcx/GroundingDINO/LVLM/cot/img_examples/extra/",use_text_rec=False)
    print(res2)
    '''
    dog cat
    [[0.107, 0.005, 0.56, 0.999], [0.597, 0.066, 1.0, 0.953]]

    'basketball', 'boy', 'car'
    [0.741, 0.179, 0.848, 0.285], [0.773, 0.299, 0.98, 0.828], [0.001, 0.304, 0.992, 0.854]
    
    'worlld 
    [0.405, 0.504, 0.726, 0.7]
    '''
    
    """
    cloud.agricultural exhibit.music.sky.food vendor.sign.street sign.carnival ride
    /val2014/COCO_val2014_000000029056.jpg
    """