luodian commited on
Commit
e76d30d
·
verified ·
1 Parent(s): af7e468

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -103
app.py CHANGED
@@ -1,25 +1,35 @@
1
-
2
  # from .demo_modelpart import InferenceDemo
3
  import gradio as gr
4
  import os
 
5
  # import time
6
  import cv2
7
 
8
 
9
  # import copy
10
  import torch
11
- import spaces
 
12
  import numpy as np
13
 
14
  from llava import conversation as conversation_lib
15
  from llava.constants import DEFAULT_IMAGE_TOKEN
16
 
17
 
18
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
 
 
 
 
 
19
  from llava.conversation import conv_templates, SeparatorStyle
20
  from llava.model.builder import load_pretrained_model
21
  from llava.utils import disable_torch_init
22
- from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
 
 
 
 
23
 
24
  from PIL import Image
25
 
@@ -28,12 +38,19 @@ from PIL import Image
28
  from io import BytesIO
29
  from transformers import TextStreamer
30
 
 
31
  class InferenceDemo(object):
32
- def __init__(self,args,model_path,tokenizer, model, image_processor, context_len) -> None:
 
 
33
  disable_torch_init()
34
 
35
-
36
- self.tokenizer, self.model, self.image_processor, self.context_len = tokenizer, model, image_processor, context_len
 
 
 
 
37
 
38
  if "llama-2" in model_name.lower():
39
  conv_mode = "llava_llama_2"
@@ -41,32 +58,36 @@ class InferenceDemo(object):
41
  conv_mode = "llava_v1"
42
  elif "mpt" in model_name.lower():
43
  conv_mode = "mpt"
44
- elif 'qwen' in model_name.lower():
45
  conv_mode = "qwen_1_5"
46
  else:
47
  conv_mode = "llava_v0"
48
 
49
  if args.conv_mode is not None and conv_mode != args.conv_mode:
50
- print("[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(conv_mode, args.conv_mode, args.conv_mode))
 
 
 
 
51
  else:
52
  args.conv_mode = conv_mode
53
- self.conv_mode=conv_mode
54
  self.conversation = conv_templates[args.conv_mode].copy()
55
  self.num_frames = args.num_frames
56
 
57
 
58
-
59
  def is_valid_video_filename(name):
60
- video_extensions = ['avi', 'mp4', 'mov', 'mkv', 'flv', 'wmv', 'mjpeg']
61
-
62
- ext = name.split('.')[-1].lower()
63
-
64
  if ext in video_extensions:
65
  return True
66
  else:
67
  return False
68
 
69
- def sample_frames(video_file, num_frames) :
 
70
  video = cv2.VideoCapture(video_file)
71
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
72
  interval = total_frames // num_frames
@@ -81,18 +102,19 @@ def sample_frames(video_file, num_frames) :
81
  video.release()
82
  return frames
83
 
 
84
  def load_image(image_file):
85
  if image_file.startswith("http") or image_file.startswith("https"):
86
  response = requests.get(image_file)
87
  if response.status_code == 200:
88
  image = Image.open(BytesIO(response.content)).convert("RGB")
89
  else:
90
- print('failed to load the image')
91
  else:
92
- print('Load image from local file')
93
  print(image_file)
94
  image = Image.open(image_file).convert("RGB")
95
-
96
  return image
97
 
98
 
@@ -101,6 +123,8 @@ def clear_history(history):
101
  our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
102
 
103
  return None
 
 
104
  def clear_response(history):
105
  for index_conv in range(1, len(history)):
106
  # loop until get a text response from our model.
@@ -111,165 +135,294 @@ def clear_response(history):
111
  history = history[:-index_conv]
112
  return history, question
113
 
114
- def print_like_dislike(x: gr.LikeData):
115
- print(x.index, x.value, x.liked)
116
 
 
 
117
 
118
 
119
  def add_message(history, message):
120
  # history=[]
121
  global our_chatbot
122
- if len(history)==0:
123
- our_chatbot = InferenceDemo(args,model_path,tokenizer, model, image_processor, context_len)
124
-
 
 
125
  for x in message["files"]:
126
  history.append(((x,), None))
127
  if message["text"] is not None:
128
  history.append((message["text"], None))
129
  return history, gr.MultimodalTextbox(value=None, interactive=False)
130
 
131
- @spaces.GPU
 
132
  def bot(history):
133
- text=history[-1][0]
134
- images_this_term=[]
135
- text_this_term=''
136
  # import pdb;pdb.set_trace()
137
  num_new_images = 0
138
- for i,message in enumerate(history[:-1]):
139
  if type(message[0]) is tuple:
140
  images_this_term.append(message[0][0])
141
  if is_valid_video_filename(message[0][0]):
142
- num_new_images+=our_chatbot.num_frames
143
  else:
144
- num_new_images+=1
145
  else:
146
- num_new_images=0
147
-
148
  # for message in history[-i-1:]:
149
  # images_this_term.append(message[0][0])
150
 
151
- assert len(images_this_term)>0, "must have an image"
152
  # image_files = (args.image_file).split(',')
153
  # image = [load_image(f) for f in images_this_term if f]
154
- image_list=[]
155
  for f in images_this_term:
156
  if is_valid_video_filename(f):
157
- image_list+=sample_frames(f, our_chatbot.num_frames)
158
  else:
159
  image_list.append(load_image(f))
160
- image_tensor = [our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][0].half().to(our_chatbot.model.device) for f in image_list]
 
 
 
 
 
 
 
