Joyantac33 commited on
Commit
2c17cdc
1 Parent(s): fa81749

Upload 7 files

Browse files
Files changed (8) hide show
  1. .gitattributes +1 -0
  2. README.md +4 -4
  3. app.py +52 -47
  4. example_1.png +0 -0
  5. example_2.jpeg +0 -0
  6. gitattributes.txt +1 -5
  7. requirements.txt +2 -4
  8. waiting-ticket.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  sample_image_cord_test_receipt_00004.png filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  sample_image_cord_test_receipt_00004.png filter=lfs diff=lfs merge=lfs -text
36
+ waiting-ticket.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Donut Base Finetuned Cord V2
3
  emoji: 🍩
4
- colorFrom: blue
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.0.26
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Donut Docvqa
3
  emoji: 🍩
4
+ colorFrom: gray
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.1.4
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,52 +1,57 @@
1
- """
2
- Donut
3
- Copyright (c) 2022-present NAVER Corp.
4
- MIT License
5
-
6
- https://github.com/clovaai/donut
7
- """
8
  import gradio as gr
9
- import torch
10
- from PIL import Image
11
-
12
- from donut import DonutModel
13
-
14
-
15
- def _init_weights(DonutModel, module):
16
- pass
17
-
18
- def demo_process(input_img):
19
- global pretrained_model, task_prompt, task_name
20
- # input_img = Image.fromarray(input_img)
21
- output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
22
- return output
23
 
24
- task_prompt = f"<s_cord-v2>"
25
-
26
- image = Image.open("./sample_image_cord_test_receipt_00004.png")
27
- image.save("cord_sample_receipt1.png")
28
- image = Image.open("./sample_image_cord_test_receipt_00012.png")
29
- image.save("cord_sample_receipt2.png")
30
-
31
- DonutModel._init_weights= _init_weights
32
-
33
- pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-zhtrainticket",ignore_mismatched_sizes=True)
34
- pretrained_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  demo = gr.Interface(
37
- fn=demo_process,
38
- inputs= gr.inputs.Image(type="pil"),
39
  outputs="json",
40
- title=f"Donut 🍩 demonstration for `cord-v2` task",
41
- description="""This model is trained with 800 Indonesian receipt images of CORD dataset. <br>
42
- Demonstrations for other types of documents/tasks are available at https://github.com/clovaai/donut <br>
43
- More CORD receipt images are available at https://huggingface.co/datasets/naver-clova-ix/cord-v2
44
-
45
- More details are available at:
46
- - Paper: https://arxiv.org/abs/2111.15664
47
- - GitHub: https://github.com/clovaai/donut""",
48
- examples=[["cord_sample_receipt1.png"], ["cord_sample_receipt2.png"]],
49
- cache_examples=False,
50
- )
51
-
52
- demo.launch()
 
1
+ import re
 
 
 
 
 
 
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ import torch
5
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
6
+
7
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
8
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model.to(device)
12
+
13
+ def process_document(image, question):
14
+ # prepare encoder inputs
15
+ pixel_values = processor(image, return_tensors="pt").pixel_values
16
+
17
+ # prepare decoder inputs
18
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
19
+ prompt = task_prompt.replace("{user_input}", question)
20
+ decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
21
+
22
+ # generate answer
23
+ outputs = model.generate(
24
+ pixel_values.to(device),
25
+ decoder_input_ids=decoder_input_ids.to(device),
26
+ max_length=model.decoder.config.max_position_embeddings,
27
+ early_stopping=True,
28
+ pad_token_id=processor.tokenizer.pad_token_id,
29
+ eos_token_id=processor.tokenizer.eos_token_id,
30
+ use_cache=True,
31
+ num_beams=1,
32
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
33
+ return_dict_in_generate=True,
34
+ )
35
+
36
+ # postprocess
37
+ sequence = processor.batch_decode(outputs.sequences)[0]
38
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
39
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
40
+
41
+ return processor.token2json(sequence)
42
+
43
+ description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
44
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
45
 
46
  demo = gr.Interface(
47
+ fn=process_document,
48
+ inputs=["image", "text"],
49
  outputs="json",
50
+ title="Demo: Donut 🍩 for DocVQA",
51
+ description=description,
52
+ article=article,
53
+ enable_queue=True,
54
+ examples=[["example_1.png", "When is the coffee break?"], ["example_2.jpeg", "What's the population of Stoddard?"]],
55
+ cache_examples=False)
56
+
57
+ demo.launch()
 
 
 
 
 
example_1.png ADDED
example_2.jpeg ADDED
gitattributes.txt CHANGED
@@ -2,13 +2,11 @@
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.npy filter=lfs diff=lfs merge=lfs -text
@@ -16,13 +14,12 @@
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
 
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
@@ -32,4 +29,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
- sample_image_cord_test_receipt_00004.png filter=lfs diff=lfs merge=lfs -text
 
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
5
  *.ftz filter=lfs diff=lfs merge=lfs -text
6
  *.gz filter=lfs diff=lfs merge=lfs -text
7
  *.h5 filter=lfs diff=lfs merge=lfs -text
8
  *.joblib filter=lfs diff=lfs merge=lfs -text
9
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
10
  *.model filter=lfs diff=lfs merge=lfs -text
11
  *.msgpack filter=lfs diff=lfs merge=lfs -text
12
  *.npy filter=lfs diff=lfs merge=lfs -text
 
14
  *.onnx filter=lfs diff=lfs merge=lfs -text
15
  *.ot filter=lfs diff=lfs merge=lfs -text
16
  *.parquet filter=lfs diff=lfs merge=lfs -text
 
17
  *.pickle filter=lfs diff=lfs merge=lfs -text
18
  *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
22
  *.rar filter=lfs diff=lfs merge=lfs -text
 
23
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
  *.tar.* filter=lfs diff=lfs merge=lfs -text
25
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
  torch
2
- donut-python
3
- gradio
4
- transformers==4.24.0
5
- timm==0.6.13
 
1
  torch
2
+ git+https://github.com/huggingface/transformers.git
3
+ sentencepiece
 
 
waiting-ticket.png ADDED

Git LFS Details

  • SHA256: 921932cd4e5b7279e46a4baebe39e2f2faea452aca14717ab51644786b8a37d2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB