DongfuJiang commited on
Commit
f16e094
1 Parent(s): 8fbc209
.gitignore ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
app.py CHANGED
@@ -1,24 +1,120 @@
1
  import gradio as gr
2
  import spaces
 
3
  from PIL import Image
4
  from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
5
  from typing import List
6
- processor = MLlavaProcessor()
7
- model = LlavaForConditionalGeneration.from_pretrained("MFuyu/mllava_v2_4096")
8
 
9
  @spaces.GPU
10
- def generate(text:str, images:List[Image.Image], history: List[dict]):
 
11
  model = model.to("cuda")
12
-
13
- for text, history in chat_mllava(text, images, model, processor, history=history, stream=True):
14
- yield text, history
 
15
 
16
- def build_demo():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  if __name__ == "__main__":
21
- processor = MLlavaProcessor()
22
- model = LlavaForConditionalGeneration.from_pretrained("MFuyu/mllava_v2_4096")
23
  demo = build_demo()
24
  demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
+ import time
4
  from PIL import Image
5
  from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
6
  from typing import List
7
+ processor = MLlavaProcessor.from_pretrained("MFuyu/mllava_llava_debug_nlvr2_v5_4096")
8
+ model = LlavaForConditionalGeneration.from_pretrained("MFuyu/mllava_llava_debug_nlvr2_v5_4096")
9
 
10
  @spaces.GPU
11
+ def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
12
+ global processor, model
13
  model = model.to("cuda")
14
+ if not images:
15
+ images = None
16
+ for text, history in chat_mllava(text, images, model, processor, history=history, stream=True, **kwargs):
17
+ yield text
18
 
19
+ return text
20
+
21
+ def enable_next_image(uploaded_images, image):
22
+ uploaded_images.append(image)
23
+ return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)
24
+
25
+ def add_message(history, message):
26
+ if message["files"]:
27
+ for file in message["files"]:
28
+ history.append([(file,), None])
29
+ if message["text"]:
30
+ history.append([message["text"], None])
31
+ return history, gr.MultimodalTextbox(value=None)
32
+
33
+ def print_like_dislike(x: gr.LikeData):
34
+ print(x.index, x.value, x.liked)
35
+
36
+
37
+ def get_chat_history(history):
38
+ chat_history = []
39
+ for i, message in enumerate(history):
40
+ if isinstance(message[0], str):
41
+ chat_history.append({"role": "user", "text": message[0]})
42
+ if i != len(history) - 1:
43
+ assert message[1], "The bot message is not provided, internal error"
44
+ chat_history.append({"role": "assistant", "text": message[1]})
45
+ else:
46
+ assert not message[1], "the bot message internal error, get: {}".format(message[1])
47
+ chat_history.append({"role": "assistant", "text": ""})
48
+ return chat_history
49
+
50
+ def get_chat_images(history):
51
+ images = []
52
+ for message in history:
53
+ if isinstance(message[0], tuple):
54
+ images.extend(message[0])
55
+ return images
56
+
57
+ def bot(history):
58
+ print(history)
59
+ cur_messages = {"text": "", "images": []}
60
+ for message in history[::-1]:
61
+ if message[1]:
62
+ break
63
+ if isinstance(message[0], str):
64
+ cur_messages["text"] = message[0] + " " + cur_messages["text"]
65
+ elif isinstance(message[0], tuple):
66
+ cur_messages["images"].extend(message[0])
67
+ cur_messages["text"] = cur_messages["text"].strip()
68
+ cur_messages["images"] = cur_messages["images"][::-1]
69
+ if not cur_messages["text"]:
70
+ raise gr.Error("Please enter a message")
71
+ if cur_messages['text'].count("<image>") < len(cur_messages['images']):
72
+ gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
73
+ cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
74
+ history[-1][0] = cur_messages["text"]
75
+ if cur_messages['text'].count("<image>") > len(cur_messages['images']):
76
+ gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
77
+ cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
78
+ history[-1][0] = cur_messages["text"]
79
 
80
+ chat_history = get_chat_history(history)
81
+ chat_images = get_chat_images(history)
82
+ generation_kwargs = {
83
+ "max_new_tokens": 4096,
84
+ "temperature": 0.7,
85
+ "top_p": 1.0,
86
+ "do_sample": True,
87
+ }
88
+ print(None, chat_images, chat_history, generation_kwargs)
89
+ response = generate(None, chat_images, chat_history, **generation_kwargs)
90
+
91
+ for _output in response:
92
+ history[-1][1] = _output
93
+ time.sleep(0.05)
94
+ yield history
95
+
96
+ def build_demo():
97
+ with gr.Blocks() as demo:
98
+ chatbot = gr.Chatbot(line_breaks=True)
99
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
100
+
101
+ chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
102
+ bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
103
+
104
+ chatbot.like(print_like_dislike, None, None)
105
+
106
+ with gr.Row():
107
+ send_button = gr.Button("Send")
108
+ clear_button = gr.ClearButton([chatbot, chat_input])
109
+
110
+ send_button.click(
111
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
112
+ ).then(
113
+ bot, chatbot, chatbot, api_name="bot_response"
114
+ )
115
+ return demo
116
 
117
 
118
  if __name__ == "__main__":
 
 
