File size: 5,579 Bytes
7d2c620
 
 
 
 
 
e46dacf
7d2c620
 
 
 
 
e46dacf
7d2c620
 
 
 
 
 
45a3a9d
e46dacf
7d2c620
bc6b231
 
 
 
 
 
 
 
 
 
 
 
 
7d2c620
e46dacf
7d2c620
 
 
 
 
 
e24bf63
7d2c620
 
 
 
 
 
 
 
 
 
 
 
 
 
e46dacf
7d2c620
 
 
 
 
 
 
e24bf63
e46dacf
7d2c620
 
 
e24bf63
e46dacf
7d2c620
5ea165d
7d2c620
 
e46dacf
7d2c620
bc6b231
7d2c620
e46dacf
7d2c620
 
 
 
 
e46dacf
7d2c620
bc6b231
7d2c620
e46dacf
7d2c620
 
 
 
 
e46dacf
7d2c620
bc6b231
7d2c620
e46dacf
7d2c620
 
 
 
bc6b231
 
 
5ea165d
bc6b231
 
 
 
5ea165d
7d2c620
e46dacf
7d2c620
 
 
 
 
bc6b231
7d2c620
bc6b231
7d2c620
bc6b231
 
7d2c620
 
 
 
 
e46dacf
7d2c620
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
from inference_sdk import InferenceHTTPClient
from PIL import Image, ImageDraw
import os
from collections import defaultdict

# βœ… Load API key securely from Hugging Face Spaces Secrets
API_KEY = os.getenv("ROBOFLOW_API_KEY")

if not API_KEY:
    raise ValueError("API Key is missing! Set it in HF Space Secrets.")

# βœ… Initialize Roboflow Client
CLIENT = InferenceHTTPClient(
    api_url="https://detect.roboflow.com",
    api_key=API_KEY
)

MODEL_ID = "hvacsym/5"
CONFIDENCE_THRESHOLD = 0.2  # βœ… Confidence threshold for filtering predictions
GRID_SIZE = (3, 3)  # βœ… 3x3 segmentation

def format_counts_as_table(counts, pass_num):
    """Formats detection counts into a Markdown table for Gradio."""
    if not counts:
        return f"### Pass {pass_num}: No components detected."
    
    table = f"### Pass {pass_num} Detection Results:\n\n"
    table += "| Component | Count |\n"
    table += "|-----------|-------|\n"
    for component, count in counts.items():
        table += f"| {component} | {count} |\n"
    
    return table

def detect_components(image):
    """ Detect components in an uploaded image with three passes. """
    
    original_image = image.convert("RGB")
    width, height = original_image.size
    seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
    
    def process_detection(image, pass_num):
        """ Detect objects in an image segment and remove them if found. """
        final_image = image.copy()
        draw_final = ImageDraw.Draw(final_image)
        total_counts = defaultdict(int)
        detected_boxes = []
        
        for row in range(GRID_SIZE[0]):
            for col in range(GRID_SIZE[1]):
                x1, y1 = col * seg_w, row * seg_h
                x2, y2 = (col + 1) * seg_w, (row + 1) * seg_h
                
                segment = image.crop((x1, y1, x2, y2))
                segment_path = f"segment_{row}_{col}_pass{pass_num}.png"
                segment.save(segment_path)
                
                # βœ… Run inference
                result = CLIENT.infer(segment_path, model_id=MODEL_ID)
                filtered_predictions = [pred for pred in result["predictions"] if pred["confidence"] >= CONFIDENCE_THRESHOLD]
                
                for obj in filtered_predictions:
                    sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
                    class_name = obj["class"]
                    total_counts[class_name] += 1

                    # βœ… Convert segment coordinates to full image coordinates
                    x_min_full, y_min_full = x1 + sx - sw // 2, y1 + sy - sh // 2
                    x_max_full, y_max_full = x1 + sx + sw // 2, y1 + sy + sh // 2
                    detected_boxes.append((x_min_full, y_min_full, x_max_full, y_max_full))
                    
                    # βœ… Draw bounding box
                    draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline="green", width=2)

        return final_image, total_counts, detected_boxes

    # βœ… First pass detection
    image_after_pass1, counts_pass1, detected_boxes = process_detection(original_image, pass_num=1)
    counts_pass1_table = format_counts_as_table(counts_pass1, 1)

    # βœ… Mask detected areas for the second pass
    image_after_removal1 = original_image.copy()
    draw_removal1 = ImageDraw.Draw(image_after_removal1)
    for box in detected_boxes:
        draw_removal1.rectangle(box, fill=(255, 255, 255))

    # βœ… Second pass detection
    image_after_pass2, counts_pass2, detected_boxes = process_detection(image_after_removal1, pass_num=2)
    counts_pass2_table = format_counts_as_table(counts_pass2, 2)

    # βœ… Mask detected areas for the third pass
    image_after_removal2 = image_after_removal1.copy()
    draw_removal2 = ImageDraw.Draw(image_after_removal2)
    for box in detected_boxes:
        draw_removal2.rectangle(box, fill=(255, 255, 255))

    # βœ… Third pass detection
    image_after_pass3, counts_pass3, _ = process_detection(image_after_removal2, pass_num=3)
    counts_pass3_table = format_counts_as_table(counts_pass3, 3)

    # βœ… Sum counts from all passes
    final_counts = defaultdict(int)
    for label in set(counts_pass1) | set(counts_pass2) | set(counts_pass3):
        final_counts[label] = counts_pass1.get(label, 0) + counts_pass2.get(label, 0) + counts_pass3.get(label, 0)

    final_counts_table = format_counts_as_table(final_counts, "Final")

    # βœ… Return counts in Markdown table format
    return (
        image_after_pass1, counts_pass1_table,  
        image_after_pass2, counts_pass2_table,  
        image_after_pass3, counts_pass3_table,  
        final_counts_table
    )

# βœ… Gradio Interface
interface = gr.Interface(
    fn=detect_components,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detection Pass 1"),
        gr.Markdown(label="Counts After Pass 1"),  # βœ… Prettified Markdown Table
        gr.Image(type="pil", label="Detection Pass 2"),
        gr.Markdown(label="Counts After Pass 2"),  # βœ… Prettified Markdown Table
        gr.Image(type="pil", label="Detection Pass 3"),
        gr.Markdown(label="Counts After Pass 3"),  # βœ… Prettified Markdown Table
        gr.Markdown(label="Final Detected Components")  # βœ… Prettified Final Results
    ],
    title="HVAC Component Detector",
    description="Upload an image to detect HVAC components using Roboflow API across three passes."
)

# βœ… Launch the app
if __name__ == "__main__":
    interface.launch()