Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8c692eb
1
Parent(s):
1ebfb13
Unique colors
Browse files
app.py
CHANGED
@@ -8,9 +8,16 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
8 |
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
|
9 |
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
@spaces.GPU
|
12 |
def infer(img, text_queries, score_threshold, model):
|
13 |
-
|
14 |
if model == "dino":
|
15 |
queries = ""
|
16 |
for query in text_queries:
|
@@ -36,9 +43,9 @@ def infer(img, text_queries, score_threshold, model):
|
|
36 |
if score < score_threshold:
|
37 |
continue
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
return result_labels
|
43 |
|
44 |
def query_image(img, text_queries, dino_threshold):
|
@@ -47,8 +54,9 @@ def query_image(img, text_queries, dino_threshold):
|
|
47 |
annotations = []
|
48 |
for box, label in dino_output:
|
49 |
annotations.append({"label": label, "coordinates": {"x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1]}})
|
50 |
-
|
51 |
-
|
|
|
52 |
|
53 |
dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
|
54 |
dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
|
|
|
8 |
dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
|
9 |
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
|
10 |
|
11 |
+
def generate_colors(labels):
|
12 |
+
import random
|
13 |
+
random.seed(42)
|
14 |
+
colors = {}
|
15 |
+
for label in labels:
|
16 |
+
colors[label] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
|
17 |
+
return colors
|
18 |
+
|
19 |
@spaces.GPU
|
20 |
def infer(img, text_queries, score_threshold, model):
|
|
|
21 |
if model == "dino":
|
22 |
queries = ""
|
23 |
for query in text_queries:
|
|
|
43 |
if score < score_threshold:
|
44 |
continue
|
45 |
|
46 |
+
# Only include the labels that are part of the input text_queries
|
47 |
+
if model == "dino" and label in text_queries:
|
48 |
+
result_labels.append((box, label))
|
49 |
return result_labels
|
50 |
|
51 |
def query_image(img, text_queries, dino_threshold):
|
|
|
54 |
annotations = []
|
55 |
for box, label in dino_output:
|
56 |
annotations.append({"label": label, "coordinates": {"x": box[0], "y": box[1], "width": box[2] - box[0], "height": box[3] - box[1]}})
|
57 |
+
|
58 |
+
colors = generate_colors(text_queries)
|
59 |
+
return (img, {"boxes": annotations, "colors": colors})
|
60 |
|
61 |
dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
|
62 |
dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
|