KB-VQA-E / app.py
m7mdal7aj's picture
Update app.py
94087a7 verified
raw
history blame
No virus
2.56 kB
import streamlit as st
import torch
import bitsandbytes
import accelerate
import scipy
from PIL import Image
import torch.nn as nn
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
from My_Model.object_detection import ObjectDetector
def load_caption_model(blip2=False, instructblip=True):
if blip2:
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", load_in_8bit=True,torch_dtype=torch.float16)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to('cuda')
#model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
if instructblip:
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.to('cuda')
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_8bit=True,torch_dtype=torch.float16)
return model, processor
def answer_question(image, question, model, processor):
image = Image.open(image)
inputs = processor(image, question, return_tensors="pt").to("cuda", torch.float16)
if isinstance(model, torch.nn.DataParallel):
# Use the 'module' attribute to access the original model
out = model.module.generate(**inputs, max_length=100, min_length=20)
else:
out = model.generate(**inputs, max_length=100, min_length=20)
answer = processor.decode(out[0], skip_special_tokens=True).strip()
return answer
st.title("Image Question Answering")
# File uploader for the image
image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
# Text input for the question
question = st.text_input("Enter your question about the image:")
if st.button("Get Answer"):
if image is not None and question:
# Display the image
st.image(image, use_column_width=True)
# Get and display the answer
model, processor = load_caption_model()
answer = answer_question(image, question, model, processor)
st.write(answer)
else:
st.write("Please upload an image and enter a question.")