wenhuchen commited on
Commit
bfc56dc
·
1 Parent(s): dfaecf8

update dem

Browse files
Files changed (1) hide show
  1. app.py +67 -45
app.py CHANGED
@@ -3,6 +3,7 @@ import spaces
3
  import os
4
  import time
5
  from PIL import Image
 
6
  from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
7
  from typing import List
8
  processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-llava-7b-v1.1")
@@ -54,45 +55,7 @@ def get_chat_images(history):
54
  if isinstance(message[0], tuple):
55
  images.extend(message[0])
56
  return images
57
-
58
- def bot(history):
59
- print(history)
60
- cur_messages = {"text": "", "images": []}
61
- for message in history[::-1]:
62
- if message[1]:
63
- break
64
- if isinstance(message[0], str):
65
- cur_messages["text"] = message[0] + " " + cur_messages["text"]
66
- elif isinstance(message[0], tuple):
67
- cur_messages["images"].extend(message[0])
68
- cur_messages["text"] = cur_messages["text"].strip()
69
- cur_messages["images"] = cur_messages["images"][::-1]
70
- if not cur_messages["text"]:
71
- raise gr.Error("Please enter a message")
72
- if cur_messages['text'].count("<image>") < len(cur_messages['images']):
73
- gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
74
- cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
75
- history[-1][0] = cur_messages["text"]
76
- if cur_messages['text'].count("<image>") > len(cur_messages['images']):
77
- gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
78
- cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
79
- history[-1][0] = cur_messages["text"]
80
-
81
- chat_history = get_chat_history(history)
82
- chat_images = get_chat_images(history)
83
- generation_kwargs = {
84
- "max_new_tokens": 4096,
85
- "temperature": 0.7,
86
- "top_p": 1.0,
87
- "do_sample": True,
88
- }
89
- print(None, chat_images, chat_history, generation_kwargs)
90
- response = generate(None, chat_images, chat_history, **generation_kwargs)
91
-
92
- for _output in response:
93
- history[-1][1] = _output
94
- time.sleep(0.05)
95
- yield history
96
 
97
  def build_demo():
98
  with gr.Blocks() as demo:
@@ -118,14 +81,73 @@ def build_demo():
118
  chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
119
 
120
  chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
122
 
123
  chatbot.like(print_like_dislike, None, None)
124
-
125
  with gr.Row():
126
  send_button = gr.Button("Send")
127
  clear_button = gr.ClearButton([chatbot, chat_input])
128
-
129
  send_button.click(
130
  add_message, [chatbot, chat_input], [chatbot, chat_input]
131
  ).then(
@@ -134,6 +156,10 @@ def build_demo():
134
 
135
  gr.Examples(
136
  examples=[
 
 
 
 
137
  {
138
  "text": "<image> <image> <image> Which image shows a different mood of character from the others?",
139
  "files": ["./examples/image12.jpg", "./examples/image13.jpg", "./examples/image14.jpg"]
@@ -142,10 +168,6 @@ def build_demo():
142
  "text": "<image> <image> What's the difference between these two images? Please describe as much as you can.",
143
  "files": ["./examples/image1.jpg", "./examples/image2.jpg"]
144
  },
145
- {
146
- "text": "<image> <image> How many dices are there in image 1 and image 2 respectively?",
147
- "files": ["./examples/image10.jpg", "./examples/image15.jpg"]
148
- },
149
  {
150
  "text": "<image> <image> Which image shows an older dog?",
151
  "files": ["./examples/image8.jpg", "./examples/image9.jpg"]
 
3
  import os
4
  import time
5
  from PIL import Image
6
+ import functools
7
  from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
8
  from typing import List
9
  processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-llava-7b-v1.1")
 
55
  if isinstance(message[0], tuple):
56
  images.extend(message[0])
57
  return images
58
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def build_demo():
61
  with gr.Blocks() as demo:
 
81
  chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
82
 
83
  chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
84
+
85
+ with gr.Accordion(label='Advanced options', open=False):
86
+ temperature = gr.Slider(
87
+ label='Temperature',
88
+ minimum=0.1,
89
+ maximum=2.0,
90
+ step=0.1,
91
+ value=0.2,
92
+ interactive=True
93
+ )
94
+ top_p = gr.Slider(
95
+ label='Top-p',
96
+ minimum=0.05,
97
+ maximum=1.0,
98
+ step=0.05,
99
+ value=1.0,
100
+ interactive=True
101
+ )
102
+
103
+ def bot(history):
104
+ print(history)
105
+ cur_messages = {"text": "", "images": []}
106
+ for message in history[::-1]:
107
+ if message[1]:
108
+ break
109
+ if isinstance(message[0], str):
110
+ cur_messages["text"] = message[0] + " " + cur_messages["text"]
111
+ elif isinstance(message[0], tuple):
112
+ cur_messages["images"].extend(message[0])
113
+ cur_messages["text"] = cur_messages["text"].strip()
114
+ cur_messages["images"] = cur_messages["images"][::-1]
115
+ if not cur_messages["text"]:
116
+ raise gr.Error("Please enter a message")
117
+ if cur_messages['text'].count("<image>") < len(cur_messages['images']):
118
+ gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
119
+ cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
120
+ history[-1][0] = cur_messages["text"]
121
+ if cur_messages['text'].count("<image>") > len(cur_messages['images']):
122
+ gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
123
+ cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
124
+ history[-1][0] = cur_messages["text"]
125
+
126
+ chat_history = get_chat_history(history)
127
+ chat_images = get_chat_images(history)
128
+ generation_kwargs = {
129
+ "max_new_tokens": 4096,
130
+ "temperature": temperature,
131
+ "top_p": top_p,
132
+ "do_sample": True,
133
+ }
134
+ print(None, chat_images, chat_history, generation_kwargs)
135
+ response = generate(None, chat_images, chat_history, **generation_kwargs)
136
+
137
+ for _output in response:
138
+ history[-1][1] = _output
139
+ time.sleep(0.05)
140
+ yield history
141
+
142
+
143
  bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
144
 
145
  chatbot.like(print_like_dislike, None, None)
146
+
147
  with gr.Row():
148
  send_button = gr.Button("Send")
149
  clear_button = gr.ClearButton([chatbot, chat_input])
150
+
151
  send_button.click(
152
  add_message, [chatbot, chat_input], [chatbot, chat_input]
153
  ).then(
 
156
 
157
  gr.Examples(
158
  examples=[
159
+ {
160
+ "text": "<image> <image> How many dices are there in image 1 and image 2 respectively?",
161
+ "files": ["./examples/image10.jpg", "./examples/image15.jpg"]
162
+ },
163
  {
164
  "text": "<image> <image> <image> Which image shows a different mood of character from the others?",
165
  "files": ["./examples/image12.jpg", "./examples/image13.jpg", "./examples/image14.jpg"]
 
168
  "text": "<image> <image> What's the difference between these two images? Please describe as much as you can.",
169
  "files": ["./examples/image1.jpg", "./examples/image2.jpg"]
170
  },
 
 
 
 
171
  {
172
  "text": "<image> <image> Which image shows an older dog?",
173
  "files": ["./examples/image8.jpg", "./examples/image9.jpg"]