merve HF staff commited on
Commit
99ad72b
1 Parent(s): 9bbf94b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -26
app.py CHANGED
@@ -6,15 +6,20 @@ import time
6
  from PIL import Image
7
  import torch
8
  import cv2
9
- import spaces
10
- model_id = "llava-hf/llava-interleave-qwen-7b-hf"
 
11
 
12
  processor = LlavaProcessor.from_pretrained(model_id)
13
 
14
  model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
15
  model.to("cuda")
16
 
17
- def sample_frames(video_file, num_frames) :
 
 
 
 
18
  video = cv2.VideoCapture(video_file)
19
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
20
  interval = total_frames // num_frames
@@ -31,9 +36,16 @@ def sample_frames(video_file, num_frames) :
31
 
32
  @spaces.GPU
33
  def bot_streaming(message, history):
34
- if message["files"]:
35
- image = message["files"][-1]
36
-
 
 
 
 
 
 
 
37
  else:
38
  # if there's no image uploaded for this turn, look for images in the past turns
39
  # kept inside tuples, take the last one
@@ -41,28 +53,44 @@ def bot_streaming(message, history):
41
  if type(hist[0])==tuple:
42
  image = hist[0][0]
43
 
44
- txt = message["text"]
45
- img = message["files"]
46
- ext_buffer =f"'user\ntext': '{txt}', 'files': '{img}' assistant"
47
-
48
- if image is None:
49
  gr.Error("You need to upload an image or video for LLaVA to work.")
50
 
51
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg")
52
  image_extensions = Image.registered_extensions()
53
  image_extensions = tuple([ex for ex, f in image_extensions.items()])
54
-
55
- if image.endswith(video_extensions):
56
- image = sample_frames(image, 12)
57
- image_tokens = "<image>" * 13
58
- prompt = f"<|im_start|>user {image_tokens}\n{message}<|im_end|><|im_start|>assistant"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- elif image.endswith(image_extensions):
61
- image = Image.open(image).convert("RGB")
62
- prompt = f"<|im_start|>user <image>\n{message}<|im_end|><|im_start|>assistant"
 
 
63
 
64
  inputs = processor(prompt, image, return_tensors="pt").to("cuda", torch.float16)
65
- streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
66
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
67
  generated_text = ""
68
 
@@ -75,15 +103,19 @@ def bot_streaming(message, history):
75
  for new_text in streamer:
76
 
77
  buffer += new_text
78
- print(buffer)
79
  generated_text_without_prompt = buffer[len(ext_buffer):]
80
  time.sleep(0.01)
81
  yield generated_text_without_prompt
82
 
83
 
84
- demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Interleave", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
85
- {"text": "How to make this pastry?", "files":["./baklava.png"]},
86
- {"text": "What type of cats are these?", "files":["./cats.mp4"]}],
87
- description="Try [LLaVA Interleave](https://huggingface.co/docs/transformers/main/en/model_doc/llava) in this demo (more specifically, the [Qwen-1.5-7B variant](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf)). Upload an image or a video, and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
88
- stop_btn="Stop Generation", multimodal=True)
 
 
 
 
89
  demo.launch(debug=True)
 
6
  from PIL import Image
7
  import torch
8
  import cv2
9
+ import spaces
10
+
11
+ model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
12
 
13
  processor = LlavaProcessor.from_pretrained(model_id)
14
 
15
  model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
16
  model.to("cuda")
17
 
18
+
19
+ def replace_video_with_images(text, frames):
20
+ return text.replace("<video>", "<image>" * frames)
21
+
22
+ def sample_frames(video_file, num_frames):
23
  video = cv2.VideoCapture(video_file)
24
  total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
25
  interval = total_frames // num_frames
 
36
 
37
  @spaces.GPU
38
  def bot_streaming(message, history):
39
+
40
+ txt = message.text
41
+ ext_buffer = f"user\n{txt} assistant"
42
+
43
+ if message.files:
44
+ if len(message.files) == 1:
45
+ image = [message.files[0].path]
46
+ # interleaved images or video
47
+ elif len(message.files) > 1:
48
+ image = [msg.path for msg in message.files]
49
  else:
50
  # if there's no image uploaded for this turn, look for images in the past turns
51
  # kept inside tuples, take the last one
 
53
  if type(hist[0])==tuple:
54
  image = hist[0][0]
55
 
56
+ if message.files is None:
 
 
 
 
57
  gr.Error("You need to upload an image or video for LLaVA to work.")
58
 
59
  video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg")
60
  image_extensions = Image.registered_extensions()
61
  image_extensions = tuple([ex for ex, f in image_extensions.items()])
62
+ if len(image) == 1:
63
+ if image[0].endswith(video_extensions):
64
+
65
+ image = sample_frames(image[0], 12)
66
+ image_tokens = "<image>" * 13
67
+ prompt = f"<|im_start|>user {image_tokens}\n{message.text}<|im_end|><|im_start|>assistant"
68
+ elif image[0].endswith(image_extensions):
69
+ image = Image.open(image[0]).convert("RGB")
70
+ prompt = f"<|im_start|>user <image>\n{message.text}<|im_end|><|im_start|>assistant"
71
+
72
+ elif len(image) > 1:
73
+ image_list = []
74
+ user_prompt = message.text
75
+
76
+ for img in image:
77
+ if img.endswith(image_extensions):
78
+ img = Image.open(img).convert("RGB")
79
+ image_list.append(img)
80
+
81
+ elif img.endswith(video_extensions):
82
+ frames = sample_frames(img, 6)
83
+ for frame in frames:
84
+ image_list.append(frame)
85
 
86
+ toks = "<image>" * len(image_list)
87
+ prompt = "<|im_start|>user"+ toks + f"\n{user_prompt}<|im_end|><|im_start|>assistant"
88
+
89
+ image = image_list
90
+
91
 
92
  inputs = processor(prompt, image, return_tensors="pt").to("cuda", torch.float16)
93
+ streamer = TextIteratorStreamer(processor, **{"max_new_tokens": 200, "skip_special_tokens": True})
94
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
95
  generated_text = ""
96
 
 
103
  for new_text in streamer:
104
 
105
  buffer += new_text
106
+
107
  generated_text_without_prompt = buffer[len(ext_buffer):]
108
  time.sleep(0.01)
109
  yield generated_text_without_prompt
110
 
111
 
112
+ demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Interleave", examples=[
113
+ {"text": "What are these cats doing?", "files":["./cats.mp4"]},
114
+ {"text": "The input contains two videos, are the cats in this video and this video doing the same thing?", "files":["./cats_1.mp4", "./cats_2.mp4"]},
115
+ {"text": "What is on the flower?", "files":["./bee.jpg"]},
116
+ {"text": "There are two images in the input. What is the relationship between this image and this image?", "files":["./bee.jpg", "./depth-bee.png"]},
117
+ {"text": "How to make this pastry?", "files":["./baklava.png"]}],
118
+ textbox=gr.MultimodalTextbox(file_count="multiple"),
119
+ description="Try [LLaVA Interleave](https://huggingface.co/docs/transformers/main/en/model_doc/llava) in this demo (more specifically, the [Qwen-1.5-7B variant](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf)). Upload an image or a video, and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error. ",
120
+ stop_btn="Stop Generation", multimodal=True)
121
  demo.launch(debug=True)