iakarshu commited on
Commit
847d138
1 Parent(s): 3f82d7f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """LiLT For Deployment
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1ol6RWyff15SF6ZJPf47X5380hBTEDiUH
8
+ """
9
+
10
+ # ## Installing the dependencies (might take some time)
11
+
12
+ # !pip install -q pytesseract
13
+ # !sudo apt install -q tesseract-ocr
14
+ # !pip install -q transformers
15
+ # !pip install -q pytorch-lightning
16
+ # !pip install -q einops
17
+ # !pip install -q tqdm
18
+ # !pip install -q gradio
19
+ # !pip install -q Pillow==7.1.2
20
+ # !pip install -q wandb
21
+ # !pip install -q gdown
22
+ # !pip install -q torchmetrics
23
+
24
+ ## Requirements.txt
25
+ import os
26
+ os.system('pip install pyyaml==5.1')
27
+ ## install PyTesseract
28
+ os.system('pip install -q pytesseract')
29
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
+
31
+ import pandas as pd
32
+ import os
33
+ from PIL import Image
34
+ from transformers import RobertaTokenizer
35
+ import torch
36
+ from torch.utils.data import Dataset, DataLoader
37
+ import torch.nn as nn
38
+ import pytorch_lightning as pl
39
+
40
+ from dataset import create_features
41
+ from modeling import LiLT
42
+ from utils import LiLTPL
43
+
44
+ import gdown
45
+ import gradio as gr
46
+
47
+ seed = 42
48
+
49
+ ## One can change this configuration and try out new combination
50
+ config = {
51
+ "hidden_dropout_prob": 0.1,
52
+ "hidden_size_t": 768,
53
+ "hidden_size" : 768,
54
+ "hidden_size_l": 768 // 6,
55
+ "intermediate_ff_size_factor": 4,
56
+ "max_2d_position_embeddings": 1001,
57
+ "max_seq_len_l": 512,
58
+ "max_seq_len_t" : 512,
59
+ "num_attention_heads": 12,
60
+ "num_hidden_layers": 12,
61
+ 'dim_head' : 64,
62
+ "shape_size": 96,
63
+ "vocab_size": 50265,
64
+ "eps": 1e-12,
65
+ "fine_tune" : True
66
+ }
67
+
68
+ id2label = ['scientific_report',
69
+ 'resume',
70
+ 'memo',
71
+ 'file_folder',
72
+ 'specification',
73
+ 'news_article',
74
+ 'letter',
75
+ 'form',
76
+ 'budget',
77
+ 'handwritten',
78
+ 'email',
79
+ 'invoice',
80
+ 'presentation',
81
+ 'scientific_publication',
82
+ 'questionnaire',
83
+ 'advertisement']
84
+
85
+ ## Defining tokenizer
86
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
87
+
88
+ url = 'https://drive.google.com/uc?id=1eRV4fS_LFwI5MHqcRwLUNQZgewxI6Se_'
89
+ output = 'lilt_ckpt.ckpt'
90
+ gdown.download(url, output, quiet=False)
91
+
92
+ class RVLCDIPData(Dataset):
93
+
94
+ def __init__(self, image_list, label_list, tokenizer, max_len = 512, size = 1000):
95
+
96
+ self.image_list = image_list
97
+ self.label_list = label_list
98
+ self.tokenizer = tokenizer
99
+ self.max_seq_length = max_len
100
+ self.size = size
101
+
102
+ def __len__(self):
103
+ return len(self.image_list)
104
+
105
+ def __getitem__(self, idx):
106
+ img_path = self.image_list[idx]
107
+ label = self.label_list[idx]
108
+
109
+ boxes, words, normal_box = create_features(
110
+ img_path = img_path,
111
+ tokenizer = self.tokenizer,
112
+ max_seq_length = self.max_seq_length,
113
+ size = self.size,
114
+ use_ocr = True,
115
+ )
116
+
117
+ final_encoding = {'input_boxes': boxes, 'input_words': words}
118
+ final_encoding['label'] = torch.as_tensor(label).long()
119
+
120
+ return final_encoding
121
+
122
+ lilt = LiLTPL(config)
123
+ # path_to_weights = 'drive/MyDrive/docformer_rvl_checkpoint/docformer_v1.ckpt'
124
+ lilt.load_from_checkpoint('lilt_ckpt.ckpt')
125
+
126
+ ## Taken from LayoutLMV2 space
127
+
128
+ image = gr.inputs.Image(type="pil")
129
+ label = gr.outputs.Label(num_top_classes=5)
130
+ examples = [['00093726.png'], ['00866042.png']]
131
+ title = "Interactive demo: LiLT for Image Classification"
132
+ description = "Demo for classifying document images with LiLT model. To use it, \
133
+ simply upload an image or use the example images below and click 'submit' to let the model predict the 5 most probable Document classes. \
134
+ Results will show up in a few seconds."
135
+
136
+ def classify_image(image):
137
+
138
+ image.save('sample_img.png')
139
+ boxes, words, normal_box = create_features(
140
+ img_path = 'sample_img.png',
141
+ tokenizer = tokenizer,
142
+ max_seq_length = 512,
143
+ size = 1000,
144
+ use_ocr = True,
145
+ )
146
+
147
+ final_encoding = {'input_boxes': boxes.unsqueeze(0), 'input_words': words.unsqueeze(0)}
148
+ output = lilt.forward(final_encoding)
149
+ output = output[0].softmax(axis = -1)
150
+
151
+ final_pred = {}
152
+ for i, score in enumerate(output):
153
+ score = output[i]
154
+ final_pred[id2label[i]] = score.detach().cpu().tolist()
155
+
156
+ return final_pred
157
+
158
+ gr.Interface(fn=classify_image, inputs=image, outputs=label, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True)
159
+