File size: 5,218 Bytes
b265364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np 
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import cv2
import networkx as nx  # <-- Added

def build_graph(pil_image, detections, annotations, class_names):
    def dist(p1, p2):
        return np.hypot(p1[0] - p2[0], p1[1] - p2[1])

    def angle_between(p1, p2):
        return np.degrees(np.arctan2(p2[1] - p1[1], p2[0] - p1[0])) % 180

    def lines_are_similar(line1, line2, max_distance=10, max_angle_diff=10):
        (x1, y1), (x2, y2) = line1
        (x3, y3), (x4, y4) = line2
        angle1 = angle_between((x1, y1), (x2, y2))
        angle2 = angle_between((x3, y3), (x4, y4))
        if abs(angle1 - angle2) > max_angle_diff:
            return False
        mid1 = ((x1 + x2) / 2, (y1 + y2) / 2)
        mid2 = ((x3 + x4) / 2, (y3 + y4) / 2)
        return dist(mid1, mid2) < max_distance

    def merge_similar_lines(lines):
        if not lines:
            return []
        merged, used = [], set()
        for i, l1 in enumerate(lines):
            if i in used: continue
            group = [l1]; used.add(i)
            for j, l2 in enumerate(lines):
                if j != i and j not in used and lines_are_similar(l1, l2):
                    group.append(l2); used.add(j)
            x_coords, y_coords = [], []
            for (x1, y1), (x2, y2) in group:
                x_coords.extend([x1, x2])
                y_coords.extend([y1, y2])
            merged.append(((int(min(x_coords)), int(min(y_coords))), (int(max(x_coords)), int(max(y_coords)))))
        return merged

    def point_inside_bbox(px, py, bbox):
        x1, y1, x2, y2 = bbox
        return x1 <= px <= x2 and y1 <= py <= y2

    def find_nearest_symbol(point, symbols, max_dist=15):
        px, py = point
        nearest_sym, nearest_dist = None, float('inf')
        for sym in symbols:
            sx, sy = sym['pos']
            d = dist((px, py), (sx, sy))
            if d < nearest_dist and d <= max_dist:
                nearest_sym, nearest_dist = sym, d
        if nearest_sym is None:
            for sym in symbols:
                if point_inside_bbox(px, py, sym['bbox']):
                    nearest_sym = sym
                    break
        return nearest_sym

    # Convert PIL image to OpenCV format
    image_cv = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)

    # Filter symbols
    allowed_types = {"connector", "crossing", "border_node"}
    symbols = []
    for idx, (box, class_id) in enumerate(zip(detections.xyxy, detections.class_id)):
        label = class_names[class_id]
        if label in allowed_types:
            x1, y1, x2, y2 = map(int, box)
            cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
            symbols.append({
                "id": f"{label}_{idx}",
                "type": label,
                "pos": (cx, cy),
                "bbox": (x1, y1, x2, y2)
            })

    # Hough line detection
    gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (3, 3), 0)
    edges = cv2.Canny(blurred, 50, 150, apertureSize=3)
    lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=50, maxLineGap=10)
    detected_lines = [((x1, y1), (x2, y2)) for line in lines for x1, y1, x2, y2 in line] if lines is not None else []

    merged_lines = merge_similar_lines(detected_lines)

    filtered_lines = []
    for pt1, pt2 in merged_lines:
        sym1 = find_nearest_symbol(pt1, symbols)
        sym2 = find_nearest_symbol(pt2, symbols)
        if sym1 and sym2 and sym1 != sym2:
            filtered_lines.append((pt1, pt2))

    # Draw results on image
    output = image_cv.copy()
    for sym in symbols:
        x1, y1, x2, y2 = sym["bbox"]
        cx, cy = sym["pos"]
        cv2.rectangle(output, (x1, y1), (x2, y2), (255, 0, 0), 2)
        cv2.circle(output, (cx, cy), 3, (0, 255, 255), -1)
        cv2.putText(output, sym["type"], (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, .6, (255, 0, 0), 1)

    for (x1, y1), (x2, y2) in filtered_lines:
        cv2.line(output, (x1, y1), (x2, y2), (0, 100, 255), 2)

    st.image(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)),
             caption="Graph: Merged Lines + Detected Symbols",
             use_column_width=True)

    # === Additional: Plot NetworkX graph ===
    # Ensure each symbol has a unique ID
    for i, sym in enumerate(symbols):
        sym['id'] = f"{sym['type']}_{i}"

    # Build graph
    G = nx.Graph()
    for sym in symbols:
        G.add_node(sym['id'], label=sym['type'], pos=sym['pos'])

    for pt1, pt2 in filtered_lines:
        sym1 = find_nearest_symbol(pt1, symbols)
        sym2 = find_nearest_symbol(pt2, symbols)
        if sym1 and sym2 and sym1['id'] != sym2['id']:
            G.add_edge(sym1['id'], sym2['id'])

    # Draw NetworkX graph in Streamlit
    fig, ax = plt.subplots(figsize=(8, 8))
    pos = {node: data['pos'] for node, data in G.nodes(data=True)}
    labels = {node: data['label'] for node, data in G.nodes(data=True)}
    nx.draw(G, pos, labels=labels, node_size=700, node_color='lightblue',
            font_size=8, with_labels=True, ax=ax)
    ax.set_title("Extracted Graph from Detected Symbols and Lines")
    st.pyplot(fig)


    return G