DongfuJiang commited on
Commit
669c11e
1 Parent(s): 75c15ae
Files changed (2) hide show
  1. app.py +18 -9
  2. models/mllava/utils.py +40 -8
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import time
5
  from PIL import Image
6
  import functools
7
- from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava_stream, MLlavaForConditionalGeneration
8
  from models.conversation import conv_templates
9
  from typing import List
10
  processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
@@ -12,7 +12,7 @@ model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-8B-sigli
12
  conv_template = conv_templates['llama_3']
13
 
14
  @spaces.GPU
15
- def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
16
  global processor, model
17
  model = model.to("cuda")
18
  if not images:
@@ -22,6 +22,15 @@ def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
22
 
23
  return text
24
 
 
 
 
 
 
 
 
 
 
25
  def enable_next_image(uploaded_images, image):
26
  uploaded_images.append(image)
27
  return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)
@@ -87,15 +96,14 @@ def bot(history):
87
 
88
  chat_history = get_chat_history(history)
89
  chat_images = get_chat_images(history)
 
90
  generation_kwargs = {
91
  "max_new_tokens": 4096,
92
- "temperature": 0.2,
93
- "top_p": 1.0,
94
- "do_sample": True,
95
  }
96
- print(None, chat_images, chat_history, generation_kwargs)
97
- response = generate(None, chat_images, chat_history, **generation_kwargs)
98
-
99
  for _output in response:
100
  history[-1][1] = _output
101
  time.sleep(0.05)
@@ -191,7 +199,8 @@ Mantis is a multimodal conversational AI model that can chat with users about im
191
  author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu},
192
  journal={arXiv preprint arXiv:2405.01483},
193
  year={2024}
