Nikhil Agarwal commited on
Commit
cd33cdc
1 Parent(s): 71b21e9

Add application file

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import re
4
+ import torch
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
10
+
11
+ title = "OCR using Donut"
12
+ description = """
13
+ This demo application uses `naver-clova-ix/donut-base` model to extract text from images.
14
+ """
15
+ article = "Check out [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) documentation that this demo is based off of."
16
+
17
+ checkpoint = "naver-clova-ix/donut-base"
18
+
19
+ processor = DonutProcessor.from_pretrained(checkpoint)
20
+ model = VisionEncoderDecoderModel.from_pretrained(checkpoint)
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model.to(device)
24
+
25
+ # prepare decoder inputs
26
+ task_prompt = "<s_synthdog>"
27
+ decoder_input_ids = processor.tokenizer(
28
+ task_prompt, add_special_tokens=False, return_tensors="pt"
29
+ ).input_ids
30
+
31
+
32
+ def convert_image_GRAY2BGR(image):
33
+ if len(np.asarray(image).shape) != 3:
34
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_GRAY2BGR)
35
+ image = Image.fromarray(np.uint8(image))
36
+ return image
37
+
38
+
39
+ def predict(image):
40
+ image = convert_image_GRAY2BGR(image)
41
+ pixel_values = processor(image, return_tensors="pt").pixel_values
42
+
43
+ outputs = model.generate(
44
+ pixel_values.to(device),
45
+ decoder_input_ids=decoder_input_ids.to(device),
46
+ max_length=model.decoder.config.max_position_embeddings,
47
+ early_stopping=True,
48
+ pad_token_id=processor.tokenizer.pad_token_id,
49
+ eos_token_id=processor.tokenizer.eos_token_id,
50
+ use_cache=True,
51
+ num_beams=1,
52
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
53
+ return_dict_in_generate=True,
54
+ )
55
+
56
+ sequence = processor.batch_decode(outputs.sequences)[0]
57
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
58
+ processor.tokenizer.pad_token, ""
59
+ )
60
+ sequence = re.sub(
61
+ r"<.*?>", "", sequence, count=1
62
+ ).strip() # remove first task start token
63
+ return processor.token2json(sequence)["text_sequence"]
64
+
65
+
66
+ # We instantiate the Textbox class
67
+ input_textbox = gr.Textbox(
68
+ label="Type your prompt here:", placeholder="John Doe", lines=2
69
+ )
70
+
71
+ gr.Interface(
72
+ fn=predict,
73
+ inputs="image",
74
+ outputs="text",
75
+ title=title,
76
+ description=description,
77
+ article=article,
78
+ examples=[
79
+ os.path.join(os.path.dirname(__file__), "../data/sample/sample-1.png"),
80
+ os.path.join(os.path.dirname(__file__), "../data/sample/lorem_ipsum.png"),
81
+ ],
82
+ ).launch()