File size: 9,632 Bytes
0691c7d
f0d9f07
baea9b2
488d99e
baea9b2
 
 
caa3c61
 
 
f0d9f07
917a5a6
d04a302
 
baea9b2
2fbf361
b32b0a3
 
488d99e
08430c8
 
488d99e
08430c8
 
 
 
488d99e
 
2fbf361
488d99e
2fbf361
6d7bbcd
 
 
 
 
 
 
 
 
f70c323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ec31c
bcc18ad
 
6d7bbcd
4e86eac
576e22a
488d99e
0691c7d
4e86eac
036ee43
 
 
f0d9f07
caa3c61
f70c323
 
6d7bbcd
f70c323
1b13c9a
917a5a6
2a31f6e
 
f70c323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc3420d
4c32826
f70c323
 
 
 
 
5197257
 
f70c323
 
 
 
 
 
 
 
4c32826
f70c323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d7bbcd
 
 
f0d9f07
f139461
488d99e
 
6d7bbcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baea9b2
b32b0a3
 
bc3420d
66c5ac5
f62183e
f7e0c7d
 
66c5ac5
 
 
 
 
 
 
b32b0a3
caa3c61
f70c323
6d7bbcd
f62183e
6d7bbcd
 
 
 
 
 
f7e0c7d
917a5a6
6d7bbcd
f70c323
caa3c61
488d99e
caa3c61
1b13c9a
917a5a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from typing import Optional
import numpy as np
import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image
from io import BytesIO
import PIL.Image
import requests
import cv2
import json
import time
import os

from utils.florence import load_florence_model, run_florence_inference, \
    FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference

DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")

torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True


FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)

def fetch_image_from_url(image_url):
    try:
        response = requests.get(image_url)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content))
        return img
    except Exception as e:
        return None

class calculateDuration:
    def __init__(self, activity_name=""):
        self.activity_name = activity_name

    def __enter__(self):
        self.start_time = time.time()
        self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
        print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.time()
        self.elapsed_time = self.end_time - self.start_time
        self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
        
        if self.activity_name:
            print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
        else:
            print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
        
        print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")


