Spaces:
Sleeping
Sleeping
import gradio as gr | |
from inference_sdk import InferenceHTTPClient | |
from PIL import Image, ImageDraw | |
import os | |
from collections import defaultdict | |
# β Load API key securely from Hugging Face Spaces Secrets | |
API_KEY = os.getenv("ROBOFLOW_API_KEY") | |
if not API_KEY: | |
raise ValueError("API Key is missing! Set it in HF Space Secrets.") | |
# β Initialize Roboflow Client | |
CLIENT = InferenceHTTPClient( | |
api_url="https://detect.roboflow.com", | |
api_key=API_KEY | |
) | |
MODEL_ID = "hvacsym/5" | |
CONFIDENCE_THRESHOLD = 0.2 # β Confidence threshold for filtering predictions | |
GRID_SIZE = (3, 3) # β 3x3 segmentation | |
def format_counts_as_table(counts, pass_num): | |
"""Formats detection counts into a Markdown table for Gradio.""" | |
if not counts: | |
return f"### Pass {pass_num}: No components detected." | |
table = f"### Pass {pass_num} Detection Results:\n\n" | |
table += "| Component | Count |\n" | |
table += "|-----------|-------|\n" | |
for component, count in counts.items(): | |
table += f"| {component} | {count} |\n" | |
return table | |
def detect_components(image): | |
""" Detect components in an uploaded image with three passes. """ | |
original_image = image.convert("RGB") | |
width, height = original_image.size | |
seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0] | |
def process_detection(image, pass_num): | |
""" Detect objects in an image segment and remove them if found. """ | |
final_image = image.copy() | |
draw_final = ImageDraw.Draw(final_image) | |
total_counts = defaultdict(int) | |
detected_boxes = [] | |
for row in range(GRID_SIZE[0]): | |
for col in range(GRID_SIZE[1]): | |
x1, y1 = col * seg_w, row * seg_h | |
x2, y2 = (col + 1) * seg_w, (row + 1) * seg_h | |
segment = image.crop((x1, y1, x2, y2)) | |
segment_path = f"segment_{row}_{col}_pass{pass_num}.png" | |
segment.save(segment_path) | |
# β Run inference | |
result = CLIENT.infer(segment_path, model_id=MODEL_ID) | |
filtered_predictions = [pred for pred in result["predictions"] if pred["confidence"] >= CONFIDENCE_THRESHOLD] | |
for obj in filtered_predictions: | |
sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"] | |
class_name = obj["class"] | |
total_counts[class_name] += 1 | |
# β Convert segment coordinates to full image coordinates | |
x_min_full, y_min_full = x1 + sx - sw // 2, y1 + sy - sh // 2 | |
x_max_full, y_max_full = x1 + sx + sw // 2, y1 + sy + sh // 2 | |
detected_boxes.append((x_min_full, y_min_full, x_max_full, y_max_full)) | |
# β Draw bounding box | |
draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline="green", width=2) | |
return final_image, total_counts, detected_boxes | |
# β First pass detection | |
image_after_pass1, counts_pass1, detected_boxes = process_detection(original_image, pass_num=1) | |
counts_pass1_table = format_counts_as_table(counts_pass1, 1) | |
# β Mask detected areas for the second pass | |
image_after_removal1 = original_image.copy() | |
draw_removal1 = ImageDraw.Draw(image_after_removal1) | |
for box in detected_boxes: | |
draw_removal1.rectangle(box, fill=(255, 255, 255)) | |
# β Second pass detection | |
image_after_pass2, counts_pass2, detected_boxes = process_detection(image_after_removal1, pass_num=2) | |
counts_pass2_table = format_counts_as_table(counts_pass2, 2) | |
# β Mask detected areas for the third pass | |
image_after_removal2 = image_after_removal1.copy() | |
draw_removal2 = ImageDraw.Draw(image_after_removal2) | |
for box in detected_boxes: | |
draw_removal2.rectangle(box, fill=(255, 255, 255)) | |
# β Third pass detection | |
image_after_pass3, counts_pass3, _ = process_detection(image_after_removal2, pass_num=3) | |
counts_pass3_table = format_counts_as_table(counts_pass3, 3) | |
# β Sum counts from all passes | |
final_counts = defaultdict(int) | |
for label in set(counts_pass1) | set(counts_pass2) | set(counts_pass3): | |
final_counts[label] = counts_pass1.get(label, 0) + counts_pass2.get(label, 0) + counts_pass3.get(label, 0) | |
final_counts_table = format_counts_as_table(final_counts, "Final") | |
# β Return counts in Markdown table format | |
return ( | |
image_after_pass1, counts_pass1_table, | |
image_after_pass2, counts_pass2_table, | |
image_after_pass3, counts_pass3_table, | |
final_counts_table | |
) | |
# β Gradio Interface | |
interface = gr.Interface( | |
fn=detect_components, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Image(type="pil", label="Detection Pass 1"), | |
gr.Markdown(label="Counts After Pass 1"), # β Prettified Markdown Table | |
gr.Image(type="pil", label="Detection Pass 2"), | |
gr.Markdown(label="Counts After Pass 2"), # β Prettified Markdown Table | |
gr.Image(type="pil", label="Detection Pass 3"), | |
gr.Markdown(label="Counts After Pass 3"), # β Prettified Markdown Table | |
gr.Markdown(label="Final Detected Components") # β Prettified Final Results | |
], | |
title="HVAC Component Detector", | |
description="Upload an image to detect HVAC components using Roboflow API across three passes." | |
) | |
# β Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |