Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
3
+ from transformers.image_utils import load_image
4
+ from threading import Thread
5
+ import torch
6
+ import pickle as pkl
7
+ import re
8
+ from PIL import Image
9
+ import json
10
+ # import spaces
11
+ from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown
12
+
13
+
14
+ MODEL_ID = "TIGER-Lab/PixelReasoner-RL-v1"
15
+ example_image = "example_images/1.jpg"
16
+ # "example_images/document.png"
17
+ example_text = "What kind of restaurant is it?"
18
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True,
19
+ # min_pixels=min_pixels,
20
+ max_pixels=512*28*28,
21
+ )
22
+ model = AutoModelForImageTextToText.from_pretrained(
23
+ MODEL_ID,
24
+ trust_remote_code=True,
25
+ torch_dtype=torch.bfloat16
26
+ ).to("cuda").eval()
27
+
28
+ def zoom(image, bbox_2d,padding=(0.1,0.1)):
29
+ """
30
+ Crop the image based on the bounding box coordinates.
31
+ """
32
+ img_x, img_y = image.size
33
+ padding_tr = (600.0/img_x,600.0/img_y)
34
+ padding = (min(padding[0],padding_tr[0]),min(padding[1],padding_tr[1]))
35
+
36
+ if bbox_2d[0] < 1 and bbox_2d[1] < 1 and bbox_2d[2] < 1 and bbox_2d[3] < 1:
37
+ normalized_bbox_2d = (float(bbox_2d[0])-padding[0], float(bbox_2d[1])-padding[1], float(bbox_2d[2])+padding[0], float(bbox_2d[3])+padding[1])
38
+ else:
39
+ normalized_bbox_2d = (float(bbox_2d[0])/img_x-padding[0], float(bbox_2d[1])/img_y-padding[1], float(bbox_2d[2])/img_x+padding[0], float(bbox_2d[3])/img_y+padding[1])
40
+ normalized_x1, normalized_y1, normalized_x2, normalized_y2 = normalized_bbox_2d
41
+ normalized_x1 =min(max(0, normalized_x1), 1)
42
+ normalized_y1 =min(max(0, normalized_y1), 1)
43
+ normalized_x2 =min(max(0, normalized_x2), 1)
44
+ normalized_y2 =min(max(0, normalized_y2), 1)
45
+ cropped_img = image.crop((int(normalized_x1*img_x), int(normalized_y1*img_y), int(normalized_x2*img_x), int(normalized_y2*img_y)))
46
+ w, h = cropped_img.size
47
+ assert w > 28 and h > 28, f"Cropped image is too small: {w}x{h}"
48
+
49
+
50
+ return cropped_img
51
+
52
+
53
+ def execute_tool(images, rawimages, args, toolname, is_video, function=None):
54
+ if toolname=='select_frames':
55
+ tgt = args['target_frames']
56
+ if len(tgt)>8:
57
+ message = f"You have selected {len(tgt)} frames in total. Think again which frames you need to check in details (no more than 8 frames)"
58
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
59
+ ##### controlled modification
60
+ if do_controlled_rectify and np.random.uniform()<0.75:
61
+ if np.random.uniform()<0.25:
62
+ tgt = tgt[:len(tgt)//2]
63
+ elif np.random.uniform()<0.25/0.75:
64
+ tgt = tgt[-len(tgt)//2:]
65
+ elif np.random.uniform()<0.25/0.5:
66
+ tgt = tgt[::2]
67
+ else:
68
+ tgt = np.random.choice(tgt, size=len(tgt)//2, replace=False)
69
+ tgt = sorted(tgt)
70
+ selected_frames = function(images[0], tgt)
71
+ message = tgt
72
+ else:
73
+ selected_frames = []
74
+ # selected_frames = function(images[0], [x-1 for x in tgt][::2]) # video is always in the first item
75
+ elif max(tgt)>len(images[0]):
76
+ message = f"There are {len(images[0])} frames numbered in range [1,{len(images[0])}]. Your selection is out of range."
77
+ selected_frames = []
78
+ else:
79
+ message = ""
80
+ candidates = images[0]
81
+ if not isinstance(candidates, list):
82
+ candidates = [candidates]
83
+ selected_frames = function(candidates, [x-1 for x in tgt]) # video is always in the first item
84
+ return selected_frames, message
85
+ else:
86
+ tgt = args['target_image']
87
+ if is_video:
88
+ if len(images)==1: # there is only
89
+ # we default the candidate images into video frames
90
+ video_frames = images[0]
91
+ index = tgt - 1
92
+ assert index<len(video_frames), f"Incorrect `target_image`. You can only select frames in the given video within [1,{len(video_frames)}]"
93
+ image_to_crop = video_frames[index]
94
+ else: # there are zoomed images after the video; images = [[video], img, img, img]
95
+ cand_images = images[1:]
96
+ index = tgt -1
97
+ assert index<len(cand_images), f"Incorrect `target_image`. You can only select a previous frame within [1,{len(cand_images)}]"
98
+ image_to_crop = cand_images[index]
99
+ else:
100
+ index = tgt-1
101
+ assert index<len(images), f"Incorrect `target_image`. You can only select previous images within [1,{len(images)}]"
102
+
103
+ if index<len(rawimages):
104
+ tmp = rawimages[index]
105
+ else:
106
+ tmp = images[index]
107
+ image_to_crop = tmp
108
+ if function is None: function = zoom
109
+ cropped_image = function(image_to_crop, args['bbox_2d'])
110
+ return cropped_image
111
+
112
+
113
+ def parse_last_tool(output_text):
114
+ # print([output_text])
115
+ return json.loads(output_text.split(tool_start)[-1].split(tool_end)[0])
116
+
117
+ tool_end = '</tool_call>'
118
+ tool_start = '<tool_call>'
119
+
120
+ # @spaces.GPU
121
+ def model_inference(input_dict, history):
122
+ text = input_dict["text"]
123
+ files = input_dict["files"]
124
+
125
+ """
126
+ Create chat history
127
+
128
+ Example history value:
129
+ [
130
+ [('pixel.png',), None],
131
+ ['ignore this image. just say "hi" and nothing else', 'Hi!'],
132
+ ['just say "hi" and nothing else', 'Hi!']
133
+ ]
134
+ """
135
+ all_images = []
136
+ current_message_images = []
137
+ sysprompt = "<|im_start|>system\nYou are a helpful assistant.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{\"type\": \"function\", \"function\": {\"name\": \"crop_image\", \"description\": \"Zoom in on the image based on the bounding box coordinates.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"bbox_2d\": {\"type\": \"array\", \"description\": \"coordinates for bounding box of the area you want to zoom in. minimum value is 0 and maximum value is the width/height of the image.\", \"items\": {\"type\": \"number\"}}, \"target_image\": {\"type\": \"number\", \"description\": \"The index of the image to crop. Index from 1 to the number of images. Choose 1 to operate on original image.\"}}, \"required\": [\"bbox_2d\", \"target_image\"]}}}\n{\"type\": \"function\", \"function\": {\"name\": \"select_frames\", \"description\": \"Select frames from a video.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"target_frames\": {\"type\": \"array\", \"description\": \"List of frame indices to select from the video (no more than 8 frames in total).\", \"items\": {\"type\": \"integer\", \"description\": \"Frame index from 1 to 16.\"}}}, \"required\": [\"target_frames\"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>"
138
+ messages = [{
139
+ "role": "user",
140
+ "content": sysprompt
141
+ }]
142
+ hint = "\n\nGuidelines: Understand the given visual information and the user query. Determine if it is beneficial to employ the given visual operations (tools). For a video, we can look closer by `select_frames`. For an image, we can look closer by `crop_image`. Reason with the visual information step by step, and put your final answer within \\boxed{}."
143
+ for val in history:
144
+ if val[0]:
145
+ if isinstance(val[0], str):
146
+ messages.append({
147
+ "role": "user",
148
+ "content": [
149
+ *[{"type": "image", "image": image} for image in current_message_images],
150
+ {"type": "text", "text": val[0]},
151
+ ],
152
+ })
153
+ current_message_images = []
154
+
155
+ else:
156
+ # Load messages. These will be appended to the first user text message that comes after
157
+ current_message_images = [load_image(image) for image in val[0]]
158
+ all_images += current_message_images
159
+
160
+ if val[1]:
161
+ messages.append({"role": "assistant", "content": val[1]})
162
+
163
+ imagelist = rawimagelist = current_message_images = [load_image(image) for image in files]
164
+ all_images += current_message_images
165
+ messages.append({
166
+ "role": "user",
167
+ "content": [
168
+ *[{"type": "image", "image": image} for image in current_message_images],
169
+ {"type": "text", "text": text+hint},
170
+ ],
171
+ })
172
+
173
+ print(messages)
174
+ complete_assistant_response_for_gradio = ""
175
+ while True:
176
+ """
177
+ Generate and stream text
178
+ """
179
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
180
+ inputs = processor(
181
+ text=[prompt],
182
+ images=all_images if all_images else None,
183
+ return_tensors="pt",
184
+ padding=True,
185
+ ).to("cuda")
186
+
187
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
188
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, temperature=0.01, top_p=1.0, top_k=1)
189
+ # import pdb; pdb.set_trace()
190
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
191
+ thread.start()
192
+
193
+ # buffer = ""
194
+ # for new_text in streamer:
195
+ # buffer += new_text
196
+ # yield buffer
197
+ # print(buffer)
198
+ current_model_output_segment = "" # Text generated in this specific model call
199
+ for new_text_chunk in streamer:
200
+ current_model_output_segment += new_text_chunk
201
+ # Yield the sum of previously committed full response parts + current streaming segment
202
+ yield complete_assistant_response_for_gradio + current_model_output_segment
203
+ tmp = f"\n<b>Planning Visual Operations ...</b>\n\n"
204
+ yield complete_assistant_response_for_gradio + current_model_output_segment.split(tool_start)[0] + tmp
205
+ thread.join()
206
+
207
+ # Process the full segment (e.g., remove <|im_end|>)
208
+ processed_segment = current_model_output_segment.split("<|im_end|>", 1)[0] if "<|im_end|>" in current_model_output_segment else current_model_output_segment
209
+
210
+ # Append this processed segment to the cumulative display string for Gradio
211
+ complete_assistant_response_for_gradio += processed_segment + "\n\n"
212
+ print(f"this one: {complete_assistant_response_for_gradio}")
213
+ yield complete_assistant_response_for_gradio # Ensure the fully processed segment is yielded to Gradio
214
+
215
+
216
+ # Check for tool call in the *just generated* segment
217
+ qatext_for_tool_check = processed_segment
218
+ require_tool = tool_end in qatext_for_tool_check and tool_start in qatext_for_tool_check
219
+
220
+ if require_tool:
221
+
222
+ tool_params = parse_last_tool(qatext_for_tool_check)
223
+ tool_name = tool_params['name']
224
+ tool_args = tool_params['arguments']
225
+ complete_assistant_response_for_gradio += f"\n<b>Executing Visual Operations ...</b> @{tool_name}({tool_args})\n\n"
226
+ yield complete_assistant_response_for_gradio # Update Gradio display
227
+
228
+ video_flag = False
229
+
230
+ raw_result = execute_tool(imagelist, rawimagelist, tool_args, tool_name, is_video=video_flag)
231
+ print(raw_result)
232
+ proc_img = raw_result
233
+ all_images += [proc_img]
234
+ new_piece = dict(role='user', content=[
235
+ dict(type='text', text="\nHere is the cropped image (Image Size: {}x{}):".format(proc_img.size[0], proc_img.size[1])),
236
+ dict(type='image', image=proc_img)
237
+ ]
238
+ )
239
+ messages.append(new_piece)
240
+
241
+ complete_assistant_response_for_gradio += f"\n<b>Analyzing Operation Result ...</b> @region(size={proc_img.size[0]}x{proc_img.size[1]})\n\n"
242
+ yield complete_assistant_response_for_gradio # Update Gradio display
243
+
244
+
245
+ else:
246
+ break
247
+
248
+ with gr.Blocks() as demo:
249
+ examples = [
250
+ [{"text": example_text, "files": [example_image]}]
251
+ ]
252
+
253
+ gr.HTML(html_header)
254
+
255
+ gr.ChatInterface(
256
+ fn=model_inference,
257
+ description="# **Pixel Reasoner**",
258
+ examples=examples,
259
+ fill_height=True,
260
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
261
+ stop_btn="Stop Generation",
262
+ multimodal=True,
263
+ cache_examples=False,
264
+ )
265
+
266
+ gr.Markdown(tos_markdown)
267
+ gr.Markdown(learn_more_markdown)
268
+ gr.Markdown(bibtext)
269
+
270
+ demo.launch(debug=True)