161
 
162
  image_tensor = torch.stack(image_tensor)
163
- image_token = DEFAULT_IMAGE_TOKEN*num_new_images
164
  # if our_chatbot.model.config.mm_use_im_start_end:
165
  # inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
166
  # else:
167
- inp=text
168
- inp = image_token+ "\n" + inp
169
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
170
  # image = None
171
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
172
  prompt = our_chatbot.conversation.get_prompt()
173
 
174
- input_ids = tokenizer_image_token(prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(our_chatbot.model.device)
175
- stop_str = our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2
 
 
 
 
 
 
 
 
 
 
176
  keywords = [stop_str]
177
- stopping_criteria = KeywordsStoppingCriteria(keywords, our_chatbot.tokenizer, input_ids)
178
- streamer = TextStreamer(our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
179
  print(our_chatbot.model.device)
180
  print(input_ids.device)
181
  print(image_tensor.device)
182
  # import pdb;pdb.set_trace()
183
  with torch.inference_mode():
184
- output_ids = our_chatbot.model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=False, stopping_criteria=[stopping_criteria])
 
 
 
 
 
 
 
 
 
185
 
186
  outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
187
  if outputs.endswith(stop_str):
188
- outputs = outputs[:-len(stop_str)]
189
  our_chatbot.conversation.messages[-1][-1] = outputs
190
-
191
- history[-1]=[text,outputs]
192
-
193
  return history
 
 
194
  txt = gr.Textbox(
195
  scale=4,
196
  show_label=False,
197
  placeholder="Enter text and press enter.",
198
  container=False,
199
  )
200
- with gr.Blocks() as demo:
201
-
 
 
 
202
  # Informations
203
- title_markdown = ("""
204
  # LLaVA-NeXT Interleave
205
  [[Blog]](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/) [[Code]](https://github.com/LLaVA-VL/LLaVA-NeXT) [[Model]](https://huggingface.co/lmms-lab/llava-next-interleave-7b)
206
- """)
207
- tos_markdown = ("""
208
  ### TODO!. Terms of use
209
  By using this service, users are required to agree to the following terms:
210
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
211
  Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
212
  For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
213
- """)
214
- learn_more_markdown = ("""
215
  ### TODO!. License
216
  The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
217
- """)
218
  models = [
219
  "LLaVA-Interleave-7B",
220
  ]
221
  cur_dir = os.path.dirname(os.path.abspath(__file__))
222
  gr.Markdown(title_markdown)
223
- with gr.Row():
224
- with gr.Column(scale=1):
225
- with gr.Row():
226
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image","video"], placeholder="Enter message or upload file...", show_label=False)
227
- with gr.Row():
228
- gr.Examples(examples=[
229
- [{"files": [f"{cur_dir}/examples/code1.jpeg",f"{cur_dir}/examples/code2.jpeg"], "text": "Please pay attention to the movement of the object from the first image to the second image, then write a HTML code to show this movement."}],
230
- [{"files": [f"{cur_dir}/examples/shub.jpg",f"{cur_dir}/examples/shuc.jpg",f"{cur_dir}/examples/shud.jpg"], "text": "what is fun about the images?"}],
231
- [{"files": [f"{cur_dir}/examples/iphone-15-price-1024x576.jpg",f"{cur_dir}/examples/dynamic-island-1024x576.jpg",f"{cur_dir}/examples/iphone-15-colors-1024x576.jpg",f"{cur_dir}/examples/Iphone-15-Usb-c-charger-1024x576.jpg",f"{cur_dir}/examples/A-17-processors-1024x576.jpg"], "text": "The images are the PPT of iPhone 15 review. can you summarize the main information?"}],
232
- [{"files": [f"{cur_dir}/examples/fangao3.jpeg",f"{cur_dir}/examples/fangao2.jpeg",f"{cur_dir}/examples/fangao1.jpeg"], "text": "Do you kown who draw these paintings?"}],
233
- [{"files": [f"{cur_dir}/examples/oprah-winfrey-resume.png",f"{cur_dir}/examples/steve-jobs-resume.jpg"], "text": "Hi, there are two candidates, can you provide a brief description for each of them for me?"}],
234
- [{"files": [f"{cur_dir}/examples/original_bench.jpeg",f"{cur_dir}/examples/changed_bench.jpeg"], "text": "How to edit image1 to make it look like image2?"}],
235
- [{"files": [f"{cur_dir}/examples/twitter2.jpeg",f"{cur_dir}/examples/twitter3.jpeg",f"{cur_dir}/examples/twitter4.jpeg"], "text": "Please write a twitter blog post with the images."}],
236
- # [{"files": [f"{cur_dir}/examples/twitter3.jpeg",f"{cur_dir}/examples/twitter4.jpeg"], "text": "Please write a twitter blog post with the images."}],
237
- # [{"files": [f"playground/demo/examples/lion1_.mp4",f"playground/demo/examples/lion2_.mp4"], "text": "The input contains two videos, the first half is the first video and the second half is the second video. What is the difference between the two videos?"}],
238
-
239
- ], inputs=[chat_input], label="Compare images: ",examples_per_page=3)
240
- with gr.Column(scale=2):
241
- with gr.Row():
242
- chatbot = gr.Chatbot(
243
- [],
244
- elem_id="chatbot",
245
- bubble_full_width=False
246
- )
247
- with gr.Row():
248
- upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
249
- downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
250
- flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
251
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
252
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
253
- clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
254
- chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
256
  bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
257
 
258
- chatbot.like(print_like_dislike, None, None)
259
- clear_btn.click(fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all")
260
-
261
-
262
-
263
-
264
 
265
 
266
  demo.queue()
267
  if __name__ == "__main__":
268
  import argparse
 
269
  argparser = argparse.ArgumentParser()
270
  argparser.add_argument("--server_name", default="0.0.0.0", type=str)
271
  argparser.add_argument("--port", default="6123", type=str)
272
- argparser.add_argument("--model_path", default="lmms-lab/llava-next-interleave-qwen-7b", type=str)
 
 
273
  # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m")
274
  argparser.add_argument("--model-base", type=str, default=None)
275
  argparser.add_argument("--num-gpus", type=int, default=1)
@@ -280,15 +433,14 @@ if __name__ == "__main__":
280
  argparser.add_argument("--load-8bit", action="store_true")
281
  argparser.add_argument("--load-4bit", action="store_true")
282
  argparser.add_argument("--debug", action="store_true")
283
-
284
  args = argparser.parse_args()
285
  model_path = args.model_path
286
- filt_invalid="cut"
287
- model_name = get_model_name_from_path(args.model_path)
288
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
289
- model=model.to(torch.device('cuda'))
290
  our_chatbot = None
291
  # import pdb;pdb.set_trace()
292
  # try:
293
- demo.launch()
294
-
 
 
1
  # from .demo_modelpart import InferenceDemo
2
  import gradio as gr
3
  import os
4
+
5
  # import time
6
  import cv2
7
 
8
 
9
  # import copy
10
  import torch
11
+
12
+ # import spaces
13
  import numpy as np
14
 
15
  from llava import conversation as conversation_lib
16
  from llava.constants import DEFAULT_IMAGE_TOKEN
17
 
18
 
19
+ from llava.constants import (
20
+ IMAGE_TOKEN_INDEX,
21
+ DEFAULT_IMAGE_TOKEN,
22
+ DEFAULT_IM_START_TOKEN,
23
+ DEFAULT_IM_END_TOKEN,
24
+ )
25
  from llava.conversation import conv_templates, SeparatorStyle
26
  from llava.model.builder import load_pretrained_model
27
  from llava.utils import disable_torch_init
28
+ from llava.mm_utils import (
29
+ tokenizer_image_token,
30
+ get_model_name_from_path,
31
+ KeywordsStoppingCriteria,
32
+ )
33
 
34
  from PIL import Image
35
 
 
38
  from io import BytesIO
39
  from transformers import TextStreamer
40
 
41
+
42
  class InferenceDemo(object):
43
+ def __init__(
44
+ self, args, model_path, tokenizer, model, image_processor, context_len
45
+ ) -> None:
46
  disable_torch_init()
47
 
48
+ self.tokenizer, self.model, self.image_processor, self.context_len = (
49
+ tokenizer,
50
+ model,
51
+ image_processor,
52
+ context_len,
53
+ )
54
 
55
  if "llama-2" in model_name.lower():
56
  conv_mode = "llava_llama_2"
 
58
  conv_mode = "llava_v1"
59
  elif "mpt" in model_name.lower():
60
  conv_mode = "mpt"
61
+ elif "qwen" in model_name.lower():
62
  conv_mode = "qwen_1_5"
63
  else:
64
  conv_mode = "llava_v0"
65
 
66
  if args.conv_mode is not None and conv_mode != args.conv_mode:
67
+ print(
68
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
69
+ conv_mode, args.conv_mode, args.conv_mode
70
+ )
71
+ )
72
  else:
73
  args.conv_mode = conv_mode
74
+ self.conv_mode = conv_mode
75
  self.conversation = conv_templates[args.conv_mode].copy()
76
  self.num_frames = args.num_frames
77
 
78
 
 
79
  def is_valid_video_filename(name):
80
+ video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"]
81
+
82
+ ext = name.split(".")[-1].lower()
83
+
84
  if ext in video_extensions:
85
  return True
86
  else:
87
  return False
88
 
89
+
90
+ def sample_frames(video_file, num_frames):
91
  video = cv2.VideoCapture(video_file)
92
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
93
  interval = total_frames // num_frames
 
102
  video.release()
103
  return frames
104
 
105
+
106
  def load_image(image_file):
107
  if image_file.startswith("http") or image_file.startswith("https"):
108
  response = requests.get(image_file)
109
  if response.status_code == 200:
110
  image = Image.open(BytesIO(response.content)).convert("RGB")
111
  else:
112
+ print("failed to load the image")
113
  else:
114
+ print("Load image from local file")
115
  print(image_file)
116
  image = Image.open(image_file).convert("RGB")
117
+
118
  return image
119
 
120
 
 
123
  our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy()
124
 
125
  return None
126
+
127
+
128
  def clear_response(history):
129
  for index_conv in range(1, len(history)):
130
  # loop until get a text response from our model.
 
135
  history = history[:-index_conv]
136
  return history, question
137
 
 
 
138
 
139
+ # def print_like_dislike(x: gr.LikeData):
140
+ # print(x.index, x.value, x.liked)
141
 
142
 
143
  def add_message(history, message):
144
  # history=[]
145
  global our_chatbot
146
+ if len(history) == 0:
147
+ our_chatbot = InferenceDemo(
148
+ args, model_path, tokenizer, model, image_processor, context_len
149
+ )
150
+
151
  for x in message["files"]:
152
  history.append(((x,), None))
153
  if message["text"] is not None:
154
  history.append((message["text"], None))
155
  return history, gr.MultimodalTextbox(value=None, interactive=False)
156
 
157
+
158
+ # @spaces.GPU
159
  def bot(history):
160
+ text = history[-1][0]
161
+ images_this_term = []
162
+ text_this_term = ""
163
  # import pdb;pdb.set_trace()
164
  num_new_images = 0
165
+ for i, message in enumerate(history[:-1]):
166
  if type(message[0]) is tuple:
167
  images_this_term.append(message[0][0])
168
  if is_valid_video_filename(message[0][0]):
169
+ num_new_images += our_chatbot.num_frames
170
  else:
171
+ num_new_images += 1
172
  else:
173
+ num_new_images = 0
174
+
175
  # for message in history[-i-1:]:
176
  # images_this_term.append(message[0][0])
177
 
178
+ assert len(images_this_term) > 0, "must have an image"
179
  # image_files = (args.image_file).split(',')
180
  # image = [load_image(f) for f in images_this_term if f]
181
+ image_list = []
182
  for f in images_this_term:
183
  if is_valid_video_filename(f):
184
+ image_list += sample_frames(f, our_chatbot.num_frames)
185
  else:
186
  image_list.append(load_image(f))
187
+ image_tensor = [
188
+ our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][
189
+ 0
190
+ ]
191
+ .half()
192
+ .to(our_chatbot.model.device)
193
+ for f in image_list
194
+ ]
195
 