@spaces.GPU()
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False, invert_mask=False) -> Optional[Image.Image]:
    
    if not image_input:
        gr.Info("Please upload an image.")
        return None
    
    if not task_prompt:
        gr.Info("Please enter a task prompt.")
        return None
   
    if image_url:
        with calculateDuration("Download Image"):
            print("start to fetch image from url", image_url)
            image_input = fetch_image_from_url(image_url)
            print("fetch image success")

    # start to parse prompt
    with calculateDuration("FLORENCE"):
        print(task_prompt, text_prompt)
        _, result = run_florence_inference(
            model=FLORENCE_MODEL,
            processor=FLORENCE_PROCESSOR,
            device=DEVICE,
            image=image_input,
            task=task_prompt,
            text=text_prompt
        )
    with calculateDuration("sv.Detections"):
        # start to dectect
        detections = sv.Detections.from_lmm(
            lmm=sv.LMM.FLORENCE_2,
            result=result,
            resolution_wh=image_input.size
        )
    # json_result = json.dumps([])
    # print(detections)
    images = []
    if return_rectangles:
        with calculateDuration("generate rectangle mask"):
            # create mask in rectangle
            (image_width, image_height) = image_input.size
            bboxes = detections.xyxy
            merge_mask_image = np.zeros((image_height, image_width), dtype=np.uint8)
            # sort from left to right
            bboxes = sorted(bboxes, key=lambda bbox: bbox[0])
            for bbox in bboxes:
                x1, y1, x2, y2 = map(int, bbox)
                cv2.rectangle(merge_mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
                clip_mask = np.zeros((image_height, image_width), dtype=np.uint8)
                cv2.rectangle(clip_mask, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
                images.append(clip_mask)
            if merge_masks:
                images = [merge_mask_image] + images
    else:
        with calculateDuration("generate segmenet mask"):
            # using sam generate segments images        
            detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
            if len(detections) == 0:
                gr.Info("No objects detected.")
                return None
            print("mask generated:", len(detections.mask))
            kernel_size = dilate
            kernel = np.ones((kernel_size, kernel_size), np.uint8)

            for i in range(len(detections.mask)):
                mask = detections.mask[i].astype(np.uint8) * 255
                if dilate > 0:
                    mask = cv2.dilate(mask, kernel, iterations=1)
                images.append(mask)

            if merge_masks:
                merged_mask = np.zeros_like(images[0], dtype=np.uint8)
                for mask in images:
                    merged_mask = cv2.bitwise_or(merged_mask, mask)
                images = [merged_mask]
    if invert_mask:
        with calculateDuration("invert mask colors"):
            images = [cv2.bitwise_not(mask) for mask in images]

    return images


def update_task_info(task_prompt):
    task_info = {
        '<OD>': "Object Detection: Detect objects in the image.",
        '<CAPTION_TO_PHRASE_GROUNDING>': "Phrase Grounding: Link phrases in captions to corresponding regions in the image.",
        '<DENSE_REGION_CAPTION>': "Dense Region Captioning: Generate captions for different regions in the image.",
        '<REGION_PROPOSAL>': "Region Proposal: Propose potential regions of interest in the image.",
        '<OCR_WITH_REGION>': "OCR with Region: Extract text and its bounding regions from the image.",
        '<REFERRING_EXPRESSION_SEGMENTATION>': "Referring Expression Segmentation: Segment the region referred to by a natural language expression.",
        '<REGION_TO_SEGMENTATION>': "Region to Segmentation: Convert region proposals into detailed segmentations.",
        '<OPEN_VOCABULARY_DETECTION>': "Open Vocabulary Detection: Detect objects based on open vocabulary concepts.",
        '<REGION_TO_CATEGORY>': "Region to Category: Assign categories to proposed regions.",
        '<REGION_TO_DESCRIPTION>': "Region to Description: Generate descriptive text for specified regions."
    }
    return task_info.get(task_prompt, "Select a task to see its description.")



with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            image = gr.Image(type='pil', label='Upload image')
            image_url =  gr.Textbox(label='Image url', placeholder='Enter text prompts (Optional)', info="The image_url parameter allows you to input a URL pointing to an image.")
            task_prompt = gr.Dropdown(['<OD>', '<CAPTION_TO_PHRASE_GROUNDING>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR_WITH_REGION>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>'], value="<CAPTION_TO_PHRASE_GROUNDING>", label="Task Prompt", info="check doc at [Florence](https://huggingface.co/microsoft/Florence-2-large)")
            text_prompt = gr.Textbox(label='Text prompt', placeholder='Enter text prompts')
            submit_button = gr.Button(value='Submit', variant='primary')

            with gr.Accordion("Advance Settings", open=False):
                dilate = gr.Slider(label="dilate mask", minimum=0, maximum=50, value=10, step=1, info="The dilate parameter controls the expansion of the mask's white areas by a specified number of pixels. Increasing this value will enlarge the white regions, which can help in smoothing out the mask's edges or covering more area in the segmentation.")
                merge_masks = gr.Checkbox(label="Merge masks", value=False, info="The merge_masks parameter combines all the individual masks into a single mask. When enabled, the separate masks generated for different objects or regions will be merged into one unified mask, which can simplify further processing or visualization.")
                return_rectangles = gr.Checkbox(label="Return Rectangles", value=False, info="The return_rectangles parameter, when enabled, generates masks as filled white rectangles corresponding to the bounding boxes of detected objects, rather than detailed contours or segments. This option is useful for simpler, box-based visualizations.")
                invert_mask = gr.Checkbox(label="invert mask", value=False, info="The invert_mask option allows you to reverse the colors of the generated mask, changing black areas to white and white areas to black. This can be useful for visualizing or processing the mask in a different context.")
            
        with gr.Column():
            image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
            # json_result = gr.Code(label="JSON Result", language="json")
    
   
    image_url.change(
        fn=fetch_image_from_url,
        inputs=[image_url],
        outputs=[image]
    )
    
    submit_button.click(
        fn=process_image,
        inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles, invert_mask],
        outputs=[image_gallery],
        show_api=False
    )

demo.queue()
demo.launch(debug=True, show_error=True)