Text Generation
Transformers
Safetensors
English
llava_phi
custom_code
g-h-chen commited on
Commit
9e4a5ec
1 Parent(s): 352d2e9

upload generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +287 -0
generation_utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from queue import Queue
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from copy import deepcopy
7
+ import requests, os
8
+
9
+ IMAGE_TOKEN_INDEX=-200
10
+ blacklist = ['<image>', '<s>', '</s>']
11
+ max_num_images = 3 # phi has a context length limit of 2048 and each image occupies 576 tokens.
12
+
13
+ def input_moderation(texts: list[list[str]]):
14
+ # perform input moderation on each message
15
+ for text_pair in texts:
16
+ # in-place operation
17
+ for b in blacklist:
18
+ text_pair[0] = text_pair[0].replace(b, '')
19
+ if text_pair[1] is not None:
20
+ text_pair[1] = text_pair[1].replace(b, '')
21
+
22
+ return texts
23
+
24
+ def insert_image_placeholder(t, num_images, placeholder='<image>', sep='\n'):
25
+ for _ in range(num_images):
26
+ t = f"{placeholder}{sep}" + t
27
+ return t
28
+
29
+ def get_conv(texts):
30
+ ret = []
31
+
32
+ for conv in texts:
33
+ ret.append({'from': 'human', 'value': conv[0]})
34
+ ret.append({'from': 'gpt', 'value': conv[1]}) # this is None for the last one
35
+
36
+ return ret
37
+
38
+ # copied from llava
39
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
40
+ prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('<image>')]
41
+
42
+ def insert_separator(X, sep):
43
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
44
+
45
+ input_ids = []
46
+ offset = 0
47
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
48
+ offset = 1
49
+ input_ids.append(prompt_chunks[0][0])
50
+
51
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
52
+ input_ids.extend(x[offset:])
53
+
54
+ if return_tensors is not None:
55
+ if return_tensors == 'pt':
56
+ return torch.tensor(input_ids, dtype=torch.long)
57
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
58
+ return input_ids
59
+
60
+ def preprocess(tokenizer, data: list, return_tensors='pt'):
61
+ '''
62
+ [
63
+ {
64
+ 'from': 'human',
65
+ 'value': xxx,
66
+ },
67
+ {
68
+ 'from': 'gpt',
69
+ 'value': xxx
70
+ }
71
+ ]
72
+ '''
73
+ # needs update
74
+ if not isinstance(data, list):
75
+ raise ValueError('must be a list')
76
+
77
+ # this is per model (tokenizer)
78
+ return preprocess_allava(tokenizer, data, return_tensors=return_tensors)
79
+
80
+
81
+
82
+ def preprocess_vicuna_v1(self, convs: list, return_tensors) -> list: # tokenize and concat the coversations
83
+ input_ids = None
84
+ for ind, conv in enumerate(convs):
85
+ if ind % 2 == 0: # human
86
+ h = conv['value'].strip()
87
+ h = f"USER: {h} "
88
+ cur_input_ids = self.tokenizer_image_token(prompt=h, return_tensors=return_tensors)
89
+
90
+ if input_ids is None:
91
+ input_ids = cur_input_ids
92
+ else:
93
+ input_ids = torch.cat([input_ids, cur_input_ids])
94
+
95
+ else: # gpt
96
+ g = conv['value']
97
+ if g is not None:
98
+ cur_input_ids = self.tokenizer(f"ASSISTANT: {g}</s>", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0]
99
+ input_ids = torch.cat([input_ids, cur_input_ids])
100
+ else:
101
+ cur_input_ids = self.tokenizer(f"ASSISTANT:", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0]
102
+ input_ids = torch.cat([input_ids, cur_input_ids])
103
+
104
+
105
+ return input_ids
106
+
107
+ def preprocess_allava(tokenizer, convs: list, return_tensors) -> list: # tokenize and concat the coversations
108
+ input_ids = None
109
+
110
+ for ind, conv in enumerate(convs):
111
+ if ind % 2 == 0: # human
112
+ h = conv['value'].strip()
113
+ h = f"[INST] {h} [/INST] "
114
+ cur_input_ids = tokenizer_image_token(prompt=h, tokenizer=tokenizer, return_tensors=return_tensors)
115
+
116
+ if input_ids is None:
117
+ input_ids = cur_input_ids
118
+ else:
119
+ input_ids = torch.cat([input_ids, cur_input_ids])
120
+
121
+ else: # gpt
122
+ g = conv['value']
123
+ if g is not None:
124
+ cur_input_ids = tokenizer(f"{g}{tokenizer.eos_token}", add_special_tokens= False, truncation=True, return_tensors='pt').input_ids[0]
125
+ input_ids = torch.cat([input_ids, cur_input_ids])
126
+
127
+ return input_ids
128
+
129
+
130
+ # copied from llava
131
+ def get_image_tensors(processor, images, device):
132
+ list_image_tensors = []
133
+ crop_size = processor.crop_size
134
+ for fp in images:
135
+ if fp is None: # None is used as a placeholder
136
+ list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(device))
137
+ continue
138
+ elif isinstance(fp, str):
139
+ image = Image.open(fp).convert('RGB')
140
+ elif isinstance(fp, Image.Image):
141
+ image = fp # already an image
142
+ else:
143
+ raise TypeError(f'Unsupported type {type(fp)}')
144
+
145
+ # this is the way of preprocessing images we used in training, so we impose it here
146
+ if True:
147
+ # self.data_args.image_aspect_ratio == 'pad'
148
+ def expand2square(pil_img, background_color):
149
+ width, height = pil_img.size
150
+ if pil_img.mode == 'L':
151
+ pil_img = pil_img.convert('RGB')
152
+
153
+ if width == height:
154
+ return pil_img
155
+ elif width > height:
156
+ result = Image.new(pil_img.mode, (width, width), background_color)
157
+ result.paste(pil_img, (0, (width - height) // 2))
158
+ return result
159
+ else:
160
+ result = Image.new(pil_img.mode, (height, height), background_color)
161
+ result.paste(pil_img, ((height - width) // 2, 0))
162
+ return result
163
+
164
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
165
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
166
+ else:
167
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] # a tensor
168
+ list_image_tensors.append(image.to(device))
169
+ # list_image_tensors.append(image)
170
+ return list_image_tensors
171
+
172
+
173
+
174
+
175
+ def build_allava_input(tokenizer, processor, texts, images, history=None, return_history=False, device='cuda'):
176
+ '''
177
+ texts: [[]]
178
+ '''
179
+
180
+ ############################
181
+ # 1. preprocess texts
182
+ ############################
183
+ if isinstance(texts, str):
184
+ texts = [[texts, None]]
185
+ else:
186
+ assert isinstance(texts, list) and isinstance(texts[0], list) , 'texts must be a list of list'
187
+
188
+ if history is not None:
189
+ texts = history + texts # concat them together
190
+
191
+ texts = input_moderation(texts)
192
+
193
+
194
+ ############################
195
+ # 2. preprocess images
196
+ ############################
197
+ if isinstance(images, str) or isinstance(images, Image.Image):
198
+ images = [images]
199
+
200
+ valid_images = []
201
+ if images is None:
202
+ images = [None]
203
+
204
+ for img in images:
205
+ try:
206
+ if os.path.exists(img): # make sure that the path exists
207
+ img = Image.open(img).convert('RGB')
208
+ else: # else it must be a URL
209
+ img = Image.open(requests.get(img, stream=True).raw)
210
+
211
+ valid_images.append(img)
212
+ except:
213
+ continue
214
+
215
+ images = valid_images
216
+
217
+ if images == []:
218
+ images = [None]
219
+
220
+
221
+ assert len(images) < max_num_images, f'Currently at most {max_num_images} images are supported'
222
+
223
+ ############################
224
+ # 3. collate conv
225
+ ############################
226
+
227
+ history = deepcopy(texts) # history is the texts without <image> placeholders
228
+
229
+ # insert <image>
230
+ image_place_holder_inserted = insert_image_placeholder(texts[0][0], len(images) if None not in images else 0) # only insert the placeholders for user input at the 1st round
231
+ texts[0][0] = image_place_holder_inserted
232
+
233
+ # collate strings into conv
234
+ conv = get_conv(texts)
235
+
236
+ # make input ids
237
+ input_ids = preprocess(tokenizer, conv, return_tensors='pt').unsqueeze(0).to(device)
238
+
239
+ list_image_tensors = get_image_tensors(processor, images, device)
240
+ image_tensors = torch.stack(list_image_tensors)
241
+
242
+ try:
243
+ dtype = torch.bfloat16
244
+ # if your hardware does not support bf16, the following line raises an error
245
+ torch.tensor(1, dtype=dtype).cuda()
246
+ except:
247
+ # default using fp16
248
+ dtype = torch.float16
249
+
250
+ if return_history:
251
+ return input_ids, image_tensors, history
252
+
253
+ return input_ids, image_tensors, None
254
+
255
+
256
+
257
+ class TextIterStreamer:
258
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
259
+ self.tokenizer = tokenizer
260
+ self.skip_prompt = skip_prompt
261
+ self.skip_special_tokens = skip_special_tokens
262
+ self.tokens = []
263
+ self.text_queue = Queue()
264
+ self.next_tokens_are_prompt = True
265
+
266
+ def put(self, value):
267
+ if self.skip_prompt and self.next_tokens_are_prompt:
268
+ self.next_tokens_are_prompt = False
269
+ else:
270
+ if len(value.shape) > 1:
271
+ value = value[0]
272
+ self.tokens.extend(value.tolist())
273
+ self.text_queue.put(
274
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
275
+
276
+ def end(self):
277
+ self.text_queue.put(None)
278
+
279
+ def __iter__(self):
280
+ return self
281
+
282
+ def __next__(self):
283
+ value = self.text_queue.get()
284
+ if value is None:
285
+ raise StopIteration()
286
+ else:
287
+ return value