196
  image_tensor = torch.stack(image_tensor)
197
+ image_token = DEFAULT_IMAGE_TOKEN * num_new_images
198
  # if our_chatbot.model.config.mm_use_im_start_end:
199
  # inp = DEFAULT_IM_START_TOKEN + image_token + DEFAULT_IM_END_TOKEN + "\n" + inp
200
  # else:
201
+ inp = text
202
+ inp = image_token + "\n" + inp
203
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp)
204
  # image = None
205
  our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None)
206
  prompt = our_chatbot.conversation.get_prompt()
207
 
208
+ input_ids = (
209
+ tokenizer_image_token(
210
+ prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
211
+ )
212
+ .unsqueeze(0)
213
+ .to(our_chatbot.model.device)
214
+ )
215
+ stop_str = (
216
+ our_chatbot.conversation.sep
217
+ if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
218
+ else our_chatbot.conversation.sep2
219
+ )
220
  keywords = [stop_str]
221
+ stopping_criteria = KeywordsStoppingCriteria(
222
+ keywords, our_chatbot.tokenizer, input_ids
223
+ )
224
+ streamer = TextStreamer(
225
+ our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
226
+ )
227
  print(our_chatbot.model.device)
228
  print(input_ids.device)
229
  print(image_tensor.device)
