DimaKoshman commited on
Commit
6db7b4c
1 Parent(s): 915fb97

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import transformers
3
+ import types
4
+
5
+
6
+ checkpoint_path = "checkpoint"
7
+ examples_path = "examples"
8
+
9
+ MODEL = types.SimpleNamespace()
10
+ MODEL.donut_processor = transformers.DonutProcessor.from_pretrained(checkpoint_path)
11
+ MODEL.encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(checkpoint_path)
12
+ MODEL.tokenizer = MODEL.donut_processor.tokenizer
13
+
14
+
15
+ def generate_token_strings(images, skip_special_tokens=True) -> list[str]:
16
+ decoder_output = MODEL.encoder_decoder.generate(
17
+ images,
18
+ max_length=MODEL.encoder_decoder.config.decoder.max_length,
19
+ eos_token_id=MODEL.tokenizer.eos_token_id,
20
+ return_dict_in_generate=True,
21
+ )
22
+ return MODEL.tokenizer.batch_decode(
23
+ decoder_output.sequences, skip_special_tokens=skip_special_tokens
24
+ )
25
+
26
+ def predict_string(image) -> str:
27
+ image = MODEL.donut_processor(
28
+ image, random_padding=False, return_tensors="pt"
29
+ ).pixel_values
30
+ string = generate_token_strings(image)[0]
31
+ return string
32
+
33
+
34
+ interface = gradio.Interface(
35
+ title = "Making graphs accessible",
36
+ description = "Generate textual representation of a graph\n"
37
+ "https://www.kaggle.com/competitions/benetech-making-graphs-accessible",
38
+ fn=predict_string,
39
+ inputs="image",
40
+ outputs="text",
41
+ examples=examples_path,
42
+ )
43
+
44
+ interface.launch()