TuringsSolutions commited on
Commit
1f434b1
1 Parent(s): 4d0b6cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -68
app.py CHANGED
@@ -1,79 +1,94 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
4
- from PIL import Image
5
- from threading import Thread
 
 
 
 
6
 
7
- # Initialize model and processor
8
- model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
9
- processor = LlavaProcessor.from_pretrained(model_id)
10
- model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")
 
11
 
12
- # Initialize inference clients
13
- client_mistral = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
 
 
14
 
15
- def llava(inputs):
16
- """Processes an image and text input using Llava."""
17
- try:
18
- image = Image.open(inputs["files"][0]).convert("RGB")
19
- prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
20
- processed = processor(prompt, image, return_tensors="pt").to("cpu")
21
- return processed
22
- except Exception as e:
23
- print(f"Error in llava function: {e}")
24
- return None
 
 
 
 
 
 
 
25
 
26
- def respond(message, history):
27
- """Generate a response based on text or image input."""
28
- try:
29
- if "files" in message and message["files"]:
30
- # Handle image + text input
31
- inputs = llava(message)
32
- if inputs is None:
33
- raise ValueError("Failed to process image input")
34
-
35
- streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
36
- thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
37
- thread.start()
38
-
39
- buffer = ""
40
- for new_text in streamer:
41
- buffer += new_text
42
- history[-1][1] = buffer
43
- yield history, history
44
- else:
45
- # Handle text-only input
46
- user_message = message["text"]
47
- history.append([user_message, None])
48
- prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]]
49
- response = client_mistral.chat_completion(prompt, max_tokens=200)
50
- bot_message = response["choices"][0]["message"]["content"]
51
- history[-1][1] = bot_message
52
- yield history, history
53
- except Exception as e:
54
- print(f"Error in respond function: {e}")
55
- history[-1][1] = f"An error occurred: {str(e)}"
56
- yield history, history
 
 
 
 
 
57
 
58
  # Set up Gradio interface
59
  with gr.Blocks() as demo:
60
- chatbot = gr.Chatbot()
61
  with gr.Row():
62
  with gr.Column():
63
- text_input = gr.Textbox(placeholder="Enter your message...")
64
- file_input = gr.File(label="Upload an image")
65
-
66
- def handle_text(text, history=[]):
67
- """Handle text input and generate responses."""
68
- return respond({"text": text}, history)
69
-
70
- def handle_file_upload(files, history=[]):
71
- """Handle file uploads and generate responses."""
72
- return respond({"files": files, "text": "Describe this image."}, history)
73
-
74
- # Connect components to callbacks
75
- text_input.submit(handle_text, [text_input, chatbot], [chatbot, chatbot])
76
- file_input.change(handle_file_upload, [file_input, chatbot], [chatbot, chatbot])
77
 
78
- # Launch the Gradio app
79
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from llava.model.builder import load_pretrained_model
4
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
5
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
6
+ from llava.conversation import conv_templates
7
+ import copy
8
+ from decord import VideoReader, cpu
9
+ import numpy as np
10
 
11
+ # Load the model
12
+ pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
13
+ model_name = "llava_qwen"
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ device_map = "auto"
16
 
17
+ print("Loading model...")
18
+ tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map)
19
+ model.eval()
20
+ print("Model loaded successfully!")
21
 
22
+ def load_video(video_path, max_frames_num, fps=1, force_sample=False):
23
+ if max_frames_num == 0:
24
+ return np.zeros((1, 336, 336, 3))
25
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
26
+ total_frame_num = len(vr)
27
+ video_time = total_frame_num / vr.get_avg_fps()
28
+ fps = round(vr.get_avg_fps()/fps)
29
+ frame_idx = [i for i in range(0, len(vr), fps)]
30
+ frame_time = [i/fps for i in frame_idx]
31
+ if len(frame_idx) > max_frames_num or force_sample:
32
+ sample_fps = max_frames_num
33
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
34
+ frame_idx = uniform_sampled_frames.tolist()
35
+ frame_time = [i/vr.get_avg_fps() for i in frame_idx]
36
+ frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
37
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
38
+ return spare_frames, frame_time, video_time
39
 
40
+ def process_video(video_path, question):
41
+ max_frames_num = 64
42
+ video, frame_time, video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
43
+ video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].to(device).bfloat16()
44
+ video = [video]
45
+
46
+ conv_template = "qwen_1_5"
47
+ time_instruction = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}. Please answer the following questions related to this video."
48
+
49
+ full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{question}"
50
+
51
+ conv = copy.deepcopy(conv_templates[conv_template])
52
+ conv.append_message(conv.roles[0], full_question)
53
+ conv.append_message(conv.roles[1], None)
54
+ prompt_question = conv.get_prompt()
55
+
56
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
57
+
58
+ with torch.no_grad():
59
+ output = model.generate(
60
+ input_ids,
61
+ images=video,
62
+ modalities=["video"],
63
+ do_sample=False,
64
+ temperature=0,
65
+ max_new_tokens=4096,
66
+ )
67
+
68
+ response = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
69
+ return response
70
+
71
+ def gradio_interface(video_file, question):
72
+ if video_file is None:
73
+ return "Please upload a video file."
74
+ response = process_video(video_file, question)
75
+ return response
76
 
77
  # Set up Gradio interface
78
  with gr.Blocks() as demo:
79
+ gr.Markdown("# 🌋📹 LLaVA-Video Chatbot")
80
  with gr.Row():
81
  with gr.Column():
82
+ video_input = gr.Video()
83
+ question_input = gr.Textbox(label="User Question", placeholder="Ask a question about the video...")
84
+ submit_button = gr.Button("Ask LLaVA-Video")
85
+ output = gr.Textbox(label="LLaVA-Video Response")
86
+
87
+ submit_button.click(
88
+ fn=gradio_interface,
89
+ inputs=[video_input, question_input],
90
+ outputs=output
91
+ )
 
 
 
 
92
 
93
+ if __name__ == "__main__":
94
+ demo.launch(show_error=True)