File size: 1,763 Bytes
ed309ba
 
 
ebf76ce
ed309ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf76ce
ed309ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebf76ce
ed309ba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr, numpy as np
from utils import SAM, GD
from groundingdino.util.utils import clean_text
from PIL import Image
import cv2, torch

def pipeline(image, prompt):
    # 1. segmenta con SAM
    img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    SAM.set_image(img_cv)
    masks, _, _ = SAM.predict(box=None, point_coords=None, point_labels=None, multimask_output=False)

    annotated = image.copy()
    boxes, labels, feats = [], [], []

    for m in masks:
        coords = np.argwhere(m)
        y1, x1 = coords.min(0)
        y2, x2 = coords.max(0)
        box = np.array([x1, y1, x2, y2])
        boxes.append(box)

    if boxes:
        # 2. grounding DINO zero‑shot
        dino_out = GD.predict_with_caption(
            image=np.array(image),
            captions=[prompt] * len(boxes),
            boxes=np.vstack(boxes)
        )
        for box, text in zip(dino_out["boxes"], dino_out["captions"]):
            x1,y1,x2,y2 = map(int, box)
            cv2.rectangle(annotated, (x1,y1), (x2,y2), (255,0,0), 2)
            cv2.putText(annotated, clean_text(text), (x1, y1-6),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,0,0), 2)

    return Image.fromarray(annotated)

demo = gr.Interface(
    fn=pipeline,
    inputs=[
        gr.Image(type="pil"),
        gr.Textbox(value="lamiera, foro circolare, vite, bullone, scanalatura")
    ],
    outputs=gr.Image(type="pil"),
    title="Zero‑Shot Mechanical Part Finder",
    description=(
        "Carica una foto di componenti meccanici a fine vita e scrivi le etichette "
        "che vuoi cercare (separate da virgole). Il sistema segmenta con SAM e fa "
        "grounding zero‑shot con GroundingDINO."
    )
)

if __name__ == "__main__":
    demo.launch()