ManishThota commited on
Commit
edff486
1 Parent(s): 5799880

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -54
app.py CHANGED
@@ -2,8 +2,6 @@ import gradio as gr
2
  from PIL import Image
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
- import magic
6
- import mimetypes
7
  import cv2
8
  import numpy as np
9
  import io
@@ -22,16 +20,6 @@ model = AutoModelForCausalLM.from_pretrained("ManishThota/SparrowVQE",
22
  tokenizer = AutoTokenizer.from_pretrained("ManishThota/SparrowVQE", trust_remote_code=True)
23
 
24
 
25
- def get_file_type_from_bytes(file_bytes):
26
- """Determine whether a file is an image or a video based on its MIME type from bytes."""
27
- mime = magic.Magic(mime=True)
28
- mimetype = mime.from_buffer(file_bytes)
29
- if mimetype.startswith('image'):
30
- return 'image'
31
- elif mimetype.startswith('video'):
32
- return 'video'
33
- return 'unknown'
34
-
35
  def process_video(video_bytes):
36
  """Extracts frames from the video, 1 per second."""
37
  video = cv2.VideoCapture(io.BytesIO(video_bytes))
@@ -46,15 +34,12 @@ def process_video(video_bytes):
46
  return frames[:4] # Return the first 4 frames
47
 
48
 
49
- def predict_answer(file, question, max_tokens=100):
50
-
51
- file_type = get_file_type_from_bytes(file)
52
 
53
- if file_type == 'image':
54
  # Process as an image
55
- image = Image.open(io.BytesIO(file))
56
- frame = image.convert("RGB")
57
- input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
58
  image_tensor = model.image_preprocess(frame)
59
 
60
  #Generate the answer
@@ -66,13 +51,13 @@ def predict_answer(file, question, max_tokens=100):
66
 
67
  return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
68
 
69
- elif file_type == 'video':
70
  # Process as a video
71
- frames = process_video(file)
72
  answers = []
73
  for frame in frames:
74
  frame = Image.open(frame).convert("RGB")
75
- input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
76
  image_tensor = model.image_preprocess(frame)
77
 
78
  # Generate the answer
@@ -90,45 +75,19 @@ def predict_answer(file, question, max_tokens=100):
90
  return "Unsupported file type. Please upload an image or video."
91
 
92
 
93
-
94
-
95
-
96
- # def predict_answer(image, question, max_tokens=100):
97
- # #Set inputs
98
- # text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{question}? ASSISTANT:"
99
- # image = image.convert("RGB")
100
-
101
- # input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
102
- # image_tensor = model.image_preprocess(image)
103
-
104
- # #Generate the answer
105
- # output_ids = model.generate(
106
- # input_ids,
107
- # max_new_tokens=max_tokens,
108
- # images=image_tensor,
109
- # use_cache=True)[0]
110
-
111
- # return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
112
 
113
- def gradio_predict(image, question, max_tokens):
114
- answer = predict_answer(image, question, max_tokens)
 
115
  return answer
116
 
117
-
118
- # examples = [["data/week_01_page_024.png", 'Can you explain the slide?',100],
119
- # ["data/week_03_page_091.png", 'Can you explain the slide?',100],
120
- # ["data/week_01_page_062.png", 'Are the training images labeled?',100],
121
- # ["data/week_05_page_027.png", 'What is meant by eigenvalue multiplicity?',100],
122
- # ["data/week_05_page_030.png", 'What does K represent?',100],
123
- # ["data/week_15_page_046.png", 'How are individual heterogeneous models trained?',100],
124
- # ["data/week_15_page_021.png", 'How does Bagging affect error?',100],
125
- # ["data/week_15_page_037.png", "What does the '+' and '-' represent?",100]]
126
 
127
  # Define the Gradio interface
128
  iface = gr.Interface(
129
  fn=gradio_predict,
130
- inputs=[gr.File(label="Upload an Image or Video"),
131
- # gr.Image(type="pil", label="Upload or Drag an Image"),
132
  gr.Textbox(label="Question", placeholder="e.g. Can you explain the slide?", scale=4),
133
  gr.Slider(2, 500, value=25, label="Token Count", info="Choose between 2 and 500")],
134
  outputs=gr.TextArea(label="Answer"),
 
2
  from PIL import Image
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
5
  import cv2
6
  import numpy as np
7
  import io
 
20
  tokenizer = AutoTokenizer.from_pretrained("ManishThota/SparrowVQE", trust_remote_code=True)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
23
  def process_video(video_bytes):
24
  """Extracts frames from the video, 1 per second."""
25
  video = cv2.VideoCapture(io.BytesIO(video_bytes))
 
34
  return frames[:4] # Return the first 4 frames
35
 
36
 
37
+ def predict_answer(image, video, question, max_tokens=100):
 
 
38
 
39
+ if image:
40
  # Process as an image
41
+ image = image.convert("RGB")
42
+ input_ids = tokenizer(question, return_tensors='pt').input_ids.to(device)
 
43
  image_tensor = model.image_preprocess(frame)
44
 
45
  #Generate the answer
 
51
 
52
  return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
53
 
54
+ elif video:
55
  # Process as a video
56
+ frames = process_video(video)
57
  answers = []
58
  for frame in frames:
59
  frame = Image.open(frame).convert("RGB")
60
+ input_ids = tokenizer(question, return_tensors='pt').input_ids.to(device)
61
  image_tensor = model.image_preprocess(frame)
62
 
63
  # Generate the answer
 
75
  return "Unsupported file type. Please upload an image or video."
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+
80
+ def gradio_predict(image, video, question, max_tokens):
81
+ answer = predict_answer(image, video, question, max_tokens)
82
  return answer
83
 
84
+
 
 
 
 
 
 
 
 
85
 
86
  # Define the Gradio interface
87
  iface = gr.Interface(
88
  fn=gradio_predict,
89
+ inputs=[gr.Image(type="pil", label="Upload or Drag an Image"),
90
+ gr.Video(label="upload your video here"),
91
  gr.Textbox(label="Question", placeholder="e.g. Can you explain the slide?", scale=4),
92
  gr.Slider(2, 500, value=25, label="Token Count", info="Choose between 2 and 500")],
93
  outputs=gr.TextArea(label="Answer"),