DimaKoshman's picture
Create app.py
6db7b4c
raw
history blame
No virus
1.34 kB
import gradio
import transformers
import types
checkpoint_path = "checkpoint"
examples_path = "examples"
MODEL = types.SimpleNamespace()
MODEL.donut_processor = transformers.DonutProcessor.from_pretrained(checkpoint_path)
MODEL.encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(checkpoint_path)
MODEL.tokenizer = MODEL.donut_processor.tokenizer
def generate_token_strings(images, skip_special_tokens=True) -> list[str]:
decoder_output = MODEL.encoder_decoder.generate(
images,
max_length=MODEL.encoder_decoder.config.decoder.max_length,
eos_token_id=MODEL.tokenizer.eos_token_id,
return_dict_in_generate=True,
)
return MODEL.tokenizer.batch_decode(
decoder_output.sequences, skip_special_tokens=skip_special_tokens
)
def predict_string(image) -> str:
image = MODEL.donut_processor(
image, random_padding=False, return_tensors="pt"
).pixel_values
string = generate_token_strings(image)[0]
return string
interface = gradio.Interface(
title = "Making graphs accessible",
description = "Generate textual representation of a graph\n"
"https://www.kaggle.com/competitions/benetech-making-graphs-accessible",
fn=predict_string,
inputs="image",
outputs="text",
examples=examples_path,
)
interface.launch()