Edgar404 commited on
Commit
545dc78
β€’
1 Parent(s): dd4d174

Adding the necessary points

Browse files
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Candy Prototype
3
- emoji: πŸŒ–
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Donut Prototype
3
+ emoji: πŸƒ
4
+ colorFrom: red
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.24.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Demo.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Icb8zeoaudyTDOKM1QySNay1cXzltRAp
8
+ """
9
+
10
+ import gradio as gr
11
+ from PIL import Image
12
+ import re
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from warnings import simplefilter
17
+
18
+ simplefilter('ignore')
19
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+ # Seting up the model
22
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
23
+
24
+ print('Loading the base model ....')
25
+ base_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-recognition')
26
+ base_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-recognition')
27
+ print('Loading complete')
28
+
29
+ print('Loading the optimized model ....')
30
+ optimized_model = VisionEncoderDecoderModel.from_pretrained('Edgar404/donut-shivi-cheques_KD_320', torch_dtype = torch.bfloat16 )
31
+ optimized_processor = DonutProcessor.from_pretrained('Edgar404/donut-shivi-cheques_KD_320')
32
+ print('Loading complete')
33
+
34
+ # setting
35
+
36
+
37
+ def process_image(image , mode = 'optimized' ):
38
+ """ Function that takes an image and perform an OCR using the model DonUT via the task document
39
+ parsing
40
+
41
+ parameters
42
+ __________
43
+ image : a machine readable image of class PIL or numpy"""
44
+
45
+ model = optimized_model if mode == 'optimized' else base_model
46
+ processor = optimized_processor if mode == 'optimized' else base_processor
47
+ d_type = torch.bfloat16 if (mode == 'optimized' & device =='cuda') else torch.float32
48
+
49
+ model.to(device)
50
+ model.eval()
51
+
52
+
53
+ task_prompt = "<s_cord-v2>"
54
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
55
+
56
+ pixel_values = processor(image, return_tensors="pt").pixel_values
57
+
58
+ outputs = model.generate(
59
+ pixel_values.to(device , dtype = d_type),
60
+ decoder_input_ids=decoder_input_ids.to(device),
61
+ max_length=model.decoder.config.max_position_embeddings,
62
+ pad_token_id=processor.tokenizer.pad_token_id,
63
+ eos_token_id=processor.tokenizer.eos_token_id,
64
+ use_cache=True,
65
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
66
+ return_dict_in_generate=True,
67
+ )
68
+
69
+ sequence = processor.batch_decode(outputs.sequences)[0]
70
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
71
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
72
+ output = processor.token2json(sequence)
73
+
74
+ return output
75
+
76
+
77
+ def image_classifier(image , mode):
78
+ return process_image(image , mode)
79
+
80
+
81
+
82
+ examples_list = [['./test_images/test_0.jpg' ,"base"] ,
83
+ ['./test_images/test_1.jpg','base'],
84
+ ['./test_images/test_2.jpg' ,"base"],
85
+ ['./test_images/test_3.jpg','base'],
86
+ ['./test_images/test_4.jpg','base'],
87
+ ['./test_images/test_5.jpg' ,"base"],
88
+ ['./test_images/test_6.jpg' ,"base"],
89
+ ['./test_images/test_7.jpg','base'],
90
+ ['./test_images/test_8.jpg','base'],
91
+ ['./test_images/test_9.jpg','base']
92
+ ]
93
+
94
+ demo = gr.Interface(fn=image_classifier, inputs=["image",
95
+ gr.Radio(["base" , "optimized"], label="mode")],
96
+ outputs="text",
97
+ examples = examples_list )
98
+
99
+ demo.launch(share = True , debug = True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ pillow
4
+ gradio
5
+ peft
test_images/test_images/test_0.jpg ADDED
test_images/test_images/test_1.jpg ADDED
test_images/test_images/test_2.jpg ADDED
test_images/test_images/test_3.jpg ADDED
test_images/test_images/test_4.jpg ADDED
test_images/test_images/test_5.jpg ADDED
test_images/test_images/test_6.jpg ADDED
test_images/test_images/test_7.jpg ADDED
test_images/test_images/test_8.jpg ADDED
test_images/test_images/test_9.jpg ADDED