qnguyen3 commited on
Commit
c36d5bb
1 Parent(s): a3db70a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -3
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
5
  import re
6
  import time
@@ -22,6 +22,40 @@ model = AutoModelForCausalLM.from_pretrained(
22
  device_map='auto',
23
  trust_remote_code=True)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  @spaces.GPU
27
  def bot_streaming(message, history):
@@ -60,10 +94,13 @@ def bot_streaming(message, history):
60
  add_generation_prompt=True)
61
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
62
  input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
63
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens = True)
 
 
 
64
 
65
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
66
- generation_kwargs = dict(input_ids=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100)
67
  generated_text = ""
68
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
69
  thread.start()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
4
  from threading import Thread
5
  import re
6
  import time
 
22
  device_map='auto',
23
  trust_remote_code=True)
24
 
25
+ class KeywordsStoppingCriteria(StoppingCriteria):
26
+ def __init__(self, keywords, tokenizer, input_ids):
27
+ self.keywords = keywords
28
+ self.keyword_ids = []
29
+ self.max_keyword_len = 0
30
+ for keyword in keywords:
31
+ cur_keyword_ids = tokenizer(keyword).input_ids
32
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
33
+ cur_keyword_ids = cur_keyword_ids[1:]
34
+ if len(cur_keyword_ids) > self.max_keyword_len:
35
+ self.max_keyword_len = len(cur_keyword_ids)
36
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
37
+ self.tokenizer = tokenizer
38
+ self.start_len = input_ids.shape[1]
39
+
40
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
41
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
42
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
43
+ for keyword_id in self.keyword_ids:
44
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
45
+ if torch.equal(truncated_output_ids, keyword_id):
46
+ return True
47
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
48
+ for keyword in self.keywords:
49
+ if keyword in outputs:
50
+ return True
51
+ return False
52
+
53
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
54
+ outputs = []
55
+ for i in range(output_ids.shape[0]):
56
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
57
+ return all(outputs)
58
+
59
 
60
  @spaces.GPU
61
  def bot_streaming(message, history):
 
94
  add_generation_prompt=True)
95
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
96
  input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
97
+ stop_str = '<|im_end|>'
98
+ keywords = [stop_str]
99
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
100
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
101
 
102
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
103
+ generation_kwargs = dict(input_ids=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100, stopping_criteria=[stopping_criteria])
104
  generated_text = ""
105
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
106
  thread.start()