Build / app.py
ManishThota's picture
Update app.py
9db455c verified
raw
history blame
5.77 kB
import gradio as gr
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import magic
import mimetypes
import cv2
import numpy as np
import io
# # Ensure GPU usage if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("ManishThota/SparrowVQE",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ManishThota/SparrowVQE", trust_remote_code=True)
def get_file_type_from_bytes(file_bytes):
"""Determine whether a file is an image or a video based on its MIME type from bytes."""
mime = magic.Magic(mime=True)
mimetype = mime.from_buffer(file_bytes)
if mimetype.startswith('image'):
return 'image'
elif mimetype.startswith('video'):
return 'video'
return 'unknown'
def process_video(video_bytes):
"""Extracts frames from the video, 1 per second."""
video = cv2.VideoCapture(io.BytesIO(video_bytes))
fps = video.get(cv2.CAP_PROP_FPS)
frames = []
success, frame = video.read()
while success:
frames.append(frame)
for _ in range(int(fps)): # Skip fps frames
success, frame = video.read()
video.release()
return frames[:4] # Return the first 4 frames
def predict_answer(file, question, max_tokens=100):
file_type = get_file_type_from_bytes(file)
if file_type == 'image':
# Process as an image
image = Image.open(io.BytesIO(file))
frame = image.convert("RGB")
input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
image_tensor = model.image_preprocess(frame)
#Generate the answer
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
images=image_tensor,
use_cache=True)[0]
return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
elif file_type == 'video':
# Process as a video
frames = process_video(file)
answers = []
for frame in frames:
frame = Image.open(frame).convert("RGB")
input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
image_tensor = model.image_preprocess(frame)
# Generate the answer
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
images=image_tensor,
use_cache=True)[0]
answer = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
answers.append(answer)
return "\n".join(answers)
else:
return "Unsupported file type. Please upload an image or video."
# def predict_answer(image, question, max_tokens=100):
# #Set inputs
# text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{question}? ASSISTANT:"
# image = image.convert("RGB")
# input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device)
# image_tensor = model.image_preprocess(image)
# #Generate the answer
# output_ids = model.generate(
# input_ids,
# max_new_tokens=max_tokens,
# images=image_tensor,
# use_cache=True)[0]
# return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
def gradio_predict(image, question, max_tokens):
answer = predict_answer(image, question, max_tokens)
return answer
# examples = [["data/week_01_page_024.png", 'Can you explain the slide?',100],
# ["data/week_03_page_091.png", 'Can you explain the slide?',100],
# ["data/week_01_page_062.png", 'Are the training images labeled?',100],
# ["data/week_05_page_027.png", 'What is meant by eigenvalue multiplicity?',100],
# ["data/week_05_page_030.png", 'What does K represent?',100],
# ["data/week_15_page_046.png", 'How are individual heterogeneous models trained?',100],
# ["data/week_15_page_021.png", 'How does Bagging affect error?',100],
# ["data/week_15_page_037.png", "What does the '+' and '-' represent?",100]]
# Define the Gradio interface
iface = gr.Interface(
fn=gradio_predict,
inputs=[gr.File(label="Upload an Image or Video"),
# gr.Image(type="pil", label="Upload or Drag an Image"),
gr.Textbox(label="Question", placeholder="e.g. Can you explain the slide?", scale=4),
gr.Slider(2, 500, value=25, label="Token Count", info="Choose between 2 and 500")],
outputs=gr.TextArea(label="Answer"),
# examples=examples,
title="Super Rapid Annotator - Multimodal vision tool to annotate videos with LLaVA framework",
# description="An interactive chat model that can answer questions about images in an Academic context. \n We can input images, and the system will analyze them to provide information about their contents. I've utilized this capability by feeding slides from PowerPoint presentations used in classes and the lecture content passed as text. Consequently, the model now mimics the behavior and responses of my professors. So, if I present any PowerPoint slide, it explains it just like my professor would, further it can be personalized.",
)
# Launch the app
iface.queue().launch(debug=True)