import gradio import transformers import types checkpoint_path = "checkpoint" examples_path = "examples" MODEL = types.SimpleNamespace() MODEL.donut_processor = transformers.DonutProcessor.from_pretrained(checkpoint_path) MODEL.encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(checkpoint_path) MODEL.tokenizer = MODEL.donut_processor.tokenizer def generate_token_strings(images, skip_special_tokens=True) -> list[str]: decoder_output = MODEL.encoder_decoder.generate( images, max_length=MODEL.encoder_decoder.config.decoder.max_length, eos_token_id=MODEL.tokenizer.eos_token_id, return_dict_in_generate=True, ) return MODEL.tokenizer.batch_decode( decoder_output.sequences, skip_special_tokens=skip_special_tokens ) def predict_string(image) -> str: image = MODEL.donut_processor( image, random_padding=False, return_tensors="pt" ).pixel_values string = generate_token_strings(image)[0] return string interface = gradio.Interface( title = "Making graphs accessible", description = "Generate textual representation of a graph\n" "https://www.kaggle.com/competitions/benetech-making-graphs-accessible", fn=predict_string, inputs="image", outputs="text", examples=examples_path, ) interface.launch()