Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import base64 | |
| import json | |
| import pandas as pd | |
| import gradio as gr | |
| import numpy as np | |
| from roboflow import Roboflow | |
| from openai import OpenAI | |
| import re | |
| # ================= CONFIG ================= | |
| ROBOFLOW_API_KEY = "uP19IAi98TqwLvHmNB8V" | |
| ROBOFLOW_PROJECT = "terminal-block-jtgsl" | |
| ROBOFLOW_VERSION = 1 | |
| CONF_THRESHOLD = 0.30 | |
| IOU_THRESHOLD = 0.4 | |
| TERMINAL_JSON_PATH = "terminal.json" | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| rf = Roboflow(api_key=ROBOFLOW_API_KEY) | |
| model = rf.workspace().project(ROBOFLOW_PROJECT).version(ROBOFLOW_VERSION).model | |
| # ================= LOAD REFERENCE ================= | |
| def load_terminal_reference(): | |
| if not os.path.exists(TERMINAL_JSON_PATH): return {} | |
| try: | |
| with open(TERMINAL_JSON_PATH, "r") as f: | |
| data = json.load(f) | |
| return {str(i["terminal"]).strip().upper(): str(i["wire"]).strip().upper() | |
| for i in data.get("terminal_blocks", []) if i.get("wire")} | |
| except: return {} | |
| terminal_reference = load_terminal_reference() | |
| def clean_terminal(text): | |
| text = re.sub(r'[^0-9]', '', text) | |
| return text | |
| def clean_wire(text): | |
| text = text.upper().replace(" ", "") | |
| # Fix common OCR mistakes | |
| text = text.replace("O", "0") | |
| text = text.replace("I", "1") | |
| text = re.sub(r'[^A-Z0-9]', '', text) | |
| return text | |
| def is_valid_wire(wire): | |
| return bool(re.match(r'^[A-Z]{1,3}[0-9]{2,4}[A-Z]{0,2}$', wire)) | |
| def validate_and_fix(t, w): | |
| t = clean_terminal(t) | |
| w = clean_wire(w) | |
| if not t: | |
| return None, None | |
| if w in ["", "NONE", "N/A"]: | |
| w = terminal_reference.get(t, "NONE") | |
| if not is_valid_wire(w): | |
| if t in terminal_reference: | |
| w = terminal_reference[t] | |
| return t, w | |
| # ================= IMPROVED PREPROCESSING ================= | |
| def prepare_for_roboflow(img, max_side=1600): | |
| h, w = img.shape[:2] | |
| scale = min(max_side / max(h, w), 1) | |
| return cv2.resize(img, (int(w * scale), int(h * scale))) if scale < 1 else img | |
| def upscale(img): | |
| if img.size == 0: return img | |
| # High-quality upscale to prevent "11" from blurring into "1" | |
| h, w = img.shape[:2] | |
| scale = 800 / h if h < 800 else 1.0 | |
| return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_LANCZOS4) | |
| def enhance_variants(img): | |
| variants = [] | |
| if img.size == 0: return variants | |
| # Variant 1: Original | |
| variants.append(img) | |
| # Variant 2: Contrast Enhancement | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| clahe = cv2.createCLAHE(clipLimit=4.0, tileGridSize=(12, 12)) | |
| enhanced_gray = clahe.apply(gray) | |
| # Variant 3: Denoised & Sharpened (Crucial for thin characters) | |
| denoised = cv2.fastNlMeansDenoising(enhanced_gray, None, 10, 7, 21) | |
| kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) | |
| sharpened = cv2.filter2D(denoised, -1, kernel) | |
| variants.append(cv2.cvtColor(sharpened, cv2.COLOR_GRAY2BGR)) | |
| return variants | |
| def img_to_base64(img): | |
| _, buffer = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) | |
| return base64.b64encode(buffer).decode() | |
| # ================= PIPELINE LOGIC ================= | |
| def verify(terminal, wire): | |
| t, w = terminal.strip().upper(), wire.strip().upper() | |
| if t not in terminal_reference: return "UNKNOWN" | |
| ref = terminal_reference[t] | |
| if w in ["NONE", "EMPTY", "N/A", ""]: | |
| return "MATCH" if ref == "NONE" else f"MISSING (Exp {ref})" | |
| return "MATCH" if ref == w else f"MISMATCH (Exp {ref})" | |
| def fix_missing_wire(terminal, wire): | |
| terminal = terminal.strip().upper() | |
| wire = wire.strip().upper() | |
| # If OCR failed but reference exists → use reference | |
| if wire in ["NONE", "", "N/A"]: | |
| if terminal in terminal_reference: | |
| return terminal_reference[terminal] | |
| return wire | |
| def group_by_columns(detections, threshold=30): | |
| detections = sorted(detections, key=lambda x: x["center"][0]) | |
| columns = [] | |
| for det in detections: | |
| placed = False | |
| for col in columns: | |
| if abs(col[0]["center"][0] - det["center"][0]) < threshold: | |
| col.append(det) | |
| placed = True | |
| break | |
| if not placed: | |
| columns.append([det]) | |
| return columns | |
| def run_pipeline(image): | |
| if image is None: | |
| return None, pd.DataFrame() | |
| img = prepare_for_roboflow(image) | |
| H, W = img.shape[:2] | |
| # ================= DETECTION ================= | |
| preds = model.predict(img, confidence=int(CONF_THRESHOLD * 100)).json()["predictions"] | |
| wires, t_nums, w_nums, terms = [], [], [], [] | |
| for p in preds: | |
| x, y, w, h = map(int, [p["x"], p["y"], p["width"], p["height"]]) | |
| det = { | |
| "class": p["class"], | |
| "bbox": ( | |
| max(0, x - w // 2), | |
| max(0, y - h // 2), | |
| min(W, x + w // 2), | |
| min(H, y + h // 2) | |
| ), | |
| "center": (x, y) | |
| } | |
| if p["class"] == "Wire": | |
| wires.append(det) | |
| elif p["class"] == "Terminal Number": | |
| t_nums.append(det) | |
| elif p["class"] == "Wire Number": | |
| w_nums.append(det) | |
| elif p["class"] == "Terminal": | |
| terms.append(det) | |
| # ================= 🔥 NEW COLUMN GROUPING ================= | |
| columns = group_by_columns(t_nums + w_nums + terms, threshold=30) | |
| ocr_regions = [] | |
| for i, col in enumerate(columns): | |
| x1 = min(d["bbox"][0] for d in col) | |
| y1 = min(d["bbox"][1] for d in col) | |
| x2 = max(d["bbox"][2] for d in col) | |
| y2 = max(d["bbox"][3] for d in col) | |
| pad = 10 | |
| ocr_regions.append({ | |
| "union_bbox": ( | |
| max(0, x1 - pad), | |
| max(0, y1 - pad), | |
| min(W, x2 + pad), | |
| min(H, y2 + pad) | |
| ), | |
| "id": i | |
| }) | |
| # ================= GPT PROMPT ================= | |
| content = [{ | |
| "type": "text", | |
| "text": """ | |
| STRICT RULES: | |
| - One ID = one vertical column | |
| - Terminal = number below screws | |
| - Wire = text on white sleeve (ILxxx) | |
| - NEVER merge columns | |
| - NEVER skip digits | |
| - If unclear return NONE | |
| Output STRICT JSON: | |
| [{"id":0,"terminal":"77","wire":"IL23CA"}] | |
| """ | |
| }] | |
| # ================= IMAGE PREP ================= | |
| for region in ocr_regions: | |
| x1, y1, x2, y2 = region["union_bbox"] | |
| roi = img[y1:y2, x1:x2] | |
| roi = upscale(roi) | |
| content.append({"type": "text", "text": f"id:{region['id']}"}) | |
| for v in enhance_variants(roi): | |
| content.append({ | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{img_to_base64(v)}"} | |
| }) | |
| results = [] | |
| # ================= GPT OCR ================= | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[{"role": "user", "content": content}], | |
| temperature=0 | |
| ) | |
| res_text = response.choices[0].message.content | |
| match = re.search(r'\[.*\]', res_text, re.DOTALL) | |
| if match: | |
| parsed = json.loads(match.group()) | |
| for item in parsed: | |
| idx = item.get("id") | |
| if idx is not None and idx < len(ocr_regions): | |
| t = str(item.get("terminal", "")).strip() | |
| w = str(item.get("wire", "")).strip() | |
| t, w = validate_and_fix(t, w) | |
| w = fix_missing_wire(t, w) | |
| results.append({ | |
| "Terminal": t, | |
| "Wire": w, | |
| "Verification": verify(t, w), | |
| "bbox": ocr_regions[idx]["union_bbox"] | |
| }) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| # ================= SORT ================= | |
| def safe_int(x): | |
| digits = ''.join(filter(str.isdigit, x)) | |
| return int(digits) if digits else 999 | |
| results = sorted(results, key=lambda x: safe_int(x["Terminal"])) | |
| # ================= VISUAL ================= | |
| vis = img.copy() | |
| for r in results: | |
| x1, y1, x2, y2 = r["bbox"] | |
| color = (0, 255, 0) if "MATCH" in r["Verification"] else (0, 0, 255) | |
| cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2) | |
| cv2.putText( | |
| vis, | |
| f"T:{r['Terminal']}", | |
| (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| color, | |
| 2 | |
| ) | |
| return vis, pd.DataFrame(results).drop(columns=["bbox"], errors="ignore") | |
| # ================= UI ================= | |
| with gr.Blocks(title="Terminal Assembly Inspector") as demo: | |
| gr.Markdown("## Terminal Detector ") | |
| with gr.Row(): | |
| img_in = gr.Image(type="numpy", label="Input Rail") | |
| img_out = gr.Image(label="Detections (Red = Error)") | |
| btn = gr.Button("Analyze Entire Rail", variant="primary") | |
| table = gr.Dataframe(headers=["Terminal", "Wire", "Verification"]) | |
| btn.click(run_pipeline, [img_in], [img_out, table]) | |
| demo.launch() | |