ariG23498 HF Staff commited on
Commit
a151d2d
·
verified ·
1 Parent(s): 01bf494

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import re
7
+ import gradio as gr
8
+ import dashscope
9
+ from dashscope import MultiModalConversation
10
+ from argparse import ArgumentParser
11
+ from http import HTTPStatus
12
+ from urllib3.exceptions import HTTPError
13
+
14
+ # Set API key
15
+ API_KEY = os.environ['API_KEY']
16
+ dashscope.api_key = API_KEY
17
+
18
+ # Constants
19
+ REVISION = 'v1.0.4'
20
+ BOX_TAG_PATTERN = r"<box>\((\d+),(\d+),(\d+),(\d+)\)</box>"
21
+
22
+ def _get_args():
23
+ parser = ArgumentParser()
24
+ parser.add_argument("--revision", type=str, default=REVISION)
25
+ parser.add_argument("--share", action="store_true", default=False,
26
+ help="Create a publicly shareable link for the interface.")
27
+ parser.add_argument("--server-port", type=int, default=7860,
28
+ help="Demo server port.")
29
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
30
+ help="Demo server name.")
31
+ return parser.parse_args()
32
+
33
+ def parse_bounding_boxes(text):
34
+ """Parse bounding box coordinates from model output."""
35
+ matches = re.findall(BOX_TAG_PATTERN, text)
36
+ bboxes = []
37
+ for match in matches:
38
+ x1, y1, x2, y2 = map(int, match)
39
+ bboxes.append(((x1, y1, x2, y2), "Object"))
40
+ return bboxes
41
+
42
+ def predict(image, prompt):
43
+ """Process image and prompt to get bounding boxes and return annotated image."""
44
+ if image is None or not prompt:
45
+ return None, "Please upload an image and provide a prompt."
46
+
47
+ # Prepare message for the model
48
+ messages = [{
49
+ 'role': 'user',
50
+ 'content': [
51
+ {'image': f'file://{image}'},
52
+ {'text': prompt}
53
+ ]
54
+ }]
55
+
56
+ # Call the Qwen2.5-VL model
57
+ try:
58
+ responses = MultiModalConversation.call(
59
+ model='qwen2.5-vl-32b-instruct',
60
+ messages=messages,
61
+ stream=False
62
+ )
63
+ if responses.status_code != HTTPStatus.OK:
64
+ return None, f"Error: {responses.message}"
65
+
66
+ # Extract response text
67
+ response = responses.output.choices[0].message.content
68
+ response_text = ''.join([ele['text'] if 'text' in ele else ele.get('box', '') for ele in response])
69
+
70
+ # Parse bounding boxes
71
+ bboxes = parse_bounding_boxes(response_text)
72
+ if not bboxes:
73
+ return None, "No bounding boxes detected."
74
+
75
+ # Return the image and annotations for AnnotatedImage
76
+ return (image, bboxes), None
77
+
78
+ except HTTPError as e:
79
+ return None, f"HTTP Error: {str(e)}"
80
+ except Exception as e:
81
+ return None, f"Error: {str(e)}"
82
+
83
+ def clear_inputs():
84
+ """Reset the input fields."""
85
+ return None, "", None
86
+
87
+ def _launch_demo(args):
88
+ with gr.Blocks() as demo:
89
+ gr.Markdown("""<center><font size=3> Qwen2.5-VL-32B-Instruct Bounding Box Demo </center>""")
90
+
91
+ with gr.Row():
92
+ with gr.Column():
93
+ image_input = gr.Image(type="filepath", label="Upload Image")
94
+ prompt_input = gr.Textbox(
95
+ lines=2,
96
+ label="Prompt",
97
+ value="Detect all objects in the provided image and output their bounding box coordinates in the format <box>(x1,y1,x2,y2)</box>. Do not include any other text or descriptions. If multiple objects are detected, list each bounding box in a new <box> tag."
98
+ )
99
+ submit_btn = gr.Button("🚀 Submit")
100
+ clear_btn = gr.Button("🧹 Clear")
101
+ with gr.Column():
102
+ output_image = gr.AnnotatedImage(label="Annotated Image")
103
+ error_message = gr.Textbox(label="Status", interactive=False)
104
+
105
+ # Bind actions
106
+ submit_btn.click(
107
+ fn=predict,
108
+ inputs=[image_input, prompt_input],
109
+ outputs=[output_image, error_message]
110
+ )
111
+ clear_btn.click(
112
+ fn=clear_inputs,
113
+ inputs=[],
114
+ outputs=[image_input, prompt_input, output_image, error_message]
115
+ )
116
+
117
+ demo.launch(
118
+ share=args.share,
119
+ server_port=args.server_port,
120
+ server_name=args.server_name
121
+ )
122
+
123
+ def main():
124
+ args = _get_args()
125
+ _launch_demo(args)
126
+
127
+ if __name__ == '__main__':
128
+ main()