119
  demo = build_demo()
120
  demo.launch()
models/mllava/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/models/mllava/__pycache__/__init__.cpython-39.pyc and b/models/mllava/__pycache__/__init__.cpython-39.pyc differ
 
models/mllava/__pycache__/configuration_llava.cpython-39.pyc CHANGED
Binary files a/models/mllava/__pycache__/configuration_llava.cpython-39.pyc and b/models/mllava/__pycache__/configuration_llava.cpython-39.pyc differ
 
models/mllava/__pycache__/modeling_llava.cpython-39.pyc CHANGED
Binary files a/models/mllava/__pycache__/modeling_llava.cpython-39.pyc and b/models/mllava/__pycache__/modeling_llava.cpython-39.pyc differ
 
models/mllava/__pycache__/processing_llava.cpython-39.pyc CHANGED
Binary files a/models/mllava/__pycache__/processing_llava.cpython-39.pyc and b/models/mllava/__pycache__/processing_llava.cpython-39.pyc differ
 
models/mllava/__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/models/mllava/__pycache__/utils.cpython-39.pyc and b/models/mllava/__pycache__/utils.cpython-39.pyc differ
 
models/mllava/utils.py CHANGED
@@ -3,11 +3,11 @@ import torch
3
  from .modeling_llava import LlavaForConditionalGeneration
4
  from .processing_llava import MLlavaProcessor
5
  from ..conversation import conv_mllava_v1_mmtag as default_conv
6
- from typing import List, Tuple
7
 
8
  def chat_mllava(
9
  text:str,
10
- images: List[PIL.Image.Image],
11
  model:LlavaForConditionalGeneration,
12
  processor:MLlavaProcessor,
13
  max_input_length:int=None,
@@ -38,13 +38,26 @@ def chat_mllava(
38
  conv.append_message(message["role"], message["text"])
39
  else:
40
  history = []
41
- conv.append_message(conv.roles[0], text)
42
- conv.append_message(conv.roles[1], "")
 
 
 
 
 
 
 
 
 
43
 
44
  prompt = conv.get_prompt()
 
 
 
 
45
 
46
  inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
47
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
48
 
49
  if stream:
50
  from transformers import TextIteratorStreamer
@@ -54,8 +67,6 @@ def chat_mllava(
54
  inputs.update(kwargs)
55
  thread = Thread(target=model.generate, kwargs=inputs)
56
  thread.start()
57
- history.append({"role": conv.roles[0], "text": text})
58
- history.append({"role": conv.roles[1], "text": ""})
59
  for _output in streamer:
60
  history[-1]["text"] += _output
61
  yield history[-1]["text"], history
@@ -67,7 +78,6 @@ def chat_mllava(
67
  generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
68
  generated_text = processor.decode(generated_ids, skip_special_tokens=True)
69
 
70
- history.append({"role": conv.roles[0], "text": text})
71
- history.append({"role": conv.roles[1], "text": generated_text})
72
 
73
  return generated_text, history
 
3
  from .modeling_llava import LlavaForConditionalGeneration
4
  from .processing_llava import MLlavaProcessor
5
  from ..conversation import conv_mllava_v1_mmtag as default_conv
6
+ from typing import List, Tuple, Union, Tuple
7
 
8
  def chat_mllava(
9
  text:str,
10
+ images: List[Union[PIL.Image.Image, str]],
11
  model:LlavaForConditionalGeneration,
12
  processor:MLlavaProcessor,
13
  max_input_length:int=None,
 
38
  conv.append_message(message["role"], message["text"])
39
  else:
40
  history = []
41
+
42
+ if text is not None:
43
+ conv.append_message(conv.roles[0], text)
44
+ conv.append_message(conv.roles[1], "")
45
+ history.append({"role": conv.roles[0], "text": text})
46
+ history.append({"role": conv.roles[1], "text": ""})
47
+ else:
48
+ assert history, "The history should not be empty if the text is None"
49
+ assert history[-1]['role'] == conv.roles[1], "The last message in the history should be the assistant, an empty message"
50
+ assert history[-2]['text'], "The last user message in the history should not be empty"
51
+ assert history[-1]['text'] == "", "The last assistant message in the history should be empty"
52
 
53
  prompt = conv.get_prompt()
54
+ if images:
55
+ for i in range(len(images)):
56
+ if isinstance(images[i], str):
57
+ images[i] = PIL.Image.open(images[i])
58
 
59
  inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
60
+ inputs = {k: v.to(model.device) if v is not None else v for k, v in inputs.items()}
61
 
62
  if stream:
63
  from transformers import TextIteratorStreamer
 
67
  inputs.update(kwargs)
68
  thread = Thread(target=model.generate, kwargs=inputs)
69
  thread.start()
 
 
70
  for _output in streamer:
71
  history[-1]["text"] += _output
72
  yield history[-1]["text"], history
 
78
  generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
79
  generated_text = processor.decode(generated_ids, skip_special_tokens=True)
80
 
81
+ history[-1]["text"] = history[-1]["text"].strip()
 
82
 
83
  return generated_text, history
requirements.txt CHANGED
@@ -1 +1,6 @@
1
- git
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ gradio
5
+ spaces
6
+ multiprocess