shwu
commited on
Commit
•
d6e13c5
1
Parent(s):
51e5ad2
feat: new generation code
Browse files- modeling_blip2chatglm.py +90 -82
- modeling_chatglm.py +2 -4
modeling_blip2chatglm.py
CHANGED
@@ -189,9 +189,10 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
|
|
189 |
def prepare_inputs_for_chat(
|
190 |
self,
|
191 |
tokenizer: PreTrainedTokenizer,
|
192 |
-
|
193 |
-
histories: List[List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]]],
|
194 |
max_length: int,
|
|
|
|
|
195 |
):
|
196 |
device = self.device
|
197 |
nvtokens = self.config.num_query_tokens
|
@@ -199,80 +200,76 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
|
|
199 |
all_images = []
|
200 |
all_image_slots = []
|
201 |
all_input_ids = []
|
202 |
-
for
|
|
|
203 |
image_slots = []
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
input_ids.extend([tokenizer.unk_token_id] * nvtokens)
|
214 |
-
else:
|
215 |
-
qtext = query
|
216 |
-
qimg = None
|
217 |
-
input_ids += tokenizer(qtext + f"\n答:").input_ids
|
218 |
-
if qimg is not None:
|
219 |
-
all_images.append(qimg)
|
220 |
-
image_slots.append(
|
221 |
-
len(input_ids) - slot_offset
|
222 |
-
) # count from backward
|
223 |
-
|
224 |
-
for ri, (q, r) in enumerate(reversed(history)):
|
225 |
-
if len(input_ids) >= max_length:
|
226 |
-
break
|
227 |
-
i = len(history) - ri - 1
|
228 |
-
cur_input_ids: List[int] = tokenizer(
|
229 |
-
f"[Round {i}]\n问:", add_special_tokens=False
|
230 |
).input_ids
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
qtext = q
|
238 |
-
qimg = None
|
239 |
-
cur_input_ids += tokenizer(
|
240 |
-
qtext + f"\n答:{r}\n", add_special_tokens=False
|
241 |
).input_ids
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
)
|
248 |
-
|
249 |
-
input_ids = []
|
250 |
-
if isinstance(query, tuple):
|
251 |
-
qtext, qimg = query
|
252 |
# image slot, embedding will be replaced by image embeddings
|
253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
else:
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
261 |
|
262 |
if len(input_ids) >= max_length:
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
input_ids
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
all_image_slots.append(image_slots)
|
277 |
all_input_ids.append(input_ids)
|
278 |
|
@@ -316,9 +313,12 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
|
|
316 |
input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long)
|
317 |
input_ids = input_ids.to(device)
|
318 |
inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
|
319 |
-
|
320 |
-
for
|
321 |
-
|
|
|
|
|
|
|
322 |
|
323 |
return input_ids, inputs_embeds
|
324 |
|
@@ -326,22 +326,25 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
|
|
326 |
def batch_chat(
|
327 |
self,
|
328 |
tokenizer: PreTrainedTokenizer,
|
329 |
-
|
330 |
-
histories: List[List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]]],
|
331 |
max_length: int = 2048,
|
332 |
num_beams=1,
|
333 |
do_sample=True,
|
334 |
top_p=0.7,
|
335 |
temperature=0.95,
|
336 |
-
|
|
|
337 |
**kwargs,
|
338 |
):
|
339 |
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
|
340 |
-
tokenizer,
|
|
|
|
|
|
|
|
|
341 |
)
|
342 |
|
343 |
-
|
344 |
-
logits_processor = LogitsProcessorList()
|
345 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
346 |
gen_kwargs = {
|
347 |
"max_length": max_length,
|
@@ -367,17 +370,22 @@ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
|
|
367 |
def stream_chat(
|
368 |
self,
|
369 |
tokenizer: PreTrainedTokenizer,
|
370 |
-
|
371 |
-
history: List[Tuple[Union[str, Tuple[str, torch.Tensor]], str]],
|
372 |
num_beams=5,
|
373 |
-
max_length=
|
374 |
top_p=0.9,
|
375 |
do_sample=True,
|
376 |
temperature=1,
|
|
|
|
|
377 |
**kwargs,
|
378 |
):
|
379 |
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
|
380 |
-
tokenizer,
|
|
|
|
|
|
|
|
|
381 |
)
|
382 |
|
383 |
logits_processor = LogitsProcessorList()
|
|
|
189 |
def prepare_inputs_for_chat(
|
190 |
self,
|
191 |
tokenizer: PreTrainedTokenizer,
|
192 |
+
batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
|
|
|
193 |
max_length: int,
|
194 |
+
user_role: str = "问",
|
195 |
+
bot_role: str = "答",
|
196 |
):
|
197 |
device = self.device
|
198 |
nvtokens = self.config.num_query_tokens
|
|
|
200 |
all_images = []
|
201 |
all_image_slots = []
|
202 |
all_input_ids = []
|
203 |
+
for messages in batch_messages:
|
204 |
+
images = []
|
205 |
image_slots = []
|
206 |
+
input_ids = []
|
207 |
+
|
208 |
+
round_roles = [set()]
|
209 |
+
for role, qtext, qimgs in messages:
|
210 |
+
if role in round_roles[-1]:
|
211 |
+
# a new round (not the first round)
|
212 |
+
input_ids += tokenizer(
|
213 |
+
f"\n[Round {len(round_roles)}]\n{role}:",
|
214 |
+
add_special_tokens=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
).input_ids
|
216 |
+
round_roles.append({role})
|
217 |
+
else:
|
218 |
+
round_roles[-1].add(role)
|
219 |
+
input_ids += tokenizer(
|
220 |
+
# For first role, no new line
|
221 |
+
f"\n{role}:" if len(input_ids) != 0 else f"{role}:", add_special_tokens=False
|
|
|
|
|
|
|
|
|
222 |
).input_ids
|
223 |
+
cur_index = 0
|
224 |
+
for qimg, img_idx in qimgs:
|
225 |
+
if img_idx > cur_index:
|
226 |
+
input_ids += tokenizer(
|
227 |
+
qtext[cur_index:img_idx], add_special_tokens=False
|
228 |
+
).input_ids
|
229 |
+
cur_index = img_idx
|
|
|
|
|
|
|
230 |
# image slot, embedding will be replaced by image embeddings
|
231 |
+
image_slots.append(len(input_ids))
|
232 |
+
input_ids += [tokenizer.unk_token_id] * nvtokens
|
233 |
+
images.append(qimg)
|
234 |
+
input_ids += tokenizer(
|
235 |
+
qtext[cur_index:], add_special_tokens=False
|
236 |
+
).input_ids
|
237 |
+
if len(round_roles) == 1:
|
238 |
+
# only 1 round
|
239 |
+
if len(round_roles[0]) == 1 and user_role in round_roles[0]:
|
240 |
+
# only user role
|
241 |
+
input_ids += tokenizer("").input_ids
|
242 |
else:
|
243 |
+
input_ids += tokenizer(f"\n{bot_role}:").input_ids
|
244 |
+
else:
|
245 |
+
# add tag for round 0
|
246 |
+
input_ids = (
|
247 |
+
tokenizer(f"[Round 0]\n", add_special_tokens=False).input_ids
|
248 |
+
+ input_ids
|
249 |
+
)
|
250 |
+
input_ids += tokenizer(f"\n{bot_role}:").input_ids
|
251 |
|
252 |
if len(input_ids) >= max_length:
|
253 |
+
image_slots_after_truncate = []
|
254 |
+
images_after_truncate = []
|
255 |
+
truncate_index = len(input_ids) - max_length
|
256 |
+
for image_slot, image in zip(image_slots, images):
|
257 |
+
# truncate from left
|
258 |
+
if len(input_ids) - image_slot < max_length:
|
259 |
+
image_slots_after_truncate.append(image_slot)
|
260 |
+
images_after_truncate.append(image)
|
261 |
+
elif len(input_ids) - (image_slot + nvtokens) < max_length:
|
262 |
+
# in-contact image slot is not allowed
|
263 |
+
truncate_index = max(truncate_index, image_slot + nvtokens)
|
264 |
+
for i, image_slot in enumerate(image_slots_after_truncate):
|
265 |
+
image_slots_after_truncate[i] = image_slot - truncate_index
|
266 |
+
input_ids = input_ids[truncate_index:]
|
267 |
+
image_slots = image_slots_after_truncate
|
268 |
+
images = images_after_truncate
|
269 |
+
|
270 |
+
# print(tokenizer.convert_ids_to_tokens(input_ids))
|
271 |
+
|
272 |
+
all_images.extend(images)
|
273 |
all_image_slots.append(image_slots)
|
274 |
all_input_ids.append(input_ids)
|
275 |
|
|
|
313 |
input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long)
|
314 |
input_ids = input_ids.to(device)
|
315 |
inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
|
316 |
+
if all_vtokens is not None:
|
317 |
+
for i, (image_slots, vtokens) in enumerate(
|
318 |
+
zip(all_image_slots, all_vtokens)
|
319 |
+
):
|
320 |
+
for slot, vimg in zip(image_slots, vtokens):
|
321 |
+
inputs_embeds[i][slot : slot + nvtokens, :] = vimg
|
322 |
|
323 |
return input_ids, inputs_embeds
|
324 |
|
|
|
326 |
def batch_chat(
|
327 |
self,
|
328 |
tokenizer: PreTrainedTokenizer,
|
329 |
+
batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
|
|
|
330 |
max_length: int = 2048,
|
331 |
num_beams=1,
|
332 |
do_sample=True,
|
333 |
top_p=0.7,
|
334 |
temperature=0.95,
|
335 |
+
user_role: str = "问",
|
336 |
+
bot_role: str = "答",
|
337 |
**kwargs,
|
338 |
):
|
339 |
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
|
340 |
+
tokenizer=tokenizer,
|
341 |
+
batch_messages=batch_messages,
|
342 |
+
max_length=max_length,
|
343 |
+
user_role=user_role,
|
344 |
+
bot_role=bot_role,
|
345 |
)
|
346 |
|
347 |
+
logits_processor = LogitsProcessorList()
|
|
|
348 |
logits_processor.append(InvalidScoreLogitsProcessor())
|
349 |
gen_kwargs = {
|
350 |
"max_length": max_length,
|
|
|
370 |
def stream_chat(
|
371 |
self,
|
372 |
tokenizer: PreTrainedTokenizer,
|
373 |
+
messages: List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]],
|
|
|
374 |
num_beams=5,
|
375 |
+
max_length=512,
|
376 |
top_p=0.9,
|
377 |
do_sample=True,
|
378 |
temperature=1,
|
379 |
+
user_role: str = "问",
|
380 |
+
bot_role: str = "答",
|
381 |
**kwargs,
|
382 |
):
|
383 |
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
|
384 |
+
tokenizer=tokenizer,
|
385 |
+
batch_messages=[messages],
|
386 |
+
max_length=max_length,
|
387 |
+
user_role=user_role,
|
388 |
+
bot_role=bot_role,
|
389 |
)
|
390 |
|
391 |
logits_processor = LogitsProcessorList()
|
modeling_chatglm.py
CHANGED
@@ -970,6 +970,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
970 |
|
971 |
if attention_mask is None:
|
972 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
|
|
|
|
973 |
|
974 |
for i, layer in enumerate(self.layers):
|
975 |
|
@@ -1095,10 +1097,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1095 |
[position_ids, new_position_id], dim=-1
|
1096 |
)
|
1097 |
|
1098 |
-
# set to None as prepare_inputs_for_generation use past for input embeds
|
1099 |
-
if "inputs_embeds" in model_kwargs:
|
1100 |
-
model_kwargs["inputs_embeds"] = None
|
1101 |
-
|
1102 |
return model_kwargs
|
1103 |
|
1104 |
def prepare_inputs_for_generation(
|
|
|
970 |
|
971 |
if attention_mask is None:
|
972 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
973 |
+
else:
|
974 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
975 |
|
976 |
for i, layer in enumerate(self.layers):
|
977 |
|
|
|
1097 |
[position_ids, new_position_id], dim=-1
|
1098 |
)
|
1099 |
|
|
|
|
|
|
|
|
|
1100 |
return model_kwargs
|
1101 |
|
1102 |
def prepare_inputs_for_generation(
|