ManishThota commited on
Commit
9db455c
1 Parent(s): 0f80bf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -15
app.py CHANGED
@@ -2,6 +2,11 @@ import gradio as gr
2
  from PIL import Image
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
5
 
6
 
7
  # # Ensure GPU usage if available
@@ -16,22 +21,94 @@ model = AutoModelForCausalLM.from_pretrained("ManishThota/SparrowVQE",
16
  trust_remote_code=True)
17
  tokenizer = AutoTokenizer.from_pretrained("ManishThota/SparrowVQE", trust_remote_code=True)
18
 
19
- def predict_answer(image, question, max_tokens=100):
20
- #Set inputs
21
- text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{question}? ASSISTANT:"
22
- image = image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
25
- image_tensor = model.image_preprocess(image)
26
 
27
- #Generate the answer
28
- output_ids = model.generate(
29
- input_ids,
30
- max_new_tokens=max_tokens,
31
- images=image_tensor,
32
- use_cache=True)[0]
33
 
34
- return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
35
 
36
  def gradio_predict(image, question, max_tokens):
37
  answer = predict_answer(image, question, max_tokens)
@@ -50,9 +127,10 @@ def gradio_predict(image, question, max_tokens):
50
  # Define the Gradio interface
51
  iface = gr.Interface(
52
  fn=gradio_predict,
53
- inputs=[gr.Image(type="pil", label="Upload or Drag an Image"),
 
54
  gr.Textbox(label="Question", placeholder="e.g. Can you explain the slide?", scale=4),
55
- gr.Slider(2, 500, value=100, label="Token Count", info="Choose between 2 and 500")],
56
  outputs=gr.TextArea(label="Answer"),
57
  # examples=examples,
58
  title="Super Rapid Annotator - Multimodal vision tool to annotate videos with LLaVA framework",
 
2
  from PIL import Image
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import magic
6
+ import mimetypes
7
+ import cv2
8
+ import numpy as np
9
+ import io
10
 
11
 
12
  # # Ensure GPU usage if available
 
21
  trust_remote_code=True)
22
  tokenizer = AutoTokenizer.from_pretrained("ManishThota/SparrowVQE", trust_remote_code=True)
23
 
24
+
25
+ def get_file_type_from_bytes(file_bytes):
26
+ """Determine whether a file is an image or a video based on its MIME type from bytes."""
27
+ mime = magic.Magic(mime=True)
28
+ mimetype = mime.from_buffer(file_bytes)
29
+ if mimetype.startswith('image'):
30
+ return 'image'
31
+ elif mimetype.startswith('video'):
32
+ return 'video'
33
+ return 'unknown'
34
+
35
+ def process_video(video_bytes):
36
+ """Extracts frames from the video, 1 per second."""
37
+ video = cv2.VideoCapture(io.BytesIO(video_bytes))
38
+ fps = video.get(cv2.CAP_PROP_FPS)
39
+ frames = []
40
+ success, frame = video.read()
41
+ while success:
42
+ frames.append(frame)
43
+ for _ in range(int(fps)): # Skip fps frames
44
+ success, frame = video.read()
45
+ video.release()
46
+ return frames[:4] # Return the first 4 frames
47
+
48
+
49
+ def predict_answer(file, question, max_tokens=100):
50
+
51
+ file_type = get_file_type_from_bytes(file)
52
+
53
+ if file_type == 'image':
54
+ # Process as an image
55
+ image = Image.open(io.BytesIO(file))
56
+ frame = image.convert("RGB")
57
+ input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
58
+ image_tensor = model.image_preprocess(frame)
59
+
60
+ #Generate the answer
61
+ output_ids = model.generate(
62
+ input_ids,
63
+ max_new_tokens=max_tokens,
64
+ images=image_tensor,
65
+ use_cache=True)[0]
66
+
67
+ return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
68
+
69
+ elif file_type == 'video':
70
+ # Process as a video
71
+ frames = process_video(file)
72
+ answers = []
73
+ for frame in frames:
74
+ frame = Image.open(frame).convert("RGB")
75
+ input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
76
+ image_tensor = model.image_preprocess(frame)
77
+
78
+ # Generate the answer
79
+ output_ids = model.generate(
80
+ input_ids,
81
+ max_new_tokens=max_tokens,
82
+ images=image_tensor,
83
+ use_cache=True)[0]
84
+
85
+ answer = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
86
+ answers.append(answer)
87
+ return "\n".join(answers)
88
+
89
+ else:
90
+ return "Unsupported file type. Please upload an image or video."
91
+
92
+
93
+
94
+
95
+
96
+ # def predict_answer(image, question, max_tokens=100):
97
+ # #Set inputs
98
+ # text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{question}? ASSISTANT:"
99
+ # image = image.convert("RGB")
100
 
101
+ # input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
102
+ # image_tensor = model.image_preprocess(image)
103
 
104
+ # #Generate the answer
105
+ # output_ids = model.generate(
106
+ # input_ids,
107
+ # max_new_tokens=max_tokens,
108
+ # images=image_tensor,
109
+ # use_cache=True)[0]
110
 
111
+ # return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
112
 
113
  def gradio_predict(image, question, max_tokens):
114
  answer = predict_answer(image, question, max_tokens)
 
127
  # Define the Gradio interface
128
  iface = gr.Interface(
129
  fn=gradio_predict,
130
+ inputs=[gr.File(label="Upload an Image or Video"),
131
+ # gr.Image(type="pil", label="Upload or Drag an Image"),
132
  gr.Textbox(label="Question", placeholder="e.g. Can you explain the slide?", scale=4),
133
+ gr.Slider(2, 500, value=25, label="Token Count", info="Choose between 2 and 500")],
134
  outputs=gr.TextArea(label="Answer"),
135
  # examples=examples,
136
  title="Super Rapid Annotator - Multimodal vision tool to annotate videos with LLaVA framework",