asammoud
Re-add large CSVs using Git LFS
b265364
raw
history blame
5.22 kB
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import cv2
import networkx as nx # <-- Added
def build_graph(pil_image, detections, annotations, class_names):
def dist(p1, p2):
return np.hypot(p1[0] - p2[0], p1[1] - p2[1])
def angle_between(p1, p2):
return np.degrees(np.arctan2(p2[1] - p1[1], p2[0] - p1[0])) % 180
def lines_are_similar(line1, line2, max_distance=10, max_angle_diff=10):
(x1, y1), (x2, y2) = line1
(x3, y3), (x4, y4) = line2
angle1 = angle_between((x1, y1), (x2, y2))
angle2 = angle_between((x3, y3), (x4, y4))
if abs(angle1 - angle2) > max_angle_diff:
return False
mid1 = ((x1 + x2) / 2, (y1 + y2) / 2)
mid2 = ((x3 + x4) / 2, (y3 + y4) / 2)
return dist(mid1, mid2) < max_distance
def merge_similar_lines(lines):
if not lines:
return []
merged, used = [], set()
for i, l1 in enumerate(lines):
if i in used: continue
group = [l1]; used.add(i)
for j, l2 in enumerate(lines):
if j != i and j not in used and lines_are_similar(l1, l2):
group.append(l2); used.add(j)
x_coords, y_coords = [], []
for (x1, y1), (x2, y2) in group:
x_coords.extend([x1, x2])
y_coords.extend([y1, y2])
merged.append(((int(min(x_coords)), int(min(y_coords))), (int(max(x_coords)), int(max(y_coords)))))
return merged
def point_inside_bbox(px, py, bbox):
x1, y1, x2, y2 = bbox
return x1 <= px <= x2 and y1 <= py <= y2
def find_nearest_symbol(point, symbols, max_dist=15):
px, py = point
nearest_sym, nearest_dist = None, float('inf')
for sym in symbols:
sx, sy = sym['pos']
d = dist((px, py), (sx, sy))
if d < nearest_dist and d <= max_dist:
nearest_sym, nearest_dist = sym, d
if nearest_sym is None:
for sym in symbols:
if point_inside_bbox(px, py, sym['bbox']):
nearest_sym = sym
break
return nearest_sym
# Convert PIL image to OpenCV format
image_cv = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
# Filter symbols
allowed_types = {"connector", "crossing", "border_node"}
symbols = []
for idx, (box, class_id) in enumerate(zip(detections.xyxy, detections.class_id)):
label = class_names[class_id]
if label in allowed_types:
x1, y1, x2, y2 = map(int, box)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
symbols.append({
"id": f"{label}_{idx}",
"type": label,
"pos": (cx, cy),
"bbox": (x1, y1, x2, y2)
})
# Hough line detection
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (3, 3), 0)
edges = cv2.Canny(blurred, 50, 150, apertureSize=3)
lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=50, maxLineGap=10)
detected_lines = [((x1, y1), (x2, y2)) for line in lines for x1, y1, x2, y2 in line] if lines is not None else []
merged_lines = merge_similar_lines(detected_lines)
filtered_lines = []
for pt1, pt2 in merged_lines:
sym1 = find_nearest_symbol(pt1, symbols)
sym2 = find_nearest_symbol(pt2, symbols)
if sym1 and sym2 and sym1 != sym2:
filtered_lines.append((pt1, pt2))
# Draw results on image
output = image_cv.copy()
for sym in symbols:
x1, y1, x2, y2 = sym["bbox"]
cx, cy = sym["pos"]
cv2.rectangle(output, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.circle(output, (cx, cy), 3, (0, 255, 255), -1)
cv2.putText(output, sym["type"], (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, .6, (255, 0, 0), 1)
for (x1, y1), (x2, y2) in filtered_lines:
cv2.line(output, (x1, y1), (x2, y2), (0, 100, 255), 2)
st.image(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)),
caption="Graph: Merged Lines + Detected Symbols",
use_column_width=True)
# === Additional: Plot NetworkX graph ===
# Ensure each symbol has a unique ID
for i, sym in enumerate(symbols):
sym['id'] = f"{sym['type']}_{i}"
# Build graph
G = nx.Graph()
for sym in symbols:
G.add_node(sym['id'], label=sym['type'], pos=sym['pos'])
for pt1, pt2 in filtered_lines:
sym1 = find_nearest_symbol(pt1, symbols)
sym2 = find_nearest_symbol(pt2, symbols)
if sym1 and sym2 and sym1['id'] != sym2['id']:
G.add_edge(sym1['id'], sym2['id'])
# Draw NetworkX graph in Streamlit
fig, ax = plt.subplots(figsize=(8, 8))
pos = {node: data['pos'] for node, data in G.nodes(data=True)}
labels = {node: data['label'] for node, data in G.nodes(data=True)}
nx.draw(G, pos, labels=labels, node_size=700, node_color='lightblue',
font_size=8, with_labels=True, ax=ax)
ax.set_title("Extracted Graph from Detected Symbols and Lines")
st.pyplot(fig)
return G