230
  # import pdb;pdb.set_trace()
231
  with torch.inference_mode():
232
+ output_ids = our_chatbot.model.generate(
233
+ input_ids,
234
+ images=image_tensor,
235
+ do_sample=True,
236
+ temperature=0.2,
237
+ max_new_tokens=1024,
238
+ streamer=streamer,
239
+ use_cache=False,
240
+ stopping_criteria=[stopping_criteria],
241
+ )
242
 
243
  outputs = our_chatbot.tokenizer.decode(output_ids[0]).strip()
244
  if outputs.endswith(stop_str):
245
+ outputs = outputs[: -len(stop_str)]
246
  our_chatbot.conversation.messages[-1][-1] = outputs
247
+
248
+ history[-1] = [text, outputs]
249
+
250
  return history
251
+
252
+
253
  txt = gr.Textbox(
254
  scale=4,
255
  show_label=False,
256
  placeholder="Enter text and press enter.",
257
  container=False,
258
  )
259
+
260
+ with gr.Blocks(
261
+ css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}",
262
+ ) as demo:
263
+
264
  # Informations
265
+ title_markdown = """
266
  # LLaVA-NeXT Interleave
267
  [[Blog]](https://llava-vl.github.io/blog/2024-06-16-llava-next-interleave/) [[Code]](https://github.com/LLaVA-VL/LLaVA-NeXT) [[Model]](https://huggingface.co/lmms-lab/llava-next-interleave-7b)
268
+ """
269
+ tos_markdown = """
270
  ### TODO!. Terms of use
271
  By using this service, users are required to agree to the following terms:
272
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
273
  Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
274
  For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
275
+ """
276
+ learn_more_markdown = """
277
  ### TODO!. License
278
  The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
279
+ """
280
  models = [
281
  "LLaVA-Interleave-7B",
282
  ]
