bachpc commited on
Commit
2a27a15
1 Parent(s): 9e9067e

Improve visualization

Browse files
Files changed (1) hide show
  1. app.py +86 -12
app.py CHANGED
@@ -4,10 +4,14 @@ import cv2
4
  import numpy as np
5
  import pandas as pd
6
  import torch
 
7
  # import sys
8
  # import json
9
  from collections import OrderedDict, defaultdict
10
  import xml.etree.ElementTree as ET
 
 
 
11
  from paddleocr import PaddleOCR
12
  import pytesseract
13
  from pytesseract import Output
@@ -80,10 +84,15 @@ def crop_image(pil_img, detection_result, padding=30):
80
  x2 = min(width, int((min_x + w / 2) * width) + padding)
81
  y2 = min(height, int((min_y + h / 2) * height) + padding)
82
  # print(x1, y1, x2, y2)
 
83
  crop_image = image[y1:y2, x1:x2, :]
84
- crop_images.append(cv_to_PIL(crop_image))
 
 
85
 
86
- cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
 
 
87
 
88
  return crop_images, cv_to_PIL(image)
89
 
@@ -169,15 +178,39 @@ def visualize_ocr(pil_img, ocr_result):
169
  x2 = int(bbox[2])
170
  y2 = int(bbox[3])
171
  cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
172
- cv2.putText(image, res['text'], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, color=(0, 0, 255))
173
  return cv_to_PIL(image)
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def visualize_structure(pil_img, structure_result):
177
  image = PIL_to_cv(pil_img)
178
  width = image.shape[1]
179
  height = image.shape[0]
180
  # print(width, height)
 
 
 
 
181
  for i, result in enumerate(structure_result):
182
  class_id = int(result[5])
183
  score = float(result[4])
@@ -191,24 +224,65 @@ def visualize_structure(pil_img, structure_result):
191
  x2 = int((min_x + w / 2) * width)
192
  y2 = int((min_y + h / 2) * height)
193
  # print(x1, y1, x2, y2)
 
194
 
195
  if score >= structure_class_thresholds[structure_class_names[class_id]]:
196
- cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
197
  #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
198
 
199
- return cv_to_PIL(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
 
202
  def visualize_cells(pil_img, cells):
203
- image = PIL_to_cv(pil_img)
 
 
204
  for i, cell in enumerate(cells):
205
  bbox = cell['bbox']
206
- x1 = int(bbox[0])
207
- y1 = int(bbox[1])
208
- x2 = int(bbox[2])
209
- y2 = int(bbox[3])
210
- cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
211
- return cv_to_PIL(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
  def pytess(cell_pil_img):
 
4
  import numpy as np
5
  import pandas as pd
6
  import torch
7
+ import io
8
  # import sys
9
  # import json
10
  from collections import OrderedDict, defaultdict
11
  import xml.etree.ElementTree as ET
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib.patches as patches
14
+
15
  from paddleocr import PaddleOCR
16
  import pytesseract
17
  from pytesseract import Output
 
84
  x2 = min(width, int((min_x + w / 2) * width) + padding)
85
  y2 = min(height, int((min_y + h / 2) * height) + padding)
86
  # print(x1, y1, x2, y2)
87
+
88
  crop_image = image[y1:y2, x1:x2, :]
89
+ crop_image = cv_to_PIL(crop_image)
90
+ if class_id == 1: # table rotated
91
+ crop_image = crop_image.rotate(270, expand=True)
92
 
93
+ crop_images.append(crop_image)
94
+
95
+ cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 0, 255))
96
 
97
  return crop_images, cv_to_PIL(image)
98
 
 
178
  x2 = int(bbox[2])
179
  y2 = int(bbox[3])
180
  cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
181
+ cv2.putText(image, res['text'], (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.25, color=(255, 0, 0))
182
  return cv_to_PIL(image)
183
 
184
 
185
+ def get_bbox_decorations(data_type, label):
186
+ if label == 0:
187
+ if data_type == 'detection':
188
+ return 'brown', 0.05, 3, '//'
189
+ else:
190
+ return 'brown', 0, 3, None
191
+ elif label == 1:
192
+ return 'red', 0.15, 2, None
193
+ elif label == 2:
194
+ return 'blue', 0.15, 2, None
195
+ elif label == 3:
196
+ return 'magenta', 0.2, 3, '//'
197
+ elif label == 4:
198
+ return 'cyan', 0.2, 4, '//'
199
+ elif label == 5:
200
+ return 'green', 0.2, 4, '\\\\'
201
+
202
+ return 'gray', 0, 0, None
203
+
204
+
205
  def visualize_structure(pil_img, structure_result):
206
  image = PIL_to_cv(pil_img)
207
  width = image.shape[1]
208
  height = image.shape[0]
209
  # print(width, height)
210
+
211
+ fig, ax = plt.subplots(1)
212
+ ax.imshow(pil_img, interpolation='lanczos')
213
+
214
  for i, result in enumerate(structure_result):
215
  class_id = int(result[5])
216
  score = float(result[4])
 
224
  x2 = int((min_x + w / 2) * width)
225
  y2 = int((min_y + h / 2) * height)
226
  # print(x1, y1, x2, y2)
227
+ bbox = [x1, y1, x2, y2]
228
 
229
  if score >= structure_class_thresholds[structure_class_names[class_id]]:
230
+ #cv2.rectangle(image, (x1, y1), (x2, y2), color=(0, 255, 0))
231
  #cv2.putText(image, str(i)+'-'+str(class_id), (x1-10, y1), cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0,0,255))
232
 
233
+ color, alpha, linewidth, hatch = get_bbox_decorations('recognition', class_id)
234
+ # Fill
235
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
236
+ linewidth=linewidth, alpha=alpha,
237
+ edgecolor='none',facecolor=color,
238
+ linestyle=None)
239
+ ax.add_patch(rect)
240
+ # Hatch
241
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
242
+ linewidth=1, alpha=0.4,
243
+ edgecolor=color,facecolor='none',
244
+ linestyle='--',hatch=hatch)
245
+ ax.add_patch(rect)
246
+ # Edge
247
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1],
248
+ linewidth=linewidth,
249
+ edgecolor=color,facecolor='none',
250
+ linestyle="--")
251
+ ax.add_patch(rect)
252
+
253
+ plt.axis('off')
254
+ img_buf = io.BytesIO()
255
+ plt.savefig(img_buf, bbox_inches='tight', dpi=100)
256
+
257
+ return PIL.Image.open(img_buf)
258
 
259
 
260
  def visualize_cells(pil_img, cells):
261
+ fig, ax = plt.subplots(1)
262
+ ax.imshow(pil_img, interpolation='lanczos')
263
+
264
  for i, cell in enumerate(cells):
265
  bbox = cell['bbox']
266
+ if cell['header']:
267
+ alpha = 0.3
268
+ else:
269
+ alpha = 0.125
270
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
271
+ edgecolor='none',facecolor="magenta", alpha=alpha)
272
+ ax.add_patch(rect)
273
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
274
+ edgecolor="magenta",facecolor='none',linestyle="--",
275
+ alpha=0.08, hatch='///')
276
+ ax.add_patch(rect)
277
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
278
+ edgecolor="magenta",facecolor='none',linestyle="--")
279
+ ax.add_patch(rect)
280
+
281
+ plt.axis('off')
282
+ img_buf = io.BytesIO()
283
+ plt.savefig(img_buf, bbox_inches='tight', dpi=100)
284
+
285
+ return PIL.Image.open(img_buf)
286
 
287
 
288
  def pytess(cell_pil_img):