yuzaa commited on
Commit
3b1479a
1 Parent(s): c67ef57

support stream and system prompt (#15)

Browse files

- support stream and sysyem prompt (9a7eb1ec0769ca938d7120f42edb9b942a1d431c)

Files changed (2) hide show
  1. README.md +19 -2
  2. modeling_minicpmv.py +44 -4
README.md CHANGED
@@ -377,10 +377,27 @@ res = model.chat(
377
  image=image,
378
  msgs=msgs,
379
  tokenizer=tokenizer,
380
- sampling=True,
381
- temperature=0.7
 
382
  )
383
  print(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  ```
385
 
386
  Please look at [GitHub](https://github.com/OpenBMB/MiniCPM-V) for more detail about usage.
 
377
  image=image,
378
  msgs=msgs,
379
  tokenizer=tokenizer,
380
+ sampling=True, # if sampling=False, beam_search will be used by default
381
+ temperature=0.7
382
+ # system_prompt='' # pass system_prompt if needed
383
  )
384
  print(res)
385
+
386
+ ## if you want to use streaming, please make sure sampling=True and stream=True
387
+ ## the model.chat will return a generator
388
+ res = model.chat(
389
+ image=image,
390
+ msgs=msgs,
391
+ tokenizer=tokenizer,
392
+ sampling=True,
393
+ temperature=0.7,
394
+ stream=True
395
+ )
396
+
397
+ generated_text = ""
398
+ for new_text in res:
399
+ generated_text += new_text
400
+ print(new_text, flush=True, end='')
401
  ```
402
 
403
  Please look at [GitHub](https://github.com/OpenBMB/MiniCPM-V) for more detail about usage.
modeling_minicpmv.py CHANGED
@@ -3,10 +3,11 @@ from typing import List, Optional
3
  import json
4
  import torch
5
  import torchvision
 
6
  from copy import deepcopy
7
  from PIL import Image
8
  from torchvision import transforms
9
- from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast
10
  from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
11
 
12
  from .configuration_minicpm import MiniCPMVConfig
@@ -218,6 +219,25 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
218
  **kwargs
219
  )
220
  return self._decode_text(output, tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  def _decode_text(self, result_ids, tokenizer):
223
  result_text = []
@@ -294,6 +314,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
294
  max_inp_length: Optional[int] = None,
295
  vision_hidden_states=None,
296
  return_vision_hidden_states=False,
 
297
  **kwargs
298
  ):
299
 
@@ -326,7 +347,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
326
  vision_hidden_states,
327
  ) = self.get_vllm_embedding(model_inputs)
328
 
329
- result = self._decode(model_inputs["inputs_embeds"], tokenizer, **kwargs)
 
 
 
330
 
331
  if return_vision_hidden_states:
332
  return result, vision_hidden_states
@@ -342,6 +366,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
342
  max_new_tokens=1024,
343
  sampling=True,
344
  max_inp_length=2048,
 
 
345
  **kwargs
346
  ):
347
  if isinstance(msgs, str):
@@ -349,6 +375,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
349
 
350
  copy_msgs = deepcopy(msgs)
351
  assert len(copy_msgs) > 0, 'msgs is empty'
 
352
 
353
  if image is not None and isinstance(copy_msgs[0]['content'], str):
354
  copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
@@ -393,6 +420,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
393
  if tgt_sizes:
394
  tgt_sizes = torch.vstack(tgt_sizes)
395
 
 
 
 
 
396
  input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
397
 
398
  if sampling:
@@ -423,11 +454,20 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
423
  max_new_tokens=max_new_tokens,
424
  vision_hidden_states=vision_hidden_states,
425
  return_vision_hidden_states=True,
 
426
  **generation_config
427
  )
428
- answer = res[0]
429
 
430
- return answer
 
 
 
 
 
 
 
 
 
431
 
432
 
433
  class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):
 
3
  import json
4
  import torch
5
  import torchvision
6
+ from threading import Thread
7
  from copy import deepcopy
8
  from PIL import Image
9
  from torchvision import transforms
10
+ from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast, TextIteratorStreamer
11
  from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
12
 
13
  from .configuration_minicpm import MiniCPMVConfig
 
219
  **kwargs
220
  )
221
  return self._decode_text(output, tokenizer)
222
+
223
+ def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
224
+ terminators = [
225
+ tokenizer.eos_token_id,
226
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
227
+ ]
228
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
229
+ generation_kwargs = {
230
+ 'inputs_embeds': inputs_embeds,
231
+ 'pad_token_id': 0,
232
+ 'eos_token_id': terminators,
233
+ 'streamer': streamer
234
+ }
235
+ generation_kwargs.update(kwargs)
236
+
237
+ thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
238
+ thread.start()
239
+
240
+ return streamer
241
 
242
  def _decode_text(self, result_ids, tokenizer):
243
  result_text = []
 
314
  max_inp_length: Optional[int] = None,
315
  vision_hidden_states=None,
316
  return_vision_hidden_states=False,
317
+ stream=False,
318
  **kwargs
319
  ):
320
 
 
347
  vision_hidden_states,
348
  ) = self.get_vllm_embedding(model_inputs)
349
 
350
+ if stream:
351
+ result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
352
+ else:
353
+ result = self._decode(model_inputs["inputs_embeds"], tokenizer, **kwargs)
354
 
355
  if return_vision_hidden_states:
356
  return result, vision_hidden_states
 
366
  max_new_tokens=1024,
367
  sampling=True,
368
  max_inp_length=2048,
369
+ system_prompt='',
370
+ stream=False,
371
  **kwargs
372
  ):
373
  if isinstance(msgs, str):
 
375
 
376
  copy_msgs = deepcopy(msgs)
377
  assert len(copy_msgs) > 0, 'msgs is empty'
378
+ assert sampling or not stream, 'if use stream mode, make sure sampling=True'
379
 
380
  if image is not None and isinstance(copy_msgs[0]['content'], str):
381
  copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
 
420
  if tgt_sizes:
421
  tgt_sizes = torch.vstack(tgt_sizes)
422
 
423
+ if system_prompt:
424
+ sys_msg = {'role': 'system', 'content': system_prompt}
425
+ copy_msgs = [sys_msg] + copy_msgs
426
+
427
  input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
428
 
429
  if sampling:
 
454
  max_new_tokens=max_new_tokens,
455
  vision_hidden_states=vision_hidden_states,
456
  return_vision_hidden_states=True,
457
+ stream=stream,
458
  **generation_config
459
  )
 
460
 
461
+ if stream:
462
+ def stream_gen():
463
+ for text in res:
464
+ text = text.replace(tokenizer.eot_token, '').replace(tokenizer.eos_token, '')
465
+ yield text
466
+ return stream_gen()
467
+
468
+ else:
469
+ answer = res[0]
470
+ return answer
471
 
472
 
473
  class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):