fruk19 commited on
Commit
1453fa3
1 Parent(s): 60df3f6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
4
+ import re
5
+ import json
6
+ from huggingface_hub import HfApi
7
+ import os
8
+ p1=os.environ.get("PATH_MODEL")
9
+ p2=os.environ.get("PATH_MODEL_v2")
10
+ print(p1,p2)
11
+
12
+ PATH_MODEL = "fruk19/donut_nfact_v4"
13
+ processor = DonutProcessor.from_pretrained(PATH_MODEL)
14
+ model = VisionEncoderDecoderModel.from_pretrained(PATH_MODEL)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model.eval()
17
+ model.to(device)
18
+
19
+ def predict(test_image):
20
+ pixel_values = processor(test_image, return_tensors="pt").pixel_values
21
+ pixel_values = pixel_values.to(device)
22
+
23
+ task_prompt = "<s_nfact>"
24
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
25
+ decoder_input_ids = decoder_input_ids.to(device)
26
+
27
+ # autoregressively generate sequence
28
+ outputs = model.generate(
29
+ pixel_values,
30
+ decoder_input_ids=decoder_input_ids,
31
+ max_length=model.decoder.config.max_position_embeddings,
32
+ early_stopping=True,
33
+ pad_token_id=processor.tokenizer.pad_token_id,
34
+ eos_token_id=processor.tokenizer.eos_token_id,
35
+ use_cache=True,
36
+ num_beams=1,
37
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
38
+ return_dict_in_generate=True,
39
+ )
40
+
41
+ # turn into JSON
42
+ seq = processor.batch_decode(outputs.sequences)[0]
43
+ seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
44
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
45
+ pred = processor.token2json(seq)
46
+ return pred
47
+
48
+ demo = gr.Interface(fn=predict,
49
+ inputs=gr.inputs.Image(type="pil"),
50
+ outputs="text",
51
+ examples=["image_0.png","image_1.png","image_2.png","image_3.png"],
52
+ )
53
+
54
+ demo.launch()