JustinLin610 commited on
Commit
0c80503
•
1 Parent(s): 204969e
Files changed (1) hide show
  1. app.py +30 -31
app.py CHANGED
@@ -111,12 +111,33 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
111
  return _patch_resize_transform
112
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # Construct input for caption task
115
  def construct_sample(task, image: Image, patch_image_size=480):
116
- bos_item = torch.LongTensor([task.src_dict.bos()])
117
- eos_item = torch.LongTensor([task.src_dict.eos()])
118
- pad_idx = task.src_dict.pad()
119
-
120
  patch_image = patch_resize_transform(patch_image_size)(image).unsqueeze(0)
121
  patch_mask = torch.tensor([True])
122
  src_text = encode_text(task, "图片上的文字是什么?", append_bos=True, append_eos=True).unsqueeze(0)
@@ -141,35 +162,11 @@ def apply_half(t):
141
  return t
142
 
143
 
144
- def ocr(ckpt, img, out_img):
145
- reader = ReaderLite()
146
- overrides={"eval_cider":False, "beam":8, "max_len_b":128, "patch_image_size":480, "orig_patch_image_size":224, "no_repeat_ngram_size":0, "seed":7}
147
- models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
148
- utils.split_paths(ckpt),
149
- arg_overrides=overrides
150
- )
151
-
152
- # Move models to GPU
153
- for model in models:
154
- model.eval()
155
- if use_fp16:
156
- model.half()
157
- if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
158
- model.cuda()
159
- model.prepare_for_inference_(cfg)
160
-
161
- # Initialize generator
162
- generator = task.build_generator(models, cfg.generation)
163
-
164
- bos_item = torch.LongTensor([task.src_dict.bos()])
165
- eos_item = torch.LongTensor([task.src_dict.eos()])
166
- pad_idx = task.src_dict.pad()
167
-
168
  orig_image = Image.open(img)
169
  results = get_images(img, reader)
170
  box_list, image_list = zip(*results)
171
  draw_boxes(orig_image, box_list)
172
- orig_image.save(out_img)
173
 
174
  ocr_result = []
175
  for box, image in zip(box_list, image_list):
@@ -183,7 +180,8 @@ def ocr(ckpt, img, out_img):
183
  ocr_result.append(result[0]['ocr'].replace(' ', ''))
184
 
185
  result = '\n'.join(ocr_result)
186
- return result
 
187
 
188
 
189
  title = "OFA-OCR"
@@ -192,7 +190,8 @@ description = "Gradio Demo for OFA-OCR. Upload your own image or click any one o
192
  article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
193
  "Repo</a></p> "
194
  examples = [['lihe.png'], ['chinese.jpg'], ['paibian.jpeg'], ['shupai.png'], ['zuowen.jpg']]
195
- io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='pil'), outputs=gr.outputs.Textbox(label="Caption"),
 
196
  title=title, description=description, article=article, examples=examples,
197
  allow_flagging=False, allow_screenshot=False)
198
  io.launch(cache_examples=True)
 
111
  return _patch_resize_transform
112
 
113
 
114
+ reader = ReaderLite()
115
+ overrides={"eval_cider":False, "beam":8, "max_len_b":128, "patch_image_size":480,
116
+ "orig_patch_image_size":224, "no_repeat_ngram_size":0, "seed":7}
117
+ models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
118
+ utils.split_paths('checkpoints/ocr.pt'),
119
+ arg_overrides=overrides
120
+ )
121
+
122
+ # Move models to GPU
123
+ for model in models:
124
+ model.eval()
125
+ if use_fp16:
126
+ model.half()
127
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
128
+ model.cuda()
129
+ model.prepare_for_inference_(cfg)
130
+
131
+ # Initialize generator
132
+ generator = task.build_generator(models, cfg.generation)
133
+
134
+ bos_item = torch.LongTensor([task.src_dict.bos()])
135
+ eos_item = torch.LongTensor([task.src_dict.eos()])
136
+ pad_idx = task.src_dict.pad()
137
+
138
+
139
  # Construct input for caption task
140
  def construct_sample(task, image: Image, patch_image_size=480):
 
 
 
 
141
  patch_image = patch_resize_transform(patch_image_size)(image).unsqueeze(0)
142
  patch_mask = torch.tensor([True])
143
  src_text = encode_text(task, "图片上的文字是什么?", append_bos=True, append_eos=True).unsqueeze(0)
 
162
  return t
163
 
164
 
165
+ def ocr(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  orig_image = Image.open(img)
167
  results = get_images(img, reader)
168
  box_list, image_list = zip(*results)
169
  draw_boxes(orig_image, box_list)
 
170
 
171
  ocr_result = []
172
  for box, image in zip(box_list, image_list):
 
180
  ocr_result.append(result[0]['ocr'].replace(' ', ''))
181
 
182
  result = '\n'.join(ocr_result)
183
+
184
+ return orig_image, result
185
 
186
 
187
  title = "OFA-OCR"
 
190
  article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
191
  "Repo</a></p> "
192
  examples = [['lihe.png'], ['chinese.jpg'], ['paibian.jpeg'], ['shupai.png'], ['zuowen.jpg']]
193
+ io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='pil'),
194
+ outputs=[gr.outputs.Image(type='pil'), gr.outputs.Textbox(label="OCR result")],
195
  title=title, description=description, article=article, examples=examples,
196
  allow_flagging=False, allow_screenshot=False)
197
  io.launch(cache_examples=True)