jadechoghari commited on
Commit
744d366
·
verified ·
1 Parent(s): f231447

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +80 -322
inference.py CHANGED
@@ -1,340 +1,98 @@
1
- import torch
 
 
2
  from PIL import Image
3
- from conversation import conv_templates
4
- from builder import load_pretrained_model
5
- from functools import partial
6
- from typing import Optional, Callable
7
- import ast
8
- import math
9
- import numpy as np
10
- DEFAULT_REGION_FEA_TOKEN = "<region_fea>"
11
- DEFAULT_IMAGE_TOKEN = "<image>"
12
- DEFAULT_IM_START_TOKEN = "<im_start>"
13
- DEFAULT_IM_END_TOKEN = "<im_end>"
14
- VOCAB_IMAGE_W = 1000 # 224
15
- VOCAB_IMAGE_H = 1000 # 224
16
- IMAGE_TOKEN_INDEX = -200
17
 
18
-
19
- # define the task categories
20
- box_in_tasks = ['widgetcaptions', 'taperception', 'ocr', 'icon_recognition', 'widget_classification', 'example_0']
21
- box_out_tasks = ['widget_listing', 'find_text', 'find_icons', 'find_widget', 'conversation_interaction']
22
- no_box_tasks = ['screen2words', 'detailed_description', 'conversation_perception', 'gpt4']
23
-
24
- def get_bbox_coor(box, ratio_w, ratio_h):
25
- return box[0] * ratio_w, box[1] * ratio_h, box[2] * ratio_w, box[3] * ratio_h
26
-
27
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
28
- if '<image>' in prompt:
29
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
30
- input_ids = []
31
- for i, chunk in enumerate(prompt_chunks):
32
- input_ids.extend(chunk)
33
- if i < len(prompt_chunks) - 1:
34
- input_ids.append(image_token_index)
35
- else:
36
- input_ids = tokenizer(prompt).input_ids
37
- # if return_tensors == 'pt':
38
- # import torch
39
- # input_ids = torch.tensor(input_ids).unsqueeze(0)
40
-
41
- return input_ids
42
-
43
-
44
- def expand2square(pil_img, background_color):
45
- width, height = pil_img.size
46
- if width == height:
47
- return pil_img
48
- elif width > height:
49
- result = Image.new(pil_img.mode, (width, width), background_color)
50
- result.paste(pil_img, (0, (width - height) // 2))
51
- return result
52
- else:
53
- result = Image.new(pil_img.mode, (height, height), background_color)
54
- result.paste(pil_img, ((height - width) // 2, 0))
55
- return result
56
-
57
- def select_best_resolution(original_size, possible_resolutions):
58
- """
59
- Selects the best resolution from a list of possible resolutions based on the original size.
60
-
61
- Args:
62
- original_size (tuple): The original size of the image in the format (width, height).
63
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
64
-
65
- Returns:
66
- tuple: The best fit resolution in the format (width, height).
67
- """
68
- original_width, original_height = original_size
69
- best_fit = None
70
- max_effective_resolution = 0
71
- min_wasted_resolution = float('inf')
72
-
73
- for width, height in possible_resolutions:
74
- scale = min(width / original_width, height / original_height)
75
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
76
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
77
- wasted_resolution = (width * height) - effective_resolution
78
-
79
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
80
- max_effective_resolution = effective_resolution
81
- min_wasted_resolution = wasted_resolution
82
- best_fit = (width, height)
83
-
84
- return best_fit
85
-
86
- def divide_to_patches(image, patch_size):
87
  """
88
- Divides an image into patches of a specified size.
89
-
90
- Args:
91
- image (PIL.Image.Image): The input image.
92
- patch_size (int): The size of each patch.
93
-
94
- Returns:
95
- list: A list of PIL.Image.Image objects representing the patches.
96
  """
97
- patches = []
98
- width, height = image.size
99
- for i in range(0, height, patch_size):
100
- for j in range(0, width, patch_size):
101
- box = (j, i, j + patch_size, i + patch_size)
102
- patch = image.crop(box)
103
- patches.append(patch)
104
 
105
- return patches
106
- def resize_and_pad_image(image, target_resolution, is_pad=False):
107
- """
108
- Resize and pad an image to a target resolution while maintaining aspect ratio.
109
- Args:
110
- image (PIL.Image.Image): The input image.
111
- target_resolution (tuple): The target resolution (width, height) of the image.
112
- Returns:
113
- PIL.Image.Image: The resized and padded image.
114
- """
115
- original_width, original_height = image.size
116
- target_width, target_height = target_resolution
117
 
118
- if is_pad:
119
- scale_w = target_width / original_width
120
- scale_h = target_height / original_height
 
 
121
 
122
- if scale_w < scale_h:
123
- new_width = target_width
124
- new_height = min(math.ceil(original_height * scale_w), target_height)
125
- else:
126
- new_height = target_height
127
- new_width = min(math.ceil(original_width * scale_h), target_width)
128
 
129
- # Resize the image
130
- resized_image = image.resize((new_width, new_height))
131
 
132
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
133
- paste_x = (target_width - new_width) // 2
134
- paste_y = (target_height - new_height) // 2
135
- new_image.paste(resized_image, (paste_x, paste_y))
136
- else:
137
- new_image = image.resize((target_width, target_height))
138
 
139
- return new_image
140
 
141
- def process_anyres_image(image, processor, grid_pinpoints, image_process_func: Optional[Callable] = None):
142
  """
143
- Process an image with variable resolutions.
144
-
145
- Args:
146
- image (PIL.Image.Image): The input image to be processed.
147
- processor: The image processor object.
148
- grid_pinpoints (str): A string representation of a list of possible resolutions.
149
-
150
- Returns:
151
- torch.Tensor: A tensor containing the processed image patches.
152
  """
153
- if type(grid_pinpoints) is list:
154
- possible_resolutions = grid_pinpoints
155
- else:
156
- possible_resolutions = ast.literal_eval(grid_pinpoints)
 
 
 
157
 
158
- best_resolution = select_best_resolution(image.size, possible_resolutions)
159
-
160
- # FIXME: not sure if do_pad or undo_pad may affect the referring side
161
- image_padded = resize_and_pad_image(image, best_resolution, is_pad=False)
162
-
163
- patches = divide_to_patches(image_padded, processor.crop_size['height'])
164
-
165
- if image_process_func:
166
- resized_image_h, resized_image_w = image_process_func.keywords['size']
167
- image_original_resize = image.resize((resized_image_w, resized_image_h))
168
- image_patches = [image_original_resize] + patches
169
- image_patches = [image_process_func(image_patch)['pixel_values'][0]
170
- for image_patch in image_patches]
171
- else:
172
- image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
173
- image_patches = [image_original_resize] + patches
174
- image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
175
- for image_patch in image_patches]
176
-
177
- return torch.stack(image_patches, dim=0)
178
-
179
-
180
- def process_images(images, image_processor, model_cfg, image_process_func: Optional[Callable] = None):
181
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
182
- new_images = []
183
- if image_aspect_ratio == 'pad':
184
- for image in images:
185
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
186
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
187
- new_images.append(image)
188
- elif image_aspect_ratio == "anyres":
189
- # image_processor(images, return_tensors='pt', do_resize=True, do_center_crop=False, size=[image_h, image_w])['pixel_values']
190
- for image in images:
191
- image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints, image_process_func=image_process_func)
192
- new_images.append(image)
193
- else:
194
- return image_processor(images, return_tensors='pt')['pixel_values']
195
- if all(x.shape == new_images[0].shape for x in new_images):
196
- new_images = torch.stack(new_images, dim=0)
197
- return new_images
198
- # function to generate the mask
199
- def generate_mask_for_feature(coor, raw_w, raw_h, mask=None):
200
- """
201
- Generates a region mask based on provided coordinates.
202
- Handles both point and box input.
203
- """
204
- if mask is not None:
205
- assert mask.shape[0] == raw_w and mask.shape[1] == raw_h
206
- coor_mask = np.zeros((raw_w, raw_h))
207
-
208
- # if it's a point (2 coordinates)
209
- if len(coor) == 2:
210
- span = 5 # Define the span for the point
211
- x_min = max(0, coor[0] - span)
212
- x_max = min(raw_w, coor[0] + span + 1)
213
- y_min = max(0, coor[1] - span)
214
- y_max = min(raw_h, coor[1] + span + 1)
215
- coor_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1
216
- assert (coor_mask == 1).any(), f"coor: {coor}, raw_w: {raw_w}, raw_h: {raw_h}"
217
-
218
- # if it's a box (4 coordinates)
219
- elif len(coor) == 4:
220
- coor_mask[coor[0]:coor[2]+1, coor[1]:coor[3]+1] = 1
221
- if mask is not None:
222
- coor_mask = coor_mask * mask
223
-
224
- # convert to torch tensor and ensure it contains non-zero values
225
- coor_mask = torch.from_numpy(coor_mask)
226
- assert len(coor_mask.nonzero()) != 0, "Generated mask is empty :("
227
-
228
-
229
- return coor_mask
230
-
231
-
232
- def infer_single_prompt(image_path, prompt, model_path, region=None, model_name="ferret_gemma", conv_mode="ferret_gemma_instruct", add_region_feature=False):
233
- img = Image.open(image_path).convert('RGB')
234
-
235
- # this loads the model, image processor and tokenizer
236
- tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
237
- # define the image size required by clip
238
- image_size = {"height": 336, "width": 336}
239
-
240
- if "<image>" in prompt:
241
- prompt = prompt.split('\n')[1]
242
-
243
- if model.config.mm_use_im_start_end:
244
- prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt
245
- else:
246
- prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
247
-
248
- # generate the prompt per template requirement
249
- conv = conv_templates[conv_mode].copy()
250
- conv.append_message(conv.roles[0], prompt)
251
- conv.append_message(conv.roles[1], None)
252
- prompt_input = conv.get_prompt()
253
-
254
- input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
255
-
256
- # raw_w, raw_h = img.size # check if shouldnt be width and height
257
- raw_w = image_size["width"]
258
- raw_h = image_size["height"]
259
- if model.config.image_aspect_ratio == "square_nocrop":
260
- image_tensor = image_processor.preprocess(img, return_tensors='pt', do_resize=True,
261
- do_center_crop=False, size=[raw_h, raw_w])['pixel_values'][0]
262
- elif model.config.image_aspect_ratio == "anyres":
263
- image_process_func = partial(image_processor.preprocess, return_tensors='pt', do_resize=True, do_center_crop=False, size=[raw_h, raw_h])
264
- image_tensor = process_images([img], image_processor, model.config, image_process_func=image_process_func)[0]
265
- else:
266
- image_tensor = process_images([img], image_processor, model.config)[0]
267
-
268
- images = image_tensor.unsqueeze(0).to(torch.float16).cuda()
269
 
 
 
270
 
271
-
272
- # region mask logic (if region is provided)
273
- region_masks = None
274
- if add_region_feature and region is not None:
275
- # box_in is true
276
- raw_w, raw_h = img.size
277
- ratio_w = VOCAB_IMAGE_W * 1.0 / raw_w
278
- ratio_h = VOCAB_IMAGE_H * 1.0 / raw_h
279
- # preprocess the region
280
- box_x1, box_y1, box_x2, box_y2 = region
281
- box_x1_textvocab, box_y1_textvocab, box_x2_textvocab, box_y2_textvocab = get_bbox_coor(box=region, ratio_h=ratio_h, ratio_w=ratio_w)
282
- region_coordinate_raw = [box_x1, box_y1, box_x2, box_y2]
283
-
284
- region_masks = generate_mask_for_feature(region_coordinate_raw, raw_w, raw_h).unsqueeze(0).cuda().half()
285
- region_masks = [[region_mask_i.cuda().half() for region_mask_i in region_masks]]
286
- prompt_input = prompt_input.replace("<bbox_location0>", f"[{box_x1_textvocab}, {box_y1_textvocab}, {box_x2_textvocab}, {box_y2_textvocab}] {DEFAULT_REGION_FEA_TOKEN}")
287
-
288
- # tokenize prompt
289
- # input_ids = tokenizer(prompt_input, return_tensors='pt')['input_ids'].cuda()
290
-
291
 
292
-
293
- # generate model output
294
- with torch.inference_mode():
295
- # Use region_masks in model's forward call
296
- model.orig_forward = model.forward
297
- model.forward = partial(
298
- model.orig_forward,
299
- region_masks=region_masks
300
- )
301
- # explcit add of attention mask
302
- output_ids = model.generate(
303
- input_ids,
304
- images=images,
305
- max_new_tokens=1024,
306
- num_beams=1,
307
- region_masks=region_masks, # pass the region mask to the model
308
- image_sizes=[img.size]
309
- )
310
- model.forward = model.orig_forward
311
-
312
- # we decode the output
313
- output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
314
- return output_text.strip()
 
 
 
 
 
 
 
 
 
 
 
315
 
316
- # We also define a task-specific inference function
317
- def infer_ui_task(image_path, prompt, model_path, task, region=None, add_region_feature=False):
318
- # region = torch.tensor(region).cuda()
319
- """
320
- Handles task types: box_in_tasks, box_out_tasks, no_box_tasks.
321
- """
322
- if region is not None:
323
- add_region_feature=True
324
- if task in box_in_tasks and region is None:
325
- raise ValueError(f"Task {task} requires a bounding box region.")
326
-
327
- if task in box_in_tasks:
328
- print(f"Processing {task} with bounding box region.")
329
- return infer_single_prompt(image_path, prompt, model_path, region, add_region_feature=add_region_feature)
330
-
331
- elif task in box_out_tasks:
332
- print(f"Processing {task} without bounding box region.")
333
- return infer_single_prompt(image_path, prompt, model_path)
334
-
335
- elif task in no_box_tasks:
336
- print(f"Processing {task} without image or bounding box.")
337
- return infer_single_prompt(image_path, prompt, model_path)
338
-
339
- else:
340
- raise ValueError(f"Unknown task type: {task}")
 
1
+ import subprocess
2
+ import os
3
+ import subprocess
4
  from PIL import Image
5
+ import re
6
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ def process_inference_results(results):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  """
10
+ Process the inference results by:
11
+ 1. Adding bounding boxes on the image based on the coordinates in 'text'.
12
+ 2. Extracting and returning the text prompt.
13
+
14
+ :param results: List of inference results with bounding boxes in 'text'.
15
+ :return: (image, text)
 
 
16
  """
17
+ processed_images = []
18
+ extracted_texts = []
 
 
 
 
 
19
 
20
+ for result in results:
21
+ image_path = result['image_path']
22
+ img = Image.open(image_path).convert("RGB")
 
 
 
 
 
 
 
 
 
23
 
24
+ # this no more than extracts bounding box coordinates from the 'text'
25
+ bbox_str = re.search(r'\[\[([0-9,\s]+)\]\]', result['text'])
26
+ if bbox_str:
27
+ bbox = [int(coord) for coord in bbox_str.group(1).split(',')]
28
+ x1, y1, x2, y2 = bbox
29
 
30
+ # Draw the bounding box on the image (optional if needed later)
31
+ # draw = ImageDraw.Draw(img)
32
+ # draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
 
 
 
33
 
34
+ extracted_texts.append(result['text'])
 
35
 
36
+ processed_images.append(img)
 
 
 
 
 
37
 
38
+ return processed_images[0], extracted_texts[0]
39
 
40
+ def inference_and_run(image_path, prompt, conv_mode="ferret_gemma_instruct", model_path="jadechoghari/Ferret-UI-Gemma2b", box=None):
41
  """
42
+ Run the inference and capture the errors for debugging.
 
 
 
 
 
 
 
 
43
  """
44
+ data_input = [{
45
+ "id": 0,
46
+ "image": os.path.basename(image_path),
47
+ "image_h": Image.open(image_path).height,
48
+ "image_w": Image.open(image_path).width,
49
+ "conversations": [{"from": "human", "value": f"<image>\n{prompt}"}]
50
+ }]
51
 
52
+ if box:
53
+ data_input[0]["box_x1y1x2y2"] = [[box]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ with open("eval.json", "w") as json_file:
56
+ json.dump(data_input, json_file)
57
 
58
+ print("eval.json file created successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ cmd = [
61
+ "python", "-m", "model_UI",
62
+ "--model_path", model_path,
63
+ "--data_path", "eval.json",
64
+ "--image_path", ".",
65
+ "--answers_file", "eval_output.jsonl",
66
+ "--num_beam", "1",
67
+ "--max_new_tokens", "1024",
68
+ "--conv_mode", conv_mode
69
+ ]
70
+
71
+ if box:
72
+ cmd.extend(["--region_format", "box", "--add_region_feature"])
73
+
74
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True)
75
+ print(f"Subprocess output:\n{result.stdout}")
76
+ print(f"Subprocess error (if any):\n{result.stderr}")
77
+ print(f"Inference completed. Output written to eval_output.jsonl")
78
+
79
+ output_folder = 'eval_output.jsonl'
80
+ if os.path.exists(output_folder):
81
+ json_files = [f for f in os.listdir(output_folder) if f.endswith(".jsonl")]
82
+ if json_files:
83
+ output_file_path = os.path.join(output_folder, json_files[0])
84
+ with open(output_file_path, "r") as output_file:
85
+ results = [json.loads(line) for line in output_file]
86
+
87
+ return process_inference_results(results)
88
+ else:
89
+ print("No output JSONL files found.")
90
+ return None, None
91
+ else:
92
+ print("Output folder not found.")
93
+ return None, None
94
 
95
+ except subprocess.CalledProcessError as e:
96
+ print(f"Error occurred during inference:\n{e}")
97
+ print(f"Subprocess output:\n{e.output}")
98
+ return None, None