JustinLin610 commited on
Commit
582f2a6
·
1 Parent(s): 43773d5

add requirements

Browse files
Files changed (2) hide show
  1. app.py +30 -180
  2. requirements.txt +4 -0
app.py CHANGED
@@ -1,184 +1,35 @@
1
- import os
 
 
2
  import pandas as pd
 
3
 
4
- os.system('cd fairseq;'
5
- 'pip install ./; cd ..')
6
-
7
- os.system('cd ezocr;'
8
- 'pip install .; cd ..')
9
-
10
- import torch
11
- import numpy as np
12
- from fairseq import utils, tasks
13
- from fairseq import checkpoint_utils
14
- from utils.eval_utils import eval_step
15
- from data.mm_data.ocr_dataset import ocr_resize
16
- from tasks.mm_tasks.ocr import OcrTask
17
- from PIL import Image, ImageDraw
18
- from torchvision import transforms
19
- from typing import List, Tuple
20
- import cv2
21
- from easyocrlite import ReaderLite
22
  import gradio as gr
23
-
24
-
25
- # Register refcoco task
26
- tasks.register_task('ocr', OcrTask)
27
-
28
- if not os.path.exists("checkpoints/ocr_general_clean.pt"):
29
- os.system('wget https://shuangqing-multimodal.oss-cn-zhangjiakou.aliyuncs.com/ocr_general_clean.pt; '
30
- 'mkdir -p checkpoints; mv ocr_general_clean.pt checkpoints/ocr_general_clean.pt')
31
-
32
- # turn on cuda if GPU is available
33
- use_cuda = torch.cuda.is_available()
34
- # use fp16 only when GPU is available
35
- use_fp16 = True
36
-
37
- mean = [0.5, 0.5, 0.5]
38
- std = [0.5, 0.5, 0.5]
39
-
40
- Rect = Tuple[int, int, int, int]
41
- FourPoint = Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int]]
42
-
43
-
44
- reader = ReaderLite(gpu=True)
45
- overrides={"eval_cider": False, "beam": 5, "max_len_b": 64, "patch_image_size": 480,
46
- "orig_patch_image_size": 224, "interpolate_position": True,
47
- "no_repeat_ngram_size": 0, "seed": 42}
48
- models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
49
- utils.split_paths('checkpoints/ocr_general_clean.pt'),
50
- arg_overrides=overrides
51
- )
52
-
53
- # Move models to GPU
54
- for model in models:
55
- model.eval()
56
- if use_fp16:
57
- model.half()
58
- if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
59
- model.cuda()
60
- model.prepare_for_inference_(cfg)
61
-
62
- # Initialize generator
63
- generator = task.build_generator(models, cfg.generation)
64
-
65
- bos_item = torch.LongTensor([task.src_dict.bos()])
66
- eos_item = torch.LongTensor([task.src_dict.eos()])
67
- pad_idx = task.src_dict.pad()
68
-
69
-
70
- def four_point_transform(image: np.ndarray, rect: FourPoint) -> np.ndarray:
71
- (tl, tr, br, bl) = rect
72
-
73
- widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
74
- widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
75
- maxWidth = max(int(widthA), int(widthB))
76
-
77
- # compute the height of the new image, which will be the
78
- # maximum distance between the top-right and bottom-right
79
- # y-coordinates or the top-left and bottom-left y-coordinates
80
- heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
81
- heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
82
- maxHeight = max(int(heightA), int(heightB))
83
-
84
- dst = np.array(
85
- [[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]],
86
- dtype="float32",
87
- )
88
-
89
- # compute the perspective transform matrix and then apply it
90
- M = cv2.getPerspectiveTransform(rect, dst)
91
- warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
92
-
93
- return warped
94
-
95
-
96
- def get_images(img: str, reader: ReaderLite, **kwargs):
97
- results = reader.process(img, **kwargs)
98
- return results
99
-
100
-
101
- def draw_boxes(image, bounds, color='red', width=4):
102
- draw = ImageDraw.Draw(image)
103
- for i, bound in enumerate(bounds):
104
- p0, p1, p2, p3 = bound
105
- draw.text((p0[0]+5, p0[1]+5), str(i+1), fill=color, align='center')
106
- draw.line([*p0, *p1, *p2, *p3, *p0], fill=color, width=width)
107
- return image
108
-
109
-
110
- def encode_text(text, length=None, append_bos=False, append_eos=False):
111
- s = task.tgt_dict.encode_line(
112
- line=task.bpe.encode(text),
113
- add_if_not_exist=False,
114
- append_eos=False
115
- ).long()
116
- if length is not None:
117
- s = s[:length]
118
- if append_bos:
119
- s = torch.cat([bos_item, s])
120
- if append_eos:
121
- s = torch.cat([s, eos_item])
122
- return s
123
-
124
-
125
- def patch_resize_transform(patch_image_size=480, is_document=False):
126
- _patch_resize_transform = transforms.Compose(
127
- [
128
- lambda image: ocr_resize(
129
- image, patch_image_size, is_document=is_document, split='test',
130
- ),
131
- transforms.ToTensor(),
132
- transforms.Normalize(mean=mean, std=std),
133
- ]
134
- )
135
-
136
- return _patch_resize_transform
137
-
138
-
139
- # Construct input for caption task
140
- def construct_sample(image: Image, patch_image_size=480, is_document=False):
141
- patch_image = patch_resize_transform(patch_image_size, is_document=is_document)(image).unsqueeze(0)
142
- patch_mask = torch.tensor([True])
143
- src_text = encode_text("图片上的文字是什么?", append_bos=True, append_eos=True).unsqueeze(0)
144
- src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
145
- sample = {
146
- "id":np.array(['42']),
147
- "net_input": {
148
- "src_tokens": src_text,
149
- "src_lengths": src_length,
150
- "patch_images": patch_image,
151
- "patch_masks": patch_mask,
152
- },
153
- "target": None
154
  }
