tinyllava commited on
Commit
34c6dcb
1 Parent(s): 83fe555

Delete model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +0 -543
model_utils.py DELETED
@@ -1,543 +0,0 @@
1
-
2
- import requests
3
- from PIL import Image
4
- import torch
5
- from io import BytesIO
6
- import base64
7
- import time
8
- import torch
9
- from transformers import StoppingCriteria
10
-
11
- import math
12
- import ast
13
-
14
- # Model Constants
15
- IGNORE_INDEX = -100
16
- IMAGE_TOKEN_INDEX = -200
17
- DEFAULT_IMAGE_TOKEN = "<image>"
18
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
19
- DEFAULT_IM_START_TOKEN = "<im_start>"
20
- DEFAULT_IM_END_TOKEN = "<im_end>"
21
- IMAGE_PLACEHOLDER = "<image-placeholder>"
22
- import dataclasses
23
- from enum import auto, Enum
24
- from typing import List, Tuple
25
-
26
-
27
- class SeparatorStyle(Enum):
28
- """Different separator style."""
29
- SINGLE = auto()
30
- TWO = auto()
31
- MPT = auto()
32
- PLAIN = auto()
33
- LLAMA_2 = auto()
34
- TINY_LLAMA = auto()
35
- QWEN_2 = auto()
36
-
37
-
38
- @dataclasses.dataclass
39
- class Conversation:
40
- """A class that keeps all conversation history."""
41
- system: str
42
- roles: List[str]
43
- messages: List[List[str]]
44
- offset: int
45
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
46
- sep: str = "###"
47
- sep2: str = None
48
- version: str = "Unknown"
49
-
50
- skip_next: bool = False
51
-
52
- def get_prompt(self):
53
- messages = self.messages
54
- if len(messages) > 0 and type(messages[0][1]) is tuple:
55
- messages = self.messages.copy()
56
- init_role, init_msg = messages[0].copy()
57
- init_msg = init_msg[0].replace("<image>", "").strip()
58
- if 'mmtag' in self.version:
59
- messages[0] = (init_role, init_msg)
60
- messages.insert(0, (self.roles[0], "<Image><image></Image>"))
61
- messages.insert(1, (self.roles[1], "Received."))
62
- else:
63
- messages[0] = (init_role, "<image>\n" + init_msg)
64
-
65
- if self.sep_style == SeparatorStyle.SINGLE:
66
- ret = self.system + self.sep
67
- for role, message in messages:
68
- if message:
69
- if type(message) is tuple:
70
- message, _, _ = message
71
- ret += role + ": " + message + self.sep
72
- else:
73
- ret += role + ":"
74
- elif self.sep_style == SeparatorStyle.TWO:
75
- seps = [self.sep, self.sep2]
76
- ret = self.system + seps[0]
77
- for i, (role, message) in enumerate(messages):
78
- if message:
79
- if type(message) is tuple:
80
- message, _, _ = message
81
- ret += role + ": " + message + seps[i % 2]
82
- else:
83
- ret += role + ":"
84
- elif self.sep_style == SeparatorStyle.MPT:
85
- ret = self.system + self.sep
86
- for role, message in messages:
87
- if message:
88
- if type(message) is tuple:
89
- message, _, _ = message
90
- ret += role + message + self.sep
91
- else:
92
- ret += role
93
- elif self.sep_style == SeparatorStyle.LLAMA_2:
94
- wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
95
- wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
96
- ret = ""
97
-
98
- for i, (role, message) in enumerate(messages):
99
- if i == 0:
100
- assert message, "first message should not be none"
101
- assert role == self.roles[0], "first message should come from user"
102
- if message:
103
- if type(message) is tuple:
104
- message, _, _ = message
105
- if i == 0: message = wrap_sys(self.system) + message
106
- if i % 2 == 0:
107
- message = wrap_inst(message)
108
- ret += self.sep + message
109
- else:
110
- ret += " " + message + " " + self.sep2
111
- else:
112
- ret += ""
113
- ret = ret.lstrip(self.sep)
114
- elif self.sep_style == SeparatorStyle.TINY_LLAMA:
115
- sep = "</s>"
116
- wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
117
- wrap_user = lambda msg: f"<|user|>\n{msg}\n"
118
- wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
119
- ret = ""
120
-
121
- for i, (role, message) in enumerate(messages):
122
- if i == 0:
123
- assert message, "first message should not be none"
124
- assert role == self.roles[0], "first message should come from user"
125
- if message:
126
- if type(message) is tuple:
127
- message, _, _ = message
128
- if i % 2 == 0:
129
- message = wrap_user(message)
130
- if i == 0:
131
- message = wrap_sys(self.system) + message
132
- ret += self.sep + message
133
- else:
134
- message = wrap_assistant(message) + self.sep2
135
- ret += message
136
- else:
137
- ret += "<|assistant|>\n"
138
- ret = ret.lstrip(self.sep)
139
- elif self.sep_style == SeparatorStyle.QWEN_2:
140
- ret = self.system + self.sep
141
- for role, message in messages:
142
- if message:
143
- if type(message) is tuple:
144
- message, _, _ = message
145
- ret += role + message + self.sep
146
- else:
147
- ret += role
148
- elif self.sep_style == SeparatorStyle.PLAIN:
149
- seps = [self.sep, self.sep2]
150
- ret = self.system
151
- for i, (role, message) in enumerate(messages):
152
- if message:
153
- if type(message) is tuple:
154
- message, _, _ = message
155
- ret += message + seps[i % 2]
156
- else:
157
- ret += ""
158
- else:
159
- raise ValueError(f"Invalid style: {self.sep_style}")
160
-
161
- return ret
162
-
163
- def append_message(self, role, message):
164
- self.messages.append([role, message])
165
-
166
- def get_images(self, return_pil=False):
167
- images = []
168
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
169
- if i % 2 == 0:
170
- if type(msg) is tuple:
171
- import base64
172
- from io import BytesIO
173
- from PIL import Image
174
- msg, image, image_process_mode = msg
175
- if image_process_mode == "Pad":
176
- def expand2square(pil_img, background_color=(122, 116, 104)):
177
- width, height = pil_img.size
178
- if width == height:
179
- return pil_img
180
- elif width > height:
181
- result = Image.new(pil_img.mode, (width, width), background_color)
182
- result.paste(pil_img, (0, (width - height) // 2))
183
- return result
184
- else:
185
- result = Image.new(pil_img.mode, (height, height), background_color)
186
- result.paste(pil_img, ((height - width) // 2, 0))
187
- return result
188
- image = expand2square(image)
189
- elif image_process_mode in ["Default", "Crop"]:
190
- pass
191
- elif image_process_mode == "Resize":
192
- image = image.resize((336, 336))
193
- else:
194
- raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
195
- max_hw, min_hw = max(image.size), min(image.size)
196
- aspect_ratio = max_hw / min_hw
197
- max_len, min_len = 800, 400
198
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
199
- longest_edge = int(shortest_edge * aspect_ratio)
200
- W, H = image.size
201
- if longest_edge != max(image.size):
202
- if H > W:
203
- H, W = longest_edge, shortest_edge
204
- else:
205
- H, W = shortest_edge, longest_edge
206
- image = image.resize((W, H))
207
- if return_pil:
208
- images.append(image)
209
- else:
210
- buffered = BytesIO()
211
- image.save(buffered, format="PNG")
212
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
213
- images.append(img_b64_str)
214
- return images
215
-
216
- def to_gradio_chatbot(self):
217
- ret = []
218
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
219
- if i % 2 == 0:
220
- if type(msg) is tuple:
221
- import base64
222
- from io import BytesIO
223
- msg, image, image_process_mode = msg
224
- max_hw, min_hw = max(image.size), min(image.size)
225
- aspect_ratio = max_hw / min_hw
226
- max_len, min_len = 800, 400
227
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
228
- longest_edge = int(shortest_edge * aspect_ratio)
229
- W, H = image.size
230
- if H > W:
231
- H, W = longest_edge, shortest_edge
232
- else:
233
- H, W = shortest_edge, longest_edge
234
- image = image.resize((W, H))
235
- buffered = BytesIO()
236
- image.save(buffered, format="JPEG")
237
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
238
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
239
- msg = img_str + msg.replace('<image>', '').strip()
240
- ret.append([msg, None])
241
- else:
242
- ret.append([msg, None])
243
- else:
244
- ret[-1][-1] = msg
245
- return ret
246
-
247
- def copy(self):
248
- return Conversation(
249
- system=self.system,
250
- roles=self.roles,
251
- messages=[[x, y] for x, y in self.messages],
252
- offset=self.offset,
253
- sep_style=self.sep_style,
254
- sep=self.sep,
255
- sep2=self.sep2,
256
- version=self.version)
257
-
258
- def dict(self):
259
- if len(self.get_images()) > 0:
260
- return {
261
- "system": self.system,
262
- "roles": self.roles,
263
- "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
264
- "offset": self.offset,
265
- "sep": self.sep,
266
- "sep2": self.sep2,
267
- }
268
- return {
269
- "system": self.system,
270
- "roles": self.roles,
271
- "messages": self.messages,
272
- "offset": self.offset,
273
- "sep": self.sep,
274
- "sep2": self.sep2,
275
- }
276
-
277
-
278
-
279
-
280
- conv_phi_v0 = Conversation(
281
- system="A chat between a curious user and an artificial intelligence assistant. "
282
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
283
- roles=("USER", "ASSISTANT"),
284
- version="phi",
285
- messages=(),
286
- offset=0,
287
- sep_style=SeparatorStyle.TWO,
288
- sep=" ",
289
- sep2="<|endoftext|>",
290
- )
291
-
292
-
293
-
294
- def select_best_resolution(original_size, possible_resolutions):
295
- """
296
- Selects the best resolution from a list of possible resolutions based on the original size.
297
-
298
- Args:
299
- original_size (tuple): The original size of the image in the format (width, height).
300
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
301
-
302
- Returns:
303
- tuple: The best fit resolution in the format (width, height).
304
- """
305
- original_width, original_height = original_size
306
- best_fit = None
307
- max_effective_resolution = 0
308
- min_wasted_resolution = float('inf')
309
-
310
- for width, height in possible_resolutions:
311
- scale = min(width / original_width, height / original_height)
312
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
313
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
314
- wasted_resolution = (width * height) - effective_resolution
315
-
316
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
317
- max_effective_resolution = effective_resolution
318
- min_wasted_resolution = wasted_resolution
319
- best_fit = (width, height)
320
-
321
- return best_fit
322
-
323
-
324
- ## added by llava-1.6
325
- def resize_and_pad_image(image, target_resolution):
326
- """
327
- Resize and pad an image to a target resolution while maintaining aspect ratio.
328
-
329
- Args:
330
- image (PIL.Image.Image): The input image.
331
- target_resolution (tuple): The target resolution (width, height) of the image.
332
-
333
- Returns:
334
- PIL.Image.Image: The resized and padded image.
335
- """
336
- original_width, original_height = image.size
337
- target_width, target_height = target_resolution
338
-
339
- scale_w = target_width / original_width
340
- scale_h = target_height / original_height
341
-
342
- if scale_w < scale_h:
343
- new_width = target_width
344
- new_height = min(math.ceil(original_height * scale_w), target_height)
345
- else:
346
- new_height = target_height
347
- new_width = min(math.ceil(original_width * scale_h), target_width)
348
-
349
- # Resize the image
350
- resized_image = image.resize((new_width, new_height))
351
-
352
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
353
- paste_x = (target_width - new_width) // 2
354
- paste_y = (target_height - new_height) // 2
355
- new_image.paste(resized_image, (paste_x, paste_y))
356
-
357
- return new_image
358
-
359
-
360
- ## added by llava-1.6
361
- def divide_to_patches(image, patch_size):
362
- """
363
- Divides an image into patches of a specified size.
364
-
365
- Args:
366
- image (PIL.Image.Image): The input image.
367
- patch_size (int): The size of each patch.
368
-
369
- Returns:
370
- list: A list of PIL.Image.Image objects representing the patches.
371
- """
372
- patches = []
373
- width, height = image.size
374
- for i in range(0, height, patch_size):
375
- for j in range(0, width, patch_size):
376
- box = (j, i, j + patch_size, i + patch_size)
377
- patch = image.crop(box)
378
- patches.append(patch)
379
-
380
- return patches
381
-
382
-
383
- ## added by llava-1.6
384
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
385
- """
386
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
387
-
388
- Args:
389
- image_size (tuple): The size of the input image in the format (width, height).
390
- grid_pinpoints (str): A string representation of a list of possible resolutions.
391
- patch_size (int): The size of each image patch.
392
-
393
- Returns:
394
- tuple: The shape of the image patch grid in the format (width, height).
395
- """
396
- if type(grid_pinpoints) is list:
397
- possible_resolutions = grid_pinpoints
398
- else:
399
- possible_resolutions = ast.literal_eval(grid_pinpoints)
400
- width, height = select_best_resolution(image_size, possible_resolutions)
401
- return width // patch_size, height // patch_size
402
-
403
-
404
- ## added by llava-1.6
405
- def process_anyres_image(image, processor, grid_pinpoints):
406
- """
407
- Process an image with variable resolutions.
408
-
409
- Args:
410
- image (PIL.Image.Image): The input image to be processed.
411
- processor: The image processor object.
412
- grid_pinpoints (str): A string representation of a list of possible resolutions.
413
-
414
- Returns:
415
- torch.Tensor: A tensor containing the processed image patches.
416
- """
417
- if type(grid_pinpoints) is list:
418
- possible_resolutions = grid_pinpoints
419
- else:
420
- possible_resolutions = ast.literal_eval(grid_pinpoints)
421
- best_resolution = select_best_resolution(image.size, possible_resolutions)
422
- image_padded = resize_and_pad_image(image, best_resolution)
423
-
424
- patches = divide_to_patches(image_padded, processor.crop_size['height'])
425
-
426
- image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
427
-
428
- image_patches = [image_original_resize] + patches
429
- image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
430
- for image_patch in image_patches]
431
- return torch.stack(image_patches, dim=0)
432
-
433
-
434
- def load_image_from_base64(image):
435
- return Image.open(BytesIO(base64.b64decode(image)))
436
-
437
-
438
- def expand2square(pil_img, background_color):
439
- width, height = pil_img.size
440
- if width == height:
441
- return pil_img
442
- elif width > height:
443
- result = Image.new(pil_img.mode, (width, width), background_color)
444
- result.paste(pil_img, (0, (width - height) // 2))
445
- return result
446
- else:
447
- result = Image.new(pil_img.mode, (height, height), background_color)
448
- result.paste(pil_img, ((height - width) // 2, 0))
449
- return result
450
-
451
-
452
- def process_images(images, image_processor, model_cfg):
453
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
454
- new_images = []
455
- if image_aspect_ratio == 'pad':
456
- for image in images:
457
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
458
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
459
- new_images.append(image)
460
- elif image_aspect_ratio == "anyres":
461
- for image in images:
462
- image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
463
- new_images.append(image)
464
- else:
465
- return image_processor(images, return_tensors='pt')['pixel_values']
466
- if all(x.shape == new_images[0].shape for x in new_images):
467
- new_images = torch.stack(new_images, dim=0)
468
- return new_images
469
-
470
-
471
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
472
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
473
-
474
- def insert_separator(X, sep):
475
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
476
-
477
- input_ids = []
478
- offset = 0
479
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
480
- offset = 1
481
- input_ids.append(prompt_chunks[0][0])
482
-
483
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
484
- input_ids.extend(x[offset:])
485
-
486
- if return_tensors is not None:
487
- if return_tensors == 'pt':
488
- return torch.tensor(input_ids, dtype=torch.long)
489
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
490
- return input_ids
491
-
492
-
493
- def get_model_name_from_path(model_path):
494
- model_path = model_path.strip("/")
495
- model_paths = model_path.split("/")
496
- if model_paths[-1].startswith('checkpoint-'):
497
- return model_paths[-2] + "_" + model_paths[-1]
498
- else:
499
- return model_paths[-1]
500
-
501
-
502
- class KeywordsStoppingCriteria(StoppingCriteria):
503
- def __init__(self, keywords, tokenizer, input_ids):
504
- self.keywords = keywords
505
- self.keyword_ids = []
506
- self.max_keyword_len = 0
507
- for keyword in keywords:
508
- cur_keyword_ids = tokenizer(keyword).input_ids
509
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
510
- cur_keyword_ids = cur_keyword_ids[1:]
511
- if len(cur_keyword_ids) > self.max_keyword_len:
512
- self.max_keyword_len = len(cur_keyword_ids)
513
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
514
- self.tokenizer = tokenizer
515
- self.start_len = input_ids.shape[1]
516
-
517
- def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
518
- offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
519
- self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
520
- for keyword_id in self.keyword_ids:
521
- if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
522
- return True
523
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
524
- for keyword in self.keywords:
525
- if keyword in outputs:
526
- return True
527
- return False
528
-
529
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
530
- outputs = []
531
- for i in range(output_ids.shape[0]):
532
- outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
533
- return all(outputs)
534
-
535
-
536
-
537
- def load_image(image_file):
538
- if image_file.startswith("http") or image_file.startswith("https"):
539
- response = requests.get(image_file)
540
- image = Image.open(BytesIO(response.content)).convert("RGB")
541
- else:
542
- image = Image.open(image_file).convert("RGB")
543
- return image