283
  cur_dir = os.path.dirname(os.path.abspath(__file__))
284
  gr.Markdown(title_markdown)
285
+ with gr.Column():
286
+ with gr.Row():
287
+ chatbot = gr.Chatbot([], elem_id="chatbot", bubble_full_width=False)
288
+
289
+ with gr.Row():
290
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=True)
291
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=True)
292
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
293
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True)
294
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
295
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
296
+
297
+ chat_input = gr.MultimodalTextbox(
298
+ interactive=True,
299
+ file_types=["image", "video"],
300
+ placeholder="Enter message or upload file...",
301
+ show_label=False,
302
+ )
303
+
304
+ print(cur_dir)
305
+ gr.Examples(
306
+ examples=[
307
+ # [
308
+ # {
309
+ # "text": "<image> <image> <image> Which image shows a different mood of character from the others?",
310
+ # "files": [f"{cur_dir}/examples/examples_image12.jpg", f"{cur_dir}/examples/examples_image13.jpg", f"{cur_dir}/examples/examples_image14.jpg"]
311
+ # },
312
+ # {
313
+ # "text": "Please pay attention to the movement of the object from the first image to the second image, then write a HTML code to show this movement.",
314
+ # "files": [
315
+ # f"{cur_dir}/examples/code1.jpeg",
316
+ # f"{cur_dir}/examples/code2.jpeg",
317
+ # ],
318
+ # }
319
+ # ],
320
+ [
321
+ {
322
+ "files": [
323
+ f"{cur_dir}/examples/shub.jpg",
324
+ f"{cur_dir}/examples/shuc.jpg",
325
+ f"{cur_dir}/examples/shud.jpg",
326
+ ],
327
+ "text": "what is fun about the images?",
328
+ }
329
+ ],
330
+ [
331
+ {
332
+ "files": [
333
+ f"{cur_dir}/examples/iphone-15-price-1024x576.jpg",
334
+ f"{cur_dir}/examples/dynamic-island-1024x576.jpg",
335
+ f"{cur_dir}/examples/iphone-15-colors-1024x576.jpg",
336
+ f"{cur_dir}/examples/Iphone-15-Usb-c-charger-1024x576.jpg",
337
+ f"{cur_dir}/examples/A-17-processors-1024x576.jpg",
338
+ ],
339
+ "text": "The images are the PPT of iPhone 15 review. can you summarize the main information?",
340
+ }
341
+ ],
342
+ [
343
+ {
344
+ "files": [
345
+ f"{cur_dir}/examples/fangao3.jpeg",
346
+ f"{cur_dir}/examples/fangao2.jpeg",
347
+ f"{cur_dir}/examples/fangao1.jpeg",
348
+ ],
349
+ "text": "Do you kown who draw these paintings?",
350
+ }
351
+ ],
352
+ [
353
+ {
354
+ "files": [
355
+ f"{cur_dir}/examples/oprah-winfrey-resume.png",
356
+ f"{cur_dir}/examples/steve-jobs-resume.jpg",
357
+ ],
358
+ "text": "Hi, there are two candidates, can you provide a brief description for each of them for me?",
359
+ }
360
+ ],
361
+ [
362
+ {
363
+ "files": [
364
+ f"{cur_dir}/examples/original_bench.jpeg",
365
+ f"{cur_dir}/examples/changed_bench.jpeg",
366
+ ],
367
+ "text": "How to edit image1 to make it look like image2?",
368
+ }
369
+ ],
370
+ [
371
+ {
372
+ "files": [
373
+ f"{cur_dir}/examples/twitter2.jpeg",
374
+ f"{cur_dir}/examples/twitter3.jpeg",
375
+ f"{cur_dir}/examples/twitter4.jpeg",
376
+ ],
377
+ "text": "Please write a twitter blog post with the images.",
378
+ }
379
+ ],
380
+ [
381
+ {
382
+ "files": [
383
+ f"{cur_dir}/examples/twitter3.jpeg",
384
+ f"{cur_dir}/examples/twitter4.jpeg",
385
+ ],
386
+ "text": "Please write a twitter blog post with the images.",
387
+ }
388
+ ],
389
+ [
390
+ {
391
+ "files": [
392
+ f"playground/demo/examples/lion1_.mp4",
393
+ f"playground/demo/examples/lion2_.mp4",
394
+ ],
395
+ "text": "The input contains two videos, the first half is the first video and the second half is the second video. What is the difference between the two videos?",
396
+ }
397
+ ],
398
+ ],
399
+ inputs=[chat_input],
400
+ label="Compare images: ",
401
+ examples_per_page=3,
402
+ )
403
+
404
+ chat_msg = chat_input.submit(
405
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
406
+ )
407
  bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
