ahmed-masry's picture
Update app.py
0a31ab4 verified
raw
history blame contribute delete
No virus
2.47 kB
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
import torch, os, re
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/multi_col_40777.png', 'chart_example_1.png')
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/OECD_SECONDARY_GRADUATION_RATE_ESP_ITA_MEX_000019.png', 'chart_example_2.png')
model_name = "ahmed-masry/unichart-chartqa-960"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = DonutProcessor.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def predict(image, input_prompt):
input_prompt = "<chartqa> " + input_prompt + " <s_answer>"
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=4,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = sequence.split("<s_answer>")[1].strip()
return sequence
image = gr.components.Image(type="pil", label="Chart Image")
input_prompt = gr.components.Textbox(label="Question")
model_output = gr.components.Textbox(label="Model Output")
examples = [["chart_example_1.png", "What is the lowest value in blue bar?"],
["chart_example_2.png", "Which country has highest secondary graduation rate in 2018?"]]
title = "Interactive Gradio Demo for UniChart-ChartQA model"
interface = gr.Interface(fn=predict,
inputs=[image, input_prompt],
outputs=model_output,
examples=examples,
title=title,
theme='gradio/soft')
interface.launch()