|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
import networkx as nx |
|
|
|
|
|
def build_graph(pil_image, detections, annotations, class_names): |
|
|
def dist(p1, p2): |
|
|
return np.hypot(p1[0] - p2[0], p1[1] - p2[1]) |
|
|
|
|
|
def angle_between(p1, p2): |
|
|
return np.degrees(np.arctan2(p2[1] - p1[1], p2[0] - p1[0])) % 180 |
|
|
|
|
|
def lines_are_similar(line1, line2, max_distance=10, max_angle_diff=10): |
|
|
(x1, y1), (x2, y2) = line1 |
|
|
(x3, y3), (x4, y4) = line2 |
|
|
angle1 = angle_between((x1, y1), (x2, y2)) |
|
|
angle2 = angle_between((x3, y3), (x4, y4)) |
|
|
if abs(angle1 - angle2) > max_angle_diff: |
|
|
return False |
|
|
mid1 = ((x1 + x2) / 2, (y1 + y2) / 2) |
|
|
mid2 = ((x3 + x4) / 2, (y3 + y4) / 2) |
|
|
return dist(mid1, mid2) < max_distance |
|
|
|
|
|
def merge_similar_lines(lines): |
|
|
if not lines: |
|
|
return [] |
|
|
merged, used = [], set() |
|
|
for i, l1 in enumerate(lines): |
|
|
if i in used: continue |
|
|
group = [l1]; used.add(i) |
|
|
for j, l2 in enumerate(lines): |
|
|
if j != i and j not in used and lines_are_similar(l1, l2): |
|
|
group.append(l2); used.add(j) |
|
|
x_coords, y_coords = [], [] |
|
|
for (x1, y1), (x2, y2) in group: |
|
|
x_coords.extend([x1, x2]) |
|
|
y_coords.extend([y1, y2]) |
|
|
merged.append(((int(min(x_coords)), int(min(y_coords))), (int(max(x_coords)), int(max(y_coords))))) |
|
|
return merged |
|
|
|
|
|
def point_inside_bbox(px, py, bbox): |
|
|
x1, y1, x2, y2 = bbox |
|
|
return x1 <= px <= x2 and y1 <= py <= y2 |
|
|
|
|
|
def find_nearest_symbol(point, symbols, max_dist=15): |
|
|
px, py = point |
|
|
nearest_sym, nearest_dist = None, float('inf') |
|
|
for sym in symbols: |
|
|
sx, sy = sym['pos'] |
|
|
d = dist((px, py), (sx, sy)) |
|
|
if d < nearest_dist and d <= max_dist: |
|
|
nearest_sym, nearest_dist = sym, d |
|
|
if nearest_sym is None: |
|
|
for sym in symbols: |
|
|
if point_inside_bbox(px, py, sym['bbox']): |
|
|
nearest_sym = sym |
|
|
break |
|
|
return nearest_sym |
|
|
|
|
|
|
|
|
image_cv = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
allowed_types = {"connector", "crossing", "border_node"} |
|
|
symbols = [] |
|
|
for idx, (box, class_id) in enumerate(zip(detections.xyxy, detections.class_id)): |
|
|
label = class_names[class_id] |
|
|
if label in allowed_types: |
|
|
x1, y1, x2, y2 = map(int, box) |
|
|
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 |
|
|
symbols.append({ |
|
|
"id": f"{label}_{idx}", |
|
|
"type": label, |
|
|
"pos": (cx, cy), |
|
|
"bbox": (x1, y1, x2, y2) |
|
|
}) |
|
|
|
|
|
|
|
|
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY) |
|
|
blurred = cv2.GaussianBlur(gray, (3, 3), 0) |
|
|
edges = cv2.Canny(blurred, 50, 150, apertureSize=3) |
|
|
lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=50, maxLineGap=10) |
|
|
detected_lines = [((x1, y1), (x2, y2)) for line in lines for x1, y1, x2, y2 in line] if lines is not None else [] |
|
|
|
|
|
merged_lines = merge_similar_lines(detected_lines) |
|
|
|
|
|
filtered_lines = [] |
|
|
for pt1, pt2 in merged_lines: |
|
|
sym1 = find_nearest_symbol(pt1, symbols) |
|
|
sym2 = find_nearest_symbol(pt2, symbols) |
|
|
if sym1 and sym2 and sym1 != sym2: |
|
|
filtered_lines.append((pt1, pt2)) |
|
|
|
|
|
|
|
|
output = image_cv.copy() |
|
|
for sym in symbols: |
|
|
x1, y1, x2, y2 = sym["bbox"] |
|
|
cx, cy = sym["pos"] |
|
|
cv2.rectangle(output, (x1, y1), (x2, y2), (255, 0, 0), 2) |
|
|
cv2.circle(output, (cx, cy), 3, (0, 255, 255), -1) |
|
|
cv2.putText(output, sym["type"], (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, .6, (255, 0, 0), 1) |
|
|
|
|
|
for (x1, y1), (x2, y2) in filtered_lines: |
|
|
cv2.line(output, (x1, y1), (x2, y2), (0, 100, 255), 2) |
|
|
|
|
|
st.image(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)), |
|
|
caption="Graph: Merged Lines + Detected Symbols", |
|
|
use_column_width=True) |
|
|
|
|
|
|
|
|
|
|
|
for i, sym in enumerate(symbols): |
|
|
sym['id'] = f"{sym['type']}_{i}" |
|
|
|
|
|
|
|
|
G = nx.Graph() |
|
|
for sym in symbols: |
|
|
G.add_node(sym['id'], label=sym['type'], pos=sym['pos']) |
|
|
|
|
|
for pt1, pt2 in filtered_lines: |
|
|
sym1 = find_nearest_symbol(pt1, symbols) |
|
|
sym2 = find_nearest_symbol(pt2, symbols) |
|
|
if sym1 and sym2 and sym1['id'] != sym2['id']: |
|
|
G.add_edge(sym1['id'], sym2['id']) |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 8)) |
|
|
pos = {node: data['pos'] for node, data in G.nodes(data=True)} |
|
|
labels = {node: data['label'] for node, data in G.nodes(data=True)} |
|
|
nx.draw(G, pos, labels=labels, node_size=700, node_color='lightblue', |
|
|
font_size=8, with_labels=True, ax=ax) |
|
|
ax.set_title("Extracted Graph from Detected Symbols and Lines") |
|
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
return G |
|
|
|