mx262 commited on
Commit
44b62eb
1 Parent(s): cce75f9

Upload internvl_chat.py

Browse files
Files changed (1) hide show
  1. internvl_chat.py +277 -0
internvl_chat.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel, CLIPImageProcessor
3
+ import warnings
4
+ from PIL import Image
5
+ from .base import BaseModel
6
+ from ..smp import *
7
+ from ..dataset import DATASET_TYPE
8
+ import pandas as pd
9
+ import string
10
+ import torchvision.transforms as T
11
+ import transformers
12
+
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ import random
15
+
16
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
17
+ IMAGENET_STD = (0.229, 0.224, 0.225)
18
+
19
+
20
+ def build_transform(input_size):
21
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
22
+ transform = T.Compose([
23
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
24
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
25
+ T.ToTensor(),
26
+ T.Normalize(mean=MEAN, std=STD)
27
+ ])
28
+ return transform
29
+
30
+
31
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
32
+ best_ratio_diff = float('inf')
33
+ best_ratio = (1, 1)
34
+ area = width * height
35
+ for ratio in target_ratios:
36
+ target_aspect_ratio = ratio[0] / ratio[1]
37
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
38
+ if ratio_diff < best_ratio_diff:
39
+ best_ratio_diff = ratio_diff
40
+ best_ratio = ratio
41
+ elif ratio_diff == best_ratio_diff:
42
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
43
+ best_ratio = ratio
44
+ return best_ratio
45
+
46
+
47
+ def dynamic_preprocess(image, min_num=5, max_num=6, image_size=448, use_thumbnail=False):
48
+ orig_width, orig_height = image.size
49
+ aspect_ratio = orig_width / orig_height
50
+
51
+ # calculate the existing image aspect ratio
52
+ target_ratios = set(
53
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
54
+ i * j <= max_num and i * j >= min_num)
55
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
56
+
57
+ # find the closest aspect ratio to the target
58
+ target_aspect_ratio = find_closest_aspect_ratio(
59
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
60
+
61
+ # calculate the target width and height
62
+ target_width = image_size * target_aspect_ratio[0]
63
+ target_height = image_size * target_aspect_ratio[1]
64
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
65
+
66
+ # resize the image
67
+ resized_img = image.resize((target_width, target_height))
68
+ processed_images = []
69
+ for i in range(blocks):
70
+ box = (
71
+ (i % (target_width // image_size)) * image_size,
72
+ (i // (target_width // image_size)) * image_size,
73
+ ((i % (target_width // image_size)) + 1) * image_size,
74
+ ((i // (target_width // image_size)) + 1) * image_size
75
+ )
76
+ # split the image
77
+ split_img = resized_img.crop(box)
78
+ processed_images.append(split_img)
79
+ assert len(processed_images) == blocks
80
+ if use_thumbnail and len(processed_images) != 1:
81
+ thumbnail_img = image.resize((image_size, image_size))
82
+ processed_images.append(thumbnail_img)
83
+ return processed_images, target_aspect_ratio
84
+
85
+ def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None):
86
+ orig_width, orig_height = image.size
87
+ aspect_ratio = orig_width / orig_height
88
+
89
+ # calculate the existing image aspect ratio
90
+ target_ratios = set(
91
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
92
+ i * j <= max_num and i * j >= min_num)
93
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
94
+
95
+ new_target_ratios = []
96
+ if prior_aspect_ratio is not None:
97
+ for i in target_ratios:
98
+ if prior_aspect_ratio[0]%i[0] !=0 or prior_aspect_ratio[1]%i[1] !=0:
99
+ new_target_ratios.append(i)
100
+ else:
101
+ continue
102
+ # find the closest aspect ratio to the target
103
+ target_aspect_ratio = find_closest_aspect_ratio(
104
+ aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
105
+
106
+ # calculate the target width and height
107
+ target_width = image_size * target_aspect_ratio[0]
108
+ target_height = image_size * target_aspect_ratio[1]
109
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
110
+
111
+ # resize the image
112
+ resized_img = image.resize((target_width, target_height))
113
+ processed_images = []
114
+ for i in range(blocks):
115
+ box = (
116
+ (i % (target_width // image_size)) * image_size,
117
+ (i // (target_width // image_size)) * image_size,
118
+ ((i % (target_width // image_size)) + 1) * image_size,
119
+ ((i // (target_width // image_size)) + 1) * image_size
120
+ )
121
+ # split the image
122
+ split_img = resized_img.crop(box)
123
+ processed_images.append(split_img)
124
+ assert len(processed_images) == blocks
125
+ if use_thumbnail and len(processed_images) != 1:
126
+ thumbnail_img = image.resize((image_size, image_size))
127
+ processed_images.append(thumbnail_img)
128
+ return processed_images
129
+
130
+ def load_image(image_file, input_size=448, min_num=1, max_num=6):
131
+ image = Image.open(image_file).convert('RGB')
132
+ transform = build_transform(input_size=input_size)
133
+ images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
134
+ pixel_values = [transform(image) for image in images]
135
+ pixel_values = torch.stack(pixel_values)
136
+ return pixel_values, target_aspect_ratio
137
+
138
+ def load_image2(image_file, input_size=448, target_aspect_ratio=(1,1), min_num=1, max_num=6):
139
+ image = Image.open(image_file).convert('RGB')
140
+ transform = build_transform(input_size=input_size)
141
+ images = dynamic_preprocess2(image, image_size=input_size, prior_aspect_ratio=target_aspect_ratio, use_thumbnail=True, min_num=min_num, max_num=max_num)
142
+ pixel_values = [transform(image) for image in images]
143
+ pixel_values = torch.stack(pixel_values)
144
+ return pixel_values
145
+
146
+ class InternVLChat(BaseModel):
147
+
148
+ INSTALL_REQ = False
149
+ INTERLEAVE = False
150
+
151
+ def __init__(self, model_path='OpenGVLab/InternVL-Chat-V1-5', load_in_8bit=False, **kwargs):
152
+ assert model_path is not None
153
+ assert version_cmp(transformers.__version__, '4.36.2', 'ge')
154
+ self.model_path = model_path
155
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
156
+ device = torch.cuda.current_device()
157
+ self.device = device
158
+ self.model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16,
159
+ trust_remote_code=True,
160
+ load_in_8bit=load_in_8bit).eval()
161
+ if not load_in_8bit:
162
+ self.model = self.model.to(device)
163
+ self.image_size = self.model.config.vision_config.image_size
164
+
165
+ if 'V1-1' in model_path:
166
+ kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=5)
167
+ else:
168
+ kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1)
169
+ kwargs_default.update(kwargs)
170
+ self.kwargs = kwargs_default
171
+ warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')
172
+
173
+ def use_custom_prompt(self, dataset):
174
+ return True
175
+
176
+ def build_multi_choice_prompt(self, line, dataset=None):
177
+ question = line['question']
178
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
179
+ if hint is not None:
180
+ question = hint + '\n' + question
181
+
182
+ options = {
183
+ cand: line[cand]
184
+ for cand in string.ascii_uppercase
185
+ if cand in line and not pd.isna(line[cand])
186
+ }
187
+ for key, item in options.items():
188
+ question += f'\n{key}. {item}'
189
+ prompt = question
190
+
191
+ if len(options):
192
+ prompt += '\n请直接回答选项字母。' if cn_string(
193
+ prompt) else "\nAnswer with the option's letter from the given choices directly."
194
+ else:
195
+ prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
196
+
197
+ return prompt
198
+
199
+ def build_prompt(self, line, dataset=None):
200
+ assert self.use_custom_prompt(dataset)
201
+ assert dataset is None or isinstance(dataset, str)
202
+ tgt_path = self.dump_image(line, dataset)
203
+
204
+ if 'V1-1' in self.model_path:
205
+ kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=5)
206
+ else:
207
+ kwargs_default = dict(do_sample=False, max_new_tokens=1024, top_p=None, num_beams=1)
208
+ self.kwargs = kwargs_default
209
+ if dataset is not None and listinstr(['MME'], dataset):
210
+ question = line['question']
211
+ prompt = question + ' Answer the question using a single word or phrase.'
212
+ if 'V1-2' not in self.model_path:
213
+ self.kwargs = dict(do_sample=True, max_new_tokens=5, top_k=50, num_beams=5, top_p=0.9)
214
+ elif dataset is not None and listinstr(['HallusionBench'], dataset):
215
+ question = line['question']
216
+ prompt = question + ' Please answer yes or no. Answer the question using a single word or phrase.'
217
+ elif dataset is not None and DATASET_TYPE(dataset) == 'multi-choice':
218
+ prompt = self.build_multi_choice_prompt(line, dataset)
219
+ elif dataset is not None and DATASET_TYPE(dataset) == 'VQA':
220
+ if 'MathVista' in dataset:
221
+ prompt = line['question']
222
+ elif listinstr(['LLaVABench'], dataset):
223
+ question = line['question']
224
+ prompt = question + '\nAnswer this question in detail.'
225
+ elif listinstr(['MMVet'], dataset):
226
+ prompt = line['question']
227
+ else:
228
+ question = line['question']
229
+ prompt = question + '\nAnswer the question using a single word or phrase.'
230
+ else:
231
+ prompt = line['question']
232
+
233
+ message = [dict(type='text', value=prompt)]
234
+ message.extend([dict(type='image', value=s) for s in tgt_path])
235
+
236
+ return message
237
+
238
+ def generate(self, message, dataset=None):
239
+ prompt, image_path = self.message_to_promptimg(message)
240
+ if dataset is not None and listinstr(['ChartQA_TEST'], dataset):
241
+ self.max_num = 12
242
+ self.max_num2 = 3
243
+ elif dataset is not None and listinstr(['DocVQA_VAL', 'DocVQA_TEST', 'TextVQA_VAL'], dataset):
244
+ self.max_num = 23
245
+ self.max_num2 = 15
246
+ self.min_num = 14
247
+ self.min_num2 = 5
248
+ elif dataset is not None and listinstr(['InfoVQA_VAL', 'InfoVQA_TEST'], dataset):
249
+ self.max_num = 23
250
+ self.max_num2 = 5
251
+ self.min_num = 15
252
+ self.min_num2 = 3
253
+ elif dataset is not None and listinstr(['OCRBench'], dataset):
254
+ self.max_num = 24
255
+ self.max_num2 = 8
256
+ self.min_num = 9
257
+ self.min_num2 = 5
258
+ else:
259
+ self.max_num = 8
260
+ self.max_num2 = 4
261
+ self.min_num = 3
262
+ self.min_num2 = 1
263
+ pixel_values, target_aspect_ratio = load_image(image_path, min_num=self.min_num, max_num=self.max_num)
264
+ pixel_values = pixel_values.cuda().to(torch.bfloat16)
265
+ pixel_values2 = load_image2(image_path, target_aspect_ratio=target_aspect_ratio, min_num=self.min_num2, max_num=self.max_num2)
266
+ pixel_values2 = pixel_values2.cuda().to(torch.bfloat16)
267
+ pixel_values = torch.cat((pixel_values[:-1], pixel_values2[:-1], pixel_values[-1:]), 0)
268
+
269
+ with torch.no_grad():
270
+ response = self.model.chat(self.tokenizer, pixel_values=pixel_values, target_aspect_ratio=target_aspect_ratio,
271
+ question=prompt, generation_config=self.kwargs)
272
+ response = response.split('[UNUSED_TOKEN_145]')[0]
273
+
274
+ return response
275
+
276
+ def generate_inner(self, message, dataset=None):
277
+ return self.generate(message, dataset)