merve HF staff commited on
Commit
59e8091
·
verified ·
1 Parent(s): c053e1a

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -0
  2. app.py +89 -0
  3. baklava.png +3 -0
  4. bee.jpg +3 -0
  5. cats.mp4 +0 -0
  6. requirements.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ baklava.png filter=lfs diff=lfs merge=lfs -text
37
+ bee.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
3
+ from threading import Thread
4
+ import re
5
+ import time
6
+ from PIL import Image
7
+ import torch
8
+ import cv2
9
+ import spaces
10
+ model_id = "llava-hf/llava-interleave-qwen-7b-hf"
11
+
12
+ processor = LlavaProcessor.from_pretrained(model_id)
13
+
14
+ model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
15
+ model.to("cuda")
16
+
17
+ def sample_frames(video_file, num_frames) :
18
+ video = cv2.VideoCapture(video_file)
19
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
20
+ interval = total_frames // num_frames
21
+ frames = []
22
+ for i in range(total_frames):
23
+ ret, frame = video.read()
24
+ pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
25
+ if not ret:
26
+ continue
27
+ if i % interval == 0:
28
+ frames.append(pil_img)
29
+ video.release()
30
+ return frames
31
+
32
+ @spaces.GPU
33
+ def bot_streaming(message, history):
34
+ if message["files"]:
35
+ image = message["files"][-1]
36
+
37
+ else:
38
+ # if there's no image uploaded for this turn, look for images in the past turns
39
+ # kept inside tuples, take the last one
40
+ for hist in history:
41
+ if type(hist[0])==tuple:
42
+ image = hist[0][0]
43
+
44
+ txt = message["text"]
45
+ img = message["files"]
46
+ ext_buffer =f"'user\ntext': '{txt}', 'files': '{img}' assistantAnswer:"
47
+
48
+ if image is None:
49
+ gr.Error("You need to upload an image or video for LLaVA to work.")
50
+
51
+ video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg")
52
+ image_extensions = Image.registered_extensions()
53
+ image_extensions = tuple([ex for ex, f in image_extensions.items()])
54
+
55
+ if image.endswith(video_extensions):
56
+ image = sample_frames(image, 5)
57
+ image_tokens = "<image>" * 5
58
+ prompt = f"<|im_start|>user {image_tokens}\n{message}<|im_end|><|im_start|>assistant"
59
+
60
+ elif image.endswith(image_extensions):
61
+ image = Image.open(image).convert("RGB")
62
+ prompt = f"<|im_start|>user <image>\n{message}<|im_end|><|im_start|>assistant"
63
+
64
+ inputs = processor(prompt, image, return_tensors="pt").to("cuda", torch.float16)
65
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
66
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
67
+ generated_text = ""
68
+
69
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
70
+ thread.start()
71
+
72
+
73
+
74
+ buffer = ""
75
+ for new_text in streamer:
76
+
77
+ buffer += new_text
78
+ print(buffer)
79
+ generated_text_without_prompt = buffer[len(ext_buffer):]
80
+ time.sleep(0.01)
81
+ yield generated_text_without_prompt
82
+
83
+
84
+ demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA Interleave", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
85
+ {"text": "How to make this pastry?", "files":["./baklava.png"]},
86
+ {"text": "What type of cats are these?", "files":["./cats.mp4"]}],
87
+ description="Try [LLaVA Interleave](https://huggingface.co/docs/transformers/main/en/model_doc/llava) in this demo (more specifically, the [Qwen-1.5-7B variant](https://huggingface.co/llava-hf/llava-interleave-qwen-7b-hf)). Upload an image or a video, and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
88
+ stop_btn="Stop Generation", multimodal=True)
89
+ demo.launch(debug=True)
baklava.png ADDED

Git LFS Details

  • SHA256: 7839e93dd753e5356176bf70d38c43bc56355099d8891ead7aaa342029369268
  • Pointer size: 132 Bytes
  • Size of remote file: 2.04 MB
bee.jpg ADDED

Git LFS Details

  • SHA256: 8b21ba78250f852ca5990063866b1ace6432521d0251bde7f8de783b22c99a6d
  • Pointer size: 132 Bytes
  • Size of remote file: 5.37 MB
cats.mp4 ADDED
Binary file (115 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers.git
3
+ spaces
4
+ opencv-python
5
+ accelerate