VisualQA / app.py
Pranav4datasc's picture
Upload app.py
605c39b verified
raw
history blame contribute delete
No virus
1.93 kB
import streamlit as st
from PIL import Image
import requests
from io import BytesIO
from transformers import ViltProcessor, ViltForQuestionAnswering
st.set_page_config(layout='wide',page_title='VQA')
#Vilt model
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
model =ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
def get_answer(image,text):
try:
#load and process the image
img = Image.open(BytesIO(image)).convert('RGB')
encoding = processor(img,text,return_tensors="pt")
#forward pass
outputs = model(**encoding)
logits = outputs.logits
idx = logits.argmax(-1).item()
answer = model.config.id2label[idx]
return answer
except Exception as e:
return str(e)
st.title("Visual Question Answering App")
st.write("Update an image and enter qustion to get and answer")
st.caption("Sample image...")
st.image("tulips.jpg",width=600)
col1,col2 = st.columns(2)
with col1:
uploaded_file = st.file_uploader("Upload your own image or simply drag sample image given above",type=['jpg','png','jpeg'])
st.image(uploaded_file,use_column_width=True)
with col2:
question = st.text_input("Question")
#st.text(question)
if uploaded_file and question is not None:
if st.button("Ask Question"):
image = Image.open(uploaded_file)
image_byte_array = BytesIO()
image.save(image_byte_array,format="JPEG")
image_bytes = image_byte_array.getvalue()
#st.show(answer)
st.info("Your Question is ..." + question)
answer = get_answer(image_bytes,question)
if answer is not None:
#st.text(answer)
st.info("Answer is ..."+ answer)
else:
st.text("Sorry I am not able to answer that question")