408
  bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
409
 
410
+ # chatbot.like(print_like_dislike, None, None)
411
+ clear_btn.click(
412
+ fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all"
413
+ )
 
 
414
 
415
 
416
  demo.queue()
417
  if __name__ == "__main__":
418
  import argparse
419
+
420
  argparser = argparse.ArgumentParser()
421
  argparser.add_argument("--server_name", default="0.0.0.0", type=str)
422
  argparser.add_argument("--port", default="6123", type=str)
423
+ argparser.add_argument(
424
+ "--model_path", default="lmms-lab/llava-next-interleave-qwen-7b", type=str
425
+ )
426
  # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m")
427
  argparser.add_argument("--model-base", type=str, default=None)
428
  argparser.add_argument("--num-gpus", type=int, default=1)
 
433
  argparser.add_argument("--load-8bit", action="store_true")
434
  argparser.add_argument("--load-4bit", action="store_true")
435
  argparser.add_argument("--debug", action="store_true")
436
+
437
  args = argparser.parse_args()
438
  model_path = args.model_path
439
+ filt_invalid = "cut"
440
+ # model_name = get_model_name_from_path(args.model_path)
441
+ # tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
442
+ # model=model.to(torch.device('cuda'))
443
  our_chatbot = None
444
  # import pdb;pdb.set_trace()
445
  # try:
446
+ demo.launch()