Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
from doctr.models import ocr_predictor | |
from .models import Text, Point, Place, Transition, Arc # Import the Text class | |
from .commons import filter_enclosed_contours, minmaxToContours, remove_contours, find_closest_distance_to_contour | |
def _geometry_to_absolute_coords(relative_geom: tuple[tuple[float, float], tuple[float, float]], | |
img_width: int, img_height: int) -> tuple[tuple[int, int], tuple[int, int]]: | |
"""Converts doctr's relative coordinates to absolute integer coordinates.""" | |
(xmin_rel, ymin_rel), (xmax_rel, ymax_rel) = relative_geom | |
xmin_abs = int(xmin_rel * img_width) | |
ymin_abs = int(ymin_rel * img_height) | |
xmax_abs = int(xmax_rel * img_width) | |
ymax_abs = int(ymax_rel * img_height) | |
return ((xmin_abs, ymin_abs), (xmax_abs, ymax_abs)) | |
def detect_text(img_color_resized: np.ndarray, config: dict) -> list[Text]: | |
""" | |
Detects text using doctr and returns a list of Text objects with absolute coordinates. | |
(Implementation remains the same) | |
""" | |
predictor_params = config.get('text_detection', {}) | |
predictor = ocr_predictor( | |
det_arch='db_resnet50', | |
reco_arch='crnn_vgg16_bn', | |
pretrained=True, | |
) | |
predictor.det_predictor.model.postprocessor.bin_thresh = predictor_params.get('bin_thresh', 0.3) | |
predictor.det_predictor.model.postprocessor.box_thresh = predictor_params.get('box_thresh', 0.1) | |
out = predictor([img_color_resized]) | |
detected_texts: list[Text] = [] | |
img_height, img_width = img_color_resized.shape[:2] | |
if out.pages: | |
for block in out.pages[0].blocks: | |
for line in block.lines: | |
for word in line.words: | |
abs_geom = _geometry_to_absolute_coords(word.geometry, img_width, img_height) | |
text_obj = Text(value=word.value, | |
geometry_abs=abs_geom, | |
confidence=word.confidence) | |
detected_texts.append(text_obj) | |
return detected_texts | |
def get_img_no_text(preprocessed_img: np.ndarray, detected_texts: list[Text]) -> np.ndarray: | |
""" | |
Removes text from the preprocessed image by finding contours within text bounding boxes | |
and applying a mask using bitwise_and (similar to original notebook). | |
Args: | |
preprocessed_img: The thresholded image (e.g., from Otsu). | |
detected_texts: List of Text objects with absolute coordinates. | |
Returns: | |
The image with text contours removed (blacked out). | |
""" | |
if not detected_texts: | |
### throw an error | |
raise ValueError("No detected texts to process.") | |
img_contours_list, _ = cv2.findContours(preprocessed_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | |
bbox_to_contours = [minmaxToContours([text.pt1.x, text.pt1.y, text.pt2.x, text.pt2.y]) for text in detected_texts] | |
text_contours = filter_enclosed_contours(img_contours_list, bbox_to_contours, include_border=True) | |
img_no_text = remove_contours(preprocessed_img, text_contours) | |
return img_no_text | |
# Helper function to get the center of a Text object's bounding box | |
def get_text_center(text_obj: Text) -> Point: | |
"""Calculates the center point of a Text object's bounding box.""" | |
center_x = (text_obj.pt1.x + text_obj.pt2.x) / 2.0 | |
center_y = (text_obj.pt1.y + text_obj.pt2.y) / 2.0 | |
return Point(center_x, center_y) | |
def link_text_to_elements( | |
detected_text_list: list[Text], | |
places_list: list[Place], | |
transitions_list: list[Transition], | |
arcs_list: list[Arc], | |
config: dict | |
): | |
""" | |
Associates Text objects with the closest Place, Transition, or Arc | |
if the distance from the text's center to the element is within a given threshold. | |
The association is done by appending the Text object to the `text` list | |
of the corresponding Place, Transition, or Arc. | |
""" | |
distance_threshold = config.get('connection_processing', {}).get('text_linking_threshold', 25.0 ) | |
# Clear any previous text associations from elements | |
for element_list in [places_list, transitions_list, arcs_list]: | |
for element in element_list: | |
element.text = [] | |
for text_obj in detected_text_list: | |
text_center = get_text_center(text_obj) | |
min_overall_distance = float('inf') | |
closest_element_overall = None | |
# 1. Check Places | |
for place in places_list: | |
dist_to_place_center = text_center.get_distance_between_points(place.center) | |
distance = max(0, dist_to_place_center - place.radius) # Distance to circumference | |
if distance < min_overall_distance: | |
min_overall_distance = distance | |
closest_element_overall = place | |
# 2. Check Transitions | |
for transition in transitions_list: | |
contour_to_use = None | |
if transition.original_detection_data is not None and \ | |
isinstance(transition.original_detection_data, np.ndarray) and \ | |
transition.original_detection_data.shape[0] > 0: | |
contour_to_use = transition.original_detection_data | |
elif transition.points and len(transition.points) > 0: # Fallback to box_points | |
contour_to_use = np.array([p.get_numpy_array() for p in transition.points], dtype=np.int32).reshape((-1, 1, 2)) | |
if contour_to_use is None or contour_to_use.shape[0] == 0: | |
continue | |
distance = find_closest_distance_to_contour(text_center, contour_to_use) | |
if distance < min_overall_distance: | |
min_overall_distance = distance | |
closest_element_overall = transition | |
# 3. Check Arcs | |
for arc in arcs_list: | |
arc_contour_for_dist_calc = None | |
# Prioritize arc.points if available, as it represents the path | |
if arc.points and len(arc.points) >= 1: | |
arc_contour_for_dist_calc = np.array([p.get_numpy_array() for p in arc.points], dtype=np.int32).reshape((-1, 1, 2)) | |
# Fallback if arc.points is empty but start/end points are defined (simple line arc) | |
elif arc.start_point and arc.end_point: | |
arc_contour_for_dist_calc = np.array([ | |
arc.start_point.get_numpy_array(), | |
arc.end_point.get_numpy_array() | |
], dtype=np.int32).reshape((-1, 1, 2)) | |
# If arc is defined by arc.lines (more complex, potentially disjoint segments) | |
# This path is less common if arc.points is expected to be canonical. | |
elif arc.lines: | |
current_arc_min_dist_lines = float('inf') | |
for line_segment in arc.lines: | |
dist_to_segment = line_segment.distance_point_to_segment(text_center) | |
current_arc_min_dist_lines = min(current_arc_min_dist_lines, dist_to_segment) | |
if current_arc_min_dist_lines < min_overall_distance: | |
min_overall_distance = current_arc_min_dist_lines | |
closest_element_overall = arc | |
continue # Skip contour-based distance if lines were processed | |
if arc_contour_for_dist_calc is not None and arc_contour_for_dist_calc.shape[0] > 0: | |
distance = find_closest_distance_to_contour(text_center, arc_contour_for_dist_calc) | |
if distance < min_overall_distance: | |
min_overall_distance = distance | |
closest_element_overall = arc | |
# Associate text with the overall closest element if within threshold | |
if closest_element_overall is not None and min_overall_distance <= distance_threshold: | |
closest_element_overall.text.append(text_obj) | |
# print(f"Associated '{text_obj.value}' (center: {text_center}) with {closest_element_overall.__class__.__name__} id={id(closest_element_overall)} (dist: {min_overall_distance:.2f})") | |
# else: | |
# print(f"Text '{text_obj.value}' (center: {text_center}) not associated, min_dist {min_overall_distance:.2f} > threshold {distance_threshold}") | |