Chart_QnA / app.py
prattay's picture
Update app.py
33fb0ac
#!/usr/bin/env python
# coding: utf-8
# In[1]:
# import required libraries
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration
import gradio as gr
# In[2]:
# pipeline function with default values
def query(image, user_question):
"""
image: single image or batch of images;
question: user prompt question;
"""
# select model from hugging face
model_name = "google/deplot"
# set preprocessor for current model
processor = Pix2StructProcessor.from_pretrained(model_name)
# load pre-trained model
model = Pix2StructForConditionalGeneration.from_pretrained(model_name)
# process the inputs for prediction
inputs = processor(images=image, text=user_question, return_tensors="pt")
# save the results
predictions = model.generate(**inputs, max_new_tokens=512)
# save output
result = processor.decode(predictions[0], skip_special_tokens=True)
# process the results for output table
outs = [x.strip() for x in result.split("<0x0A>")]
# create an empty list
nested = list()
# loop for splitting the data
for data in outs:
if "|" in data:
nested.append([x.strip() for x in data.split("|")])
else:
nested.append(data)
# return the converted output
return nested
# In[ ]:
# Interface framework to customize the io page
ui = gr.Interface(title="Chart Q/A",
fn=query,
inputs=[gr.Image(label="Upload Here", type="pil"), gr.Textbox(label="Question?")],
outputs="list",
examples=[["./samples/sample1.png", "Generate underlying data table of the figure"],
["./samples/sample2.png", "Is the sum of all 4 places greater than Laos?"]],
# ["./samples/sample3.webp", "What are the 2020 net sales?"]],
cache_examples=True,
allow_flagging='never')
ui.queue(api_open=True)
ui.launch(inline=False, share=False, debug=True)