DongfuJiang commited on
Commit
335eee6
1 Parent(s): c862a9f
app.py CHANGED
@@ -4,10 +4,12 @@ import os
4
  import time
5
  from PIL import Image
6
  import functools
7
- from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
 
8
  from typing import List
9
  processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
10
  model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
 
11
 
12
  @spaces.GPU
13
  def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
@@ -15,7 +17,7 @@ def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
15
  model = model.to("cuda")
16
  if not images:
17
  images = None
18
- for text, history in chat_mllava(text, images, model, processor, history=history, stream=True, **kwargs):
19
  yield text
20
 
21
  return text
@@ -38,15 +40,17 @@ def print_like_dislike(x: gr.LikeData):
38
 
39
  def get_chat_history(history):
40
  chat_history = []
 
 
41
  for i, message in enumerate(history):
42
  if isinstance(message[0], str):
43
- chat_history.append({"role": "user", "text": message[0]})
44
  if i != len(history) - 1:
45
  assert message[1], "The bot message is not provided, internal error"
46
- chat_history.append({"role": "assistant", "text": message[1]})
47
  else:
48
  assert not message[1], "the bot message internal error, get: {}".format(message[1])
49
- chat_history.append({"role": "assistant", "text": ""})
50
  return chat_history
51
 
52
 
 
4
  import time
5
  from PIL import Image
6
  import functools
7
+ from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava_stream, MLlavaForConditionalGeneration
8
+ from models.conversation import conv_templates
9
  from typing import List
10
  processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
11
  model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
12
+ conv_template = conv_templates['llama_3']
13
 
14
  @spaces.GPU
15
  def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
 
17
  model = model.to("cuda")
18
  if not images:
19
  images = None
20
+ for text, history in chat_mllava_stream(text, images, model, processor, history=history, **kwargs):
21
  yield text
22
 
23
  return text
 
40
 
41
  def get_chat_history(history):
42
  chat_history = []
43
+ user_role = conv_template.roles[0]
44
+ assistant_role = conv_template.roles[1]
45
  for i, message in enumerate(history):
46
  if isinstance(message[0], str):
47
+ chat_history.append({"role": user_role, "text": message[0]})
48
  if i != len(history) - 1:
49
  assert message[1], "The bot message is not provided, internal error"
50
+ chat_history.append({"role": assistant_role, "text": message[1]})
51
  else:
52
  assert not message[1], "the bot message internal error, get: {}".format(message[1])
53
+ chat_history.append({"role": assistant_role, "text": ""})
54
  return chat_history
55
 
56
 
models/mllava/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
  from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
2
  from .processing_llava import MLlavaProcessor
3
  from .configuration_llava import LlavaConfig
4
- from .utils import chat_mllava
 
1
  from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
2
  from .processing_llava import MLlavaProcessor
3
  from .configuration_llava import LlavaConfig
4
+ from .utils import chat_mllava, chat_mllava_stream
models/mllava/utils.py CHANGED
@@ -44,7 +44,6 @@ def chat_mllava(
44
  conv.messages = []
45
  if history is not None:
46
  for message in history:
47
- message["role"] = message["role"].upper()
48
  assert message["role"] in conv.roles
49
  conv.append_message(message["role"], message["text"])
50
  else:
@@ -105,11 +104,20 @@ def chat_mllava_stream(
105
 
106
 
107
  """
108
- conv = default_conv.copy()
 
 
 
 
 
 
 
 
 
 
109
  conv.messages = []
110
  if history is not None:
111
  for message in history:
112
- message["role"] = message["role"].upper()
113
  assert message["role"] in conv.roles
114
  conv.append_message(message["role"], message["text"])
115
  else:
 
44
  conv.messages = []
45
  if history is not None:
46
  for message in history:
 
47
  assert message["role"] in conv.roles
48
  conv.append_message(message["role"], message["text"])
49
  else:
 
104
 
105
 
106
  """
107
+ if "llama-3" in model.language_model.name_or_path.lower():
108
+ conv = conv_templates['llama_3']
109
+ terminators = [
110
+ processor.tokenizer.eos_token_id,
111
+ processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
112
+ ]
113
+ else:
114
+ conv = default_conv
115
+ terminators = None
116
+ kwargs["eos_token_id"] = terminators
117
+ conv = conv.copy()
118
  conv.messages = []
119
  if history is not None:
120
  for message in history:
 
121
  assert message["role"] in conv.roles
122
  conv.append_message(message["role"], message["text"])
123
  else:
requirements.txt CHANGED
@@ -3,4 +3,5 @@ transformers
3
  Pillow
4
  gradio
5
  spaces
6
- multiprocess
 
 
3
  Pillow
4
  gradio
5
  spaces
6
+ multiprocess
7
+ flash-attn