|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
def query(image, user_question): |
|
""" |
|
image: single image or batch of images; |
|
question: user prompt question; |
|
""" |
|
|
|
model_name = "google/deplot" |
|
|
|
processor = Pix2StructProcessor.from_pretrained(model_name) |
|
|
|
model = Pix2StructForConditionalGeneration.from_pretrained(model_name) |
|
|
|
inputs = processor(images=image, text=user_question, return_tensors="pt") |
|
|
|
predictions = model.generate(**inputs, max_new_tokens=512) |
|
|
|
result = processor.decode(predictions[0], skip_special_tokens=True) |
|
|
|
outs = [x.strip() for x in result.split("<0x0A>")] |
|
|
|
nested = list() |
|
|
|
for data in outs: |
|
if "|" in data: |
|
nested.append([x.strip() for x in data.split("|")]) |
|
else: |
|
nested.append(data) |
|
|
|
return nested |
|
|
|
|
|
|
|
|
|
|
|
|
|
ui = gr.Interface(title="Chart Q/A", |
|
fn=query, |
|
inputs=[gr.Image(label="Upload Here", type="pil"), gr.Textbox(label="Question?")], |
|
outputs="list", |
|
examples=[["./samples/sample1.png", "Generate underlying data table of the figure"], |
|
["./samples/sample2.png", "Is the sum of all 4 places greater than Laos?"]], |
|
|
|
cache_examples=True, |
|
allow_flagging='never') |
|
|
|
ui.queue(api_open=True) |
|
ui.launch(inline=False, share=False, debug=True) |
|
|
|
|