Knightmovies commited on
Commit
3520813
Β·
verified Β·
1 Parent(s): 5cac9c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -71
app.py CHANGED
@@ -10,14 +10,10 @@ from scipy.spatial import distance as dist
10
  # ==============================================================================
11
  # App Configuration & Styling
12
  # ==============================================================================
13
- st.set_page_config(
14
- page_title="Document AI Toolkit",
15
- page_icon="πŸ€–",
16
- layout="wide"
17
- )
18
 
19
- # Inject CSS for a centered, fixed-width layout
20
- st.markdown("""
21
  <style>
22
  .main .block-container {
23
  max-width: 900px;
@@ -27,20 +23,29 @@ st.markdown("""
27
  padding-bottom: 2rem;
28
  }
29
  </style>
30
- """, unsafe_allow_html=True)
 
 
31
 
32
  # ==============================================================================
33
  # Model Loading (Cached)
34
  # ==============================================================================
35
  @st.cache_resource
36
  def load_model():
37
- """Loads the Table Transformer model and processor."""
38
- return TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition"), DetrImageProcessor.from_pretrained("microsoft/table-transformer-structure-recognition")
 
 
 
 
 
 
 
39
 
40
  model, processor = load_model()
41
 
42
  # ==============================================================================
43
- # Core Image Processing Functions (Unchanged)
44
  # ==============================================================================
45
  def order_points(pts):
46
  xSorted = pts[np.argsort(pts[:, 0]), :]
@@ -51,87 +56,151 @@ def order_points(pts):
51
  (br, tr) = rightMost[np.argsort(D)[::-1], :]
52
  return np.array([tl, tr, br, bl], dtype="float32")
53
 
54
- def perspective_transform(image, pts):
55
- rect = order_points(pts)
56
- (tl, tr, br, bl) = rect
57
- widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
58
- widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
59
- maxWidth = max(int(widthA), int(widthB))
60
- heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
61
- heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
62
- maxHeight = max(int(heightA), int(heightB))
63
- dst = np.array([[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32")
64
- M = cv2.getPerspectiveTransform(rect, dst)
65
- return cv2.warpPerspective(image, M, (maxWidth, maxHeight))
 
66
 
67
  def find_and_straighten_document(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
69
- _, mask = cv2.threshold(gray, 128, 255, cv2.THRESH_BINARY)
70
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
71
- if not contours: return None
72
- page_contour = max(contours, key=cv2.contourArea)
73
- if cv2.contourArea(page_contour) < (image.shape[0] * image.shape[1] * 0.1): return None
74
- box = cv2.boxPoints(cv2.minAreaRect(page_contour))
75
- return perspective_transform(image, box)
 
 
 
 
 
76
 
77
  def correct_orientation(image):
78
- """Robust orientation correction using a cascade approach."""
79
  try:
80
  osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5)
81
- rotation = osd['rotate']
82
- if rotation > 0:
83
- angle_map = {90: cv2.ROTATE_90_COUNTERCLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_CLOCKWISE}
 
 
 
 
 
84
  return cv2.rotate(image, angle_map[rotation])
85
  return image
86
  except Exception:
 
87
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
88
- thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
89
- orientations = {0: thresh, 90: cv2.rotate(thresh, cv2.ROTATE_90_CLOCKWISE), 180: cv2.rotate(thresh, cv2.ROTATE_180), 270: cv2.rotate(thresh, cv2.ROTATE_90_COUNTERCLOCKWISE)}
90
- best_rotation, max_horizontal_boxes = 0, -1
91
- for angle, rotated_img in orientations.items():
 
 
 
 
 
 
92
  try:
93
- data = pytesseract.image_to_data(rotated_img, output_type=pytesseract.Output.DICT, timeout=5)
94
- horizontal_boxes = sum(1 for i, conf in enumerate(data['conf']) if int(conf) > 10 and data['width'][i] > data['height'][i])
95
- if horizontal_boxes > max_horizontal_boxes:
96
- max_horizontal_boxes, best_rotation = horizontal_boxes, angle
 
 
 
 
 
 
97
  except Exception:
98
- continue
99
- angle_map = {90: cv2.ROTATE_90_CLOCKWISE, 180: cv2.ROTATE_180, 270: cv2.ROTATE_90_COUNTERCLOCKWISE}
100
- return cv2.rotate(image, angle_map[best_rotation]) if best_rotation > 0 else image
 
 
 
 
 
 
 
101
 
102
  def extract_and_draw_table_structure(image_bgr):
103
- """Finds and draws table structure using OpenCV."""
104
  image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
105
  inputs = processor(images=image_pil, return_tensors="pt")
106
- with torch.no_grad():
 
107
  outputs = model(**inputs)
108
- target_sizes = torch.tensor([image_pil.size[::-1]])
 
 
109
  results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
110
- img_with_boxes = image_bgr.copy()
 
111
  colors = {"table row": (0, 255, 0), "table column": (255, 0, 0), "table": (255, 0, 255)}
112
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
113
- class_name = model.config.id2label[label.item()]
114
- if class_name in colors:
115
- xmin, ymin, xmax, ymax = [int(val) for val in box.tolist()]
116
- cv2.rectangle(img_with_boxes, (xmin, ymin), (xmax, ymax), colors[class_name], 2)
117
- return img_with_boxes
 
 
 
 
 
118
 
119
  # ==============================================================================
120
  # Streamlit UI
121
  # ==============================================================================
122
 
123
- # --- Session State Management ---
124
  if "stage" not in st.session_state:
125
  st.session_state.stage = "upload"
126
  st.session_state.original_image = None
127
  st.session_state.processed_image = None
128
  st.session_state.annotated_image = None
129
 
130
- # --- Sidebar Controls ---
131
  with st.sidebar:
132
  st.title("πŸ€– Document AI Toolkit")
133
  st.markdown("---")
134
-
135
  if st.button("πŸ”„ Start Over", use_container_width=True):
136
  for key in list(st.session_state.keys()):
137
  del st.session_state[key]
@@ -139,31 +208,37 @@ with st.sidebar:
139
 
140
  if st.session_state.stage == "upload":
141
  st.header("Step 1: Upload Image")
142
- uploaded_file = st.file_uploader("Upload your document", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
 
 
143
  if uploaded_file:
144
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
145
  st.session_state.original_image = cv2.imdecode(file_bytes, 1)
146
  st.session_state.stage = "processing"
147
  st.rerun()
 
148
  elif st.session_state.stage == "processing":
149
  st.header("Step 2: Pre-process")
150
  if st.button("▢️ Start Pre-processing", use_container_width=True, type="primary"):
151
- with st.spinner("Straightening & correcting orientation..."):
152
  original_image = st.session_state.original_image
153
- straightened = find_and_straighten_document(original_image)
154
- image_to_orient = straightened if straightened is not None and straightened.size > 0 else original_image
155
- st.session_state.processed_image = correct_orientation(image_to_orient)
156
  st.session_state.stage = "analysis"
157
  st.rerun()
 
158
  elif st.session_state.stage == "analysis":
159
  st.header("Step 3: Analyze Table")
160
  if st.button("πŸ“Š Find Table Structure", use_container_width=True, type="primary"):
161
  with st.spinner("Running Table Transformer model..."):
162
- st.session_state.annotated_image = extract_and_draw_table_structure(st.session_state.processed_image)
 
 
163
  st.session_state.stage = "done"
164
  st.rerun()
165
 
166
- # --- Main Panel Display ---
167
  st.title("Document Processing Workflow")
168
 
169
  # Step 1: Upload
@@ -172,17 +247,27 @@ with expander1:
172
  if st.session_state.original_image is None:
173
  st.info("Please upload a document image using the sidebar to begin.")
174
  else:
175
- st.image(cv2.cvtColor(st.session_state.original_image, cv2.COLOR_BGR2RGB), use_container_width=True)
 
 
 
176
  st.success("Image uploaded successfully.")
177
 
178
  # Step 2: Pre-process
179
  if st.session_state.original_image is not None:
180
- expander2 = st.expander("Step 2: Pre-process Document", expanded=(st.session_state.stage == "processing" or st.session_state.stage == "analysis"))
 
 
 
181
  with expander2:
182
  if st.session_state.processed_image is None:
183
  st.info("Click 'Start Pre-processing' in the sidebar.")
184
  else:
185
- st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), caption="Straightened & Oriented", use_container_width=True)
 
 
 
 
186
  st.success("Pre-processing complete.")
187
 
188
  # Step 3: Analysis
@@ -194,9 +279,21 @@ if st.session_state.processed_image is not None:
194
  else:
195
  tab1, tab2 = st.tabs(["βœ… Corrected Document", "πŸ“Š Table Structure"])
196
  with tab1:
197
- st.image(cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB), use_container_width=True)
 
 
 
198
  _, buf = cv2.imencode(".jpg", st.session_state.processed_image)
199
- st.download_button("πŸ“₯ Download Clean Image", data=buf.tobytes(), file_name="corrected.jpg", mime="image/jpeg", use_container_width=True)
 
 
 
 
 
 
200
  with tab2:
201
- st.image(cv2.cvtColor(st.session_state.annotated_image, cv2.COLOR_BGR2RGB), use_container_width=True)
202
- st.success("Analysis complete.")
 
 
 
 
10
  # ==============================================================================
11
  # App Configuration & Styling
12
  # ==============================================================================
13
+ st.set_page_config(page_title="Document AI Toolkit", page_icon="πŸ€–", layout="wide")
 
 
 
 
14
 
15
+ st.markdown(
16
+ """
17
  <style>