155
- return sample
156
-
157
-
158
- # Function to turn FP32 to FP16
159
- def apply_half(t):
160
- if t.dtype is torch.float32:
161
- return t.to(dtype=torch.half)
162
- return t
163
 
164
-
165
- def ocr(img):
166
- out_img = Image.open(img)
167
- results = get_images(img, reader, max_size=4000, text_confidence=0.7, text_threshold=0.4,
168
- link_threshold=0.4, slope_ths=0., add_margin=0.04)
169
- box_list, image_list = zip(*results)
170
- draw_boxes(out_img, box_list)
171
-
172
- ocr_result = []
173
- for i, (box, image) in enumerate(zip(box_list, image_list)):
174
- image = Image.fromarray(image)
175
- sample = construct_sample(image, cfg.task.patch_image_size)
176
- sample = utils.move_to_cuda(sample) if use_cuda else sample
177
- sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
178
-
179
- with torch.no_grad():
180
- result, scores = eval_step(task, generator, models, sample)
181
- ocr_result.append([str(i+1), result[0]['ocr'].replace(' ', '')])
182
 
183
  result = pd.DataFrame(ocr_result, columns=['Box ID', 'Text'])
184
 
@@ -193,10 +44,9 @@ description = "Gradio Demo for Chinese OCR based on OFA-Base. "\
193
  article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
194
  "Repo</a></p> "
195
  examples = [['shupai.png'], ['chinese.jpg'], ['gaidao.jpeg'],
196
- ['qiaodaima.png'], ['xsd.jpg']]
197
  io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
198
  outputs=[gr.outputs.Image(type='pil', label='Image'),
199
  gr.outputs.Dataframe(headers=['Box ID', 'Text'], type='pandas', label='OCR Results')],
200
- title=title, description=description, article=article, examples=examples)
201
- io.launch()
202
-
 
1
+ import base64
2
+ import json
3
+ from io import BytesIO
4
  import pandas as pd
5
+ from PIL import Image
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import gradio as gr
8
+ import requests
9
+
10
+
11
+ def ocr(image):
12
+
13
+ image = Image.open(image)
14
+ img_buffer = BytesIO()
15
+ image.save(img_buffer, format=image.format)
16
+ byte_data = img_buffer.getvalue()
17
+ base64_bytes = base64.b64encode(byte_data) # bytes
18
+ base64_str = base64_bytes.decode()
19
+ url = "https://www.modelscope.cn/api/v1/studio/damo/ofa_ocr_pipeline/gradio/api/predict/"
20
+ payload = json.dumps({
21
+ "data": [f"data:image/jpeg;base64,{base64_str}"],
22
+ "dataType": ["image"]
23
+ })
24
+ headers = {
25
+ 'Content-Type': 'application/json'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  }
 
 
 
 
 
 
 
 
27
 
28
+ response = requests.request("POST", url, headers=headers, data=payload)
29
+ jobj = json.loads(response.text)
30
+ out_img_base64 = jobj['Data']['data'][0].replace('data:image/png;base64,','')
31
+ out_img = Image.open(BytesIO(base64.urlsafe_b64decode(out_img_base64)))
32
+ ocr_result = jobj['Data']['data'][1]['data']
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  result = pd.DataFrame(ocr_result, columns=['Box ID', 'Text'])
35
 
 
44
  article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
45
  "Repo</a></p> "
46
  examples = [['shupai.png'], ['chinese.jpg'], ['gaidao.jpeg'],
47
+ ['qiaodaima.png'], ['xsd.jpg']]
48
  io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
49
  outputs=[gr.outputs.Image(type='pil', label='Image'),
50
  gr.outputs.Dataframe(headers=['Box ID', 'Text'], type='pandas', label='OCR Results')],
51
+ title=title, description=description, article=article)
52
+ io.launch()
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ pillow
3
+ pandas
4
+ requests