194
- }```""")
 
195
  return demo
196
 
197
 
 
4
  import time
5
  from PIL import Image
6
  import functools
7
+ from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava_stream, MLlavaForConditionalGeneration, chat_mllava
8
  from models.conversation import conv_templates
9
  from typing import List
10
  processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
 
12
  conv_template = conv_templates['llama_3']
13
 
14
  @spaces.GPU
15
+ def generate_stream(text:str, images:List[Image.Image], history: List[dict], **kwargs):
16
  global processor, model
17
  model = model.to("cuda")
18
  if not images:
 
22
 
23
  return text
24
 
25
+ @spaces.GPU
26
+ def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
27
+ global processor, model
28
+ model = model.to("cuda")
29
+ if not images:
30
+ images = None
31
+ generated_text, history = chat_mllava(text, images, model, processor, history=history, **kwargs)
32
+ return generated_text
33
+
34
  def enable_next_image(uploaded_images, image):
35
  uploaded_images.append(image)
36
  return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)
 
96
 
97
  chat_history = get_chat_history(history)
98
  chat_images = get_chat_images(history)
99
+
100
  generation_kwargs = {
101
  "max_new_tokens": 4096,
102
+ "num_beams": 1,
103
+ "do_sample": False
 
104
  }
105
+
106
+ response = generate_stream(None, chat_images, chat_history, **generation_kwargs)
 
107
  for _output in response:
108
  history[-1][1] = _output
109
  time.sleep(0.05)
 
199
  author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu},
200
  journal={arXiv preprint arXiv:2405.01483},
201
  year={2024}
202
+ }
203
+ ```""")
204
  return demo
205
 
206
 
models/mllava/utils.py CHANGED
@@ -46,10 +46,27 @@ def chat_mllava(
46
  for message in history:
47
  assert message["role"] in conv.roles
48
  conv.append_message(message["role"], message["text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  else:
50
  history = []
51
- conv.append_message(conv.roles[0], text)
52
- conv.append_message(conv.roles[1], "")
 
 
 
 
53
 
54
  prompt = conv.get_prompt()
55
  if images:
@@ -75,8 +92,7 @@ def chat_mllava(
75
  generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
76
  generated_text = processor.decode(generated_ids, skip_special_tokens=True)
77
 
78
- history.append({"role": conv.roles[0], "text": text})
79
- history.append({"role": conv.roles[1], "text": generated_text})
80
 
81
  return generated_text, history
82
 
@@ -120,10 +136,27 @@ def chat_mllava_stream(
120
  for message in history:
121
  assert message["role"] in conv.roles
122
  conv.append_message(message["role"], message["text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  else:
124
  history = []
125
- conv.append_message(conv.roles[0], text)
126
- conv.append_message(conv.roles[1], "")
 
 
 
 
127
 
128
  prompt = conv.get_prompt()
129
  if images:
@@ -132,6 +165,7 @@ def chat_mllava_stream(
132
  images[i] = PIL.Image.open(images[i])
133
 
134
  inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
 
135
  for k, v in inputs.items():
136
  if v is not None:
137
  if isinstance(v, torch.Tensor):
@@ -148,8 +182,6 @@ def chat_mllava_stream(
148
  inputs.update(kwargs)
149
  thread = Thread(target=model.generate, kwargs=inputs)
150
  thread.start()
151
- history.append({"role": conv.roles[0], "text": text})
152
- history.append({"role": conv.roles[1], "text": ""})
153
  for _output in streamer:
154
  history[-1]["text"] += _output
155
  yield history[-1]["text"], history
 
46
  for message in history:
47
  assert message["role"] in conv.roles
48
  conv.append_message(message["role"], message["text"])
49
+ if text:
50
+ assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
51
+ conv.append_message(conv.roles[0], text)
52
+ conv.append_message(conv.roles[1], "")
53
+ history.append({"role": conv.roles[0], "text": text})
54
+ history.append({"role": conv.roles[1], "text": ""})
55
+ else:
56
+ if conv.messages[-1][0] == conv.roles[1]:
57
+ assert conv.messages[-1][1] == "", "No user message should be provided"
58
+ else:
59
+ assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
60
+ conv.append_message(conv.roles[0], "")
61
+ history.append({"role": conv.roles[0], "text": ""})
62
  else:
63
  history = []
64
+ history.append({"role": conv.roles[0], "text": text})
65
+ history.append({"role": conv.roles[1], "text": ""})
66
+ conv.append_message(conv.roles[0], text)
67
+ conv.append_message(conv.roles[1], "")
68
+ assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
69
+ assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
70
 
71
  prompt = conv.get_prompt()
72
  if images:
 
92
  generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
93
  generated_text = processor.decode(generated_ids, skip_special_tokens=True)
94
 
95
+ history[-1]["text"] = generated_text
 
96
 
97
  return generated_text, history
98
 
 
136
  for message in history:
137
  assert message["role"] in conv.roles
138
  conv.append_message(message["role"], message["text"])
139
+ if text:
140
+ assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
141
+ conv.append_message(conv.roles[0], text)
142
+ conv.append_message(conv.roles[1], "")
143
+ history.append({"role": conv.roles[0], "text": text})
144
+ history.append({"role": conv.roles[1], "text": ""})
145
+ else:
146
+ if conv.messages[-1][0] == conv.roles[1]:
147
+ assert conv.messages[-1][1] == "", "No user message should be provided"
148
+ else:
149
+ assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
150
+ conv.append_message(conv.roles[0], "")
151
+ history.append({"role": conv.roles[0], "text": ""})
152
  else:
153
  history = []
154
+ history.append({"role": conv.roles[0], "text": text})
155
+ history.append({"role": conv.roles[1], "text": ""})
156
+ conv.append_message(conv.roles[0], text)
157
+ conv.append_message(conv.roles[1], "")
158
+ assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
159
+ assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
160
 
161
  prompt = conv.get_prompt()
162
  if images:
 
165
  images[i] = PIL.Image.open(images[i])
166
 
167
  inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
168
+ print(processor.tokenizer.decode(inputs["input_ids"][0]))
169
  for k, v in inputs.items():
170
  if v is not None:
171
  if isinstance(v, torch.Tensor):
 
182
  inputs.update(kwargs)
183
  thread = Thread(target=model.generate, kwargs=inputs)
184
  thread.start()
 
 
185
  for _output in streamer:
186
  history[-1]["text"] += _output
187
  yield history[-1]["text"], history