18
  .main .block-container {
19
  max-width: 900px;
 
23
  padding-bottom: 2rem;
24
  }
25
  </style>
26
+ """,
27
+ unsafe_allow_html=True,
28
+ )
29
 
30
  # ==============================================================================
31
  # Model Loading (Cached)
32
  # ==============================================================================
33
  @st.cache_resource
34
  def load_model():
35
+ model = TableTransformerForObjectDetection.from_pretrained(
36
+ "microsoft/table-transformer-structure-recognition"
37
+ )
38
+ processor = DetrImageProcessor.from_pretrained(
39
+ "microsoft/table-transformer-structure-recognition"
40
+ )
41
+ model.eval()
42
+ return model, processor
43
+
44
 
45
  model, processor = load_model()
46
 
47
  # ==============================================================================
48
+ # Core Image Processing Functions
49
  # ==============================================================================
50
  def order_points(pts):
51
  xSorted = pts[np.argsort(pts[:, 0]), :]
 
56
  (br, tr) = rightMost[np.argsort(D)[::-1], :]
57
  return np.array([tl, tr, br, bl], dtype="float32")
58
 
59
+
60
+ def _four_point_warp(image, pts):
61
+ pts = order_points(pts.astype("float32"))
62
+ (tl, tr, br, bl) = pts
63
+ widthA = np.linalg.norm(br - bl)
64
+ widthB = np.linalg.norm(tr - tl)
65
+ heightA = np.linalg.norm(tr - br)
66
+ heightB = np.linalg.norm(tl - bl)
67
+ maxW, maxH = int(max(widthA, widthB)), int(max(heightA, heightB))
68
+ dst = np.array([[0, 0], [maxW - 1, 0], [maxW - 1, maxH - 1], [0, maxH - 1]], dtype="float32")
69
+ M = cv2.getPerspectiveTransform(pts, dst)
70
+ return cv2.warpPerspective(image, M, (maxW, maxH))
71
+
72
 
73
  def find_and_straighten_document(image):
74
+ """Find 4 page corners; fall back to minAreaRect if needed."""
75
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
76
+ gray = cv2.GaussianBlur(gray, (5, 5), 0)
77
+ edges = cv2.Canny(gray, 50, 150)
78
+ edges = cv2.dilate(edges, np.ones((3, 3), np.uint8), 1)
79
+
80
+ cnts, _ = cv2.findContours(edges, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
81
+ cnts = sorted(cnts, key=cv2.contourArea, reverse=True)[:10]
82
+
83
+ for c in cnts:
84
+ peri = cv2.arcLength(c, True)
85
+ approx = cv2.approxPolyDP(c, 0.02 * peri, True)
86
+ if len(approx) == 4:
87
+ return _four_point_warp(image, approx.reshape(4, 2))
88
+
89
+ if cnts:
90
+ box = cv2.boxPoints(cv2.minAreaRect(max(cnts, key=cv2.contourArea)))
91
+ return _four_point_warp(image, box)
92
+
93
+ return image
94
+
95
+
96
+ def deskew_slight(image):
97
+ """Remove small residual tilt so rows/cols are parallel to axes."""
98
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
99
+ thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
100
+ coords = np.column_stack(np.where(thr == 0)) # use ink pixels
101
+ if len(coords) < 100:
102
+ return image
103
+ angle = cv2.minAreaRect(coords)[-1]
104
+ angle = -(90 + angle) if angle < -45 else -angle
105
+ if abs(angle) < 0.3:
106
+ return image
107
+ (h, w) = image.shape[:2]
108
+ M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1.0)
109
+ return cv2.warpAffine(image, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE)
110
+
111
 
112
  def correct_orientation(image):
113
+ """Robust orientation correction using pytesseract + fallback."""
114
  try:
115
  osd = pytesseract.image_to_osd(image, output_type=pytesseract.Output.DICT, timeout=5)
116
+ rotation = int(osd.get("rotate", 0))
117
+ if rotation:
118
+ # pytesseract's 'rotate' is the CLOCKWISE angle to correct the image.
119
+ angle_map = {
120
+ 90: cv2.ROTATE_90_CLOCKWISE,
121
+ 180: cv2.ROTATE_180,
122
+ 270: cv2.ROTATE_90_COUNTERCLOCKWISE,
123
+ }
124
  return cv2.rotate(image, angle_map[rotation])
125
  return image
126
  except Exception:
127
+ # Fallback: choose the rotation with the most horizontal text boxes
128
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
129
+ thr = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
130
+ rots = {
131
+ 0: thr,
132
+ 90: cv2.rotate(thr, cv2.ROTATE_90_CLOCKWISE),
133
+ 180: cv2.rotate(thr, cv2.ROTATE_180),
134
+ 270: cv2.rotate(thr, cv2.ROTATE_90_COUNTERCLOCKWISE),
135
+ }
136
+ best = 0
137
+ best_count = -1
138
+ for ang, img in rots.items():
139
  try:
140
+ data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT, timeout=5)
141
+ cnt = sum(
142
+ 1
143
+ for i, c in enumerate(data["conf"])
144
+ if str(c).isdigit()
145
+ and int(c) > 10
146
+ and data["width"][i] > data["height"][i]
147
+ )
148
+ if cnt > best_count:
149
+ best, best_count = ang, cnt
150
  except Exception:
151
+ pass
152
+ if best:
153
+ angle_map = {
154
+ 90: cv2.ROTATE_90_CLOCKWISE,
155
+ 180: cv2.ROTATE_180,
156
+ 270: cv2.ROTATE_90_COUNTERCLOCKWISE,
157
+ }
158
+ return cv2.rotate(image, angle_map[best])
159
+ return image
160
+
161
 
162
  def extract_and_draw_table_structure(image_bgr):
163
+ """Run TableTransformer and draw table/table row/table column boxes."""
164
  image_pil = Image.fromarray(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
165
  inputs = processor(images=image_pil, return_tensors="pt")
166
+
167
+ with torch.inference_mode():
168
  outputs = model(**inputs)
169
+
170
+ h, w = image_bgr.shape[:2]
171
+ target_sizes = torch.tensor([[h, w]], dtype=torch.float32)
172
  results = processor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
173
+
174
+ img = image_bgr.copy()
175
  colors = {"table row": (0, 255, 0), "table column": (255, 0, 0), "table": (255, 0, 255)}
176
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
177
+ cls = model.config.id2label[label.item()]
178
+ if cls in colors:
179
+ xmin, ymin, xmax, ymax = [int(round(v)) for v in box.tolist()]
180
+ xmin = max(0, min(xmin, w - 1))
181
+ xmax = max(0, min(xmax, w - 1))
182
+ ymin = max(0, min(ymin, h - 1))
183
+ ymax = max(0, min(ymax, h - 1))
184
+ cv2.rectangle(img, (xmin, ymin), (xmax, ymax), colors[cls], 2)
185
+ return img
186
+
187
 
188
  # ==============================================================================
189
  # Streamlit UI
190
  # ==============================================================================
191
 
192
+ # Session state
193
  if "stage" not in st.session_state:
194
  st.session_state.stage = "upload"
195
  st.session_state.original_image = None
196
  st.session_state.processed_image = None
197
  st.session_state.annotated_image = None
198
 
199
+ # Sidebar
200
  with st.sidebar:
201
  st.title("πŸ€– Document AI Toolkit")
202
  st.markdown("---")
203
+
204
  if st.button("πŸ”„ Start Over", use_container_width=True):
205
  for key in list(st.session_state.keys()):
206
  del st.session_state[key]
 
208
 
209
  if st.session_state.stage == "upload":
210
  st.header("Step 1: Upload Image")
211
+ uploaded_file = st.file_uploader(
212
+ "Upload your document", type=["jpg", "jpeg", "png"], label_visibility="collapsed"
213
+ )
214
  if uploaded_file:
215
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
216
  st.session_state.original_image = cv2.imdecode(file_bytes, 1)
217
  st.session_state.stage = "processing"
218
  st.rerun()
219
+
220
  elif st.session_state.stage == "processing":
221
  st.header("Step 2: Pre-process")
222
  if st.button("▢️ Start Pre-processing", use_container_width=True, type="primary"):
223
+ with st.spinner("Correcting orientation, straightening & deskewing..."):
224
  original_image = st.session_state.original_image
225
+ oriented = correct_orientation(original_image)
226
+ straightened = find_and_straighten_document(oriented)
227
+ st.session_state.processed_image = deskew_slight(straightened)
228
  st.session_state.stage = "analysis"
229
  st.rerun()
230
+
231
  elif st.session_state.stage == "analysis":
232
  st.header("Step 3: Analyze Table")
233
  if st.button("πŸ“Š Find Table Structure", use_container_width=True, type="primary"):
234
  with st.spinner("Running Table Transformer model..."):
235
+ st.session_state.annotated_image = extract_and_draw_table_structure(
236
+ st.session_state.processed_image
237
+ )
238
  st.session_state.stage = "done"
239
  st.rerun()
240
 
241
+ # Main panel
242
  st.title("Document Processing Workflow")
243
 
244
  # Step 1: Upload
 
247
  if st.session_state.original_image is None:
248
  st.info("Please upload a document image using the sidebar to begin.")
249
  else:
250
+ st.image(
251
+ cv2.cvtColor(st.session_state.original_image, cv2.COLOR_BGR2RGB),
252
+ use_container_width=True,
253
+ )
254
  st.success("Image uploaded successfully.")
255
 
256
  # Step 2: Pre-process
257
  if st.session_state.original_image is not None:
258
+ expander2 = st.expander(
259
+ "Step 2: Pre-process Document",
260
+ expanded=(st.session_state.stage == "processing" or st.session_state.stage == "analysis"),
261
+ )
262
  with expander2:
263
  if st.session_state.processed_image is None:
264
  st.info("Click 'Start Pre-processing' in the sidebar.")
265
  else:
266
+ st.image(
267
+ cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB),
268
+ caption="Oriented β€’ Straightened β€’ Deskewed",
269
+ use_container_width=True,
270
+ )
271
  st.success("Pre-processing complete.")
272
 
273
  # Step 3: Analysis
 
279
  else:
280
  tab1, tab2 = st.tabs(["βœ… Corrected Document", "πŸ“Š Table Structure"])
281
  with tab1:
282
+ st.image(
283
+ cv2.cvtColor(st.session_state.processed_image, cv2.COLOR_BGR2RGB),
284
+ use_container_width=True,
285
+ )
286
  _, buf = cv2.imencode(".jpg", st.session_state.processed_image)
287
+ st.download_button(
288
+ "πŸ“₯ Download Clean Image",
289
+ data=buf.tobytes(),
290
+ file_name="corrected.jpg",
291
+ mime="image/jpeg",
292
+ use_container_width=True,
293
+ )
294
  with tab2:
295
+ st.image(
296
+ cv2.cvtColor(st.session_state.annotated_image, cv2.COLOR_BGR2RGB),
297
+ use_container_width=True,
298
+ )
299
+ st.success("Analysis complete.")