Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,33 +4,16 @@ import numpy as np
|
|
4 |
import gradio as gr
|
5 |
|
6 |
|
7 |
-
|
8 |
-
clip_checkpoint = "openai/clip-vit-base-patch16"
|
9 |
clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
|
10 |
-
siglip_model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
|
11 |
-
siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
|
12 |
|
13 |
|
14 |
def postprocess(output):
|
15 |
return {out["label"]: float(out["score"]) for out in output}
|
16 |
|
17 |
-
def postprocess_siglip(output, labels):
|
18 |
-
return {labels[i]: float(np.array(output[0])[i]) for i in range(len(labels))}
|
19 |
-
|
20 |
-
def siglip_detector(image, texts):
|
21 |
-
inputs = siglip_processor(text=texts, images=image, return_tensors="pt",
|
22 |
-
padding="max_length")
|
23 |
-
|
24 |
-
with torch.no_grad():
|
25 |
-
outputs = siglip_model(**inputs)
|
26 |
-
logits_per_image = outputs.logits_per_image
|
27 |
-
probs = torch.sigmoid(logits_per_image)
|
28 |
-
return probs
|
29 |
-
|
30 |
|
31 |
def infer(image, candidate_labels):
|
32 |
candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
|
33 |
-
siglip_out = siglip_detector(image, candidate_labels)
|
34 |
clip_out = clip_detector(image, candidate_labels=candidate_labels)
|
35 |
return postprocess(clip_out), postprocess_siglip(siglip_out, labels=candidate_labels)
|
36 |
|
@@ -46,14 +29,12 @@ with gr.Blocks() as demo:
|
|
46 |
|
47 |
with gr.Column():
|
48 |
clip_output = gr.Label(label = "CLIP Output", num_top_classes=3)
|
49 |
-
siglip_output = gr.Label(label = "SigLIP Output", num_top_classes=3)
|
50 |
|
51 |
examples = [["./baklava.jpg", "baklava, souffle, tiramisu"]]
|
52 |
gr.Examples(
|
53 |
examples = examples,
|
54 |
inputs=[image_input, text_input],
|
55 |
outputs=[clip_output,
|
56 |
-
siglip_output
|
57 |
],
|
58 |
fn=infer,
|
59 |
cache_examples=True
|
@@ -61,7 +42,6 @@ with gr.Blocks() as demo:
|
|
61 |
run_button.click(fn=infer,
|
62 |
inputs=[image_input, text_input],
|
63 |
outputs=[clip_output,
|
64 |
-
siglip_output
|
65 |
])
|
66 |
|
67 |
demo.launch()
|
|
|
4 |
import gradio as gr
|
5 |
|
6 |
|
7 |
+
clip_checkpoint = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
|
|
|
8 |
clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
|
|
|
|
|
9 |
|
10 |
|
11 |
def postprocess(output):
|
12 |
return {out["label"]: float(out["score"]) for out in output}
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def infer(image, candidate_labels):
|
16 |
candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
|
|
|
17 |
clip_out = clip_detector(image, candidate_labels=candidate_labels)
|
18 |
return postprocess(clip_out), postprocess_siglip(siglip_out, labels=candidate_labels)
|
19 |
|
|
|
29 |
|
30 |
with gr.Column():
|
31 |
clip_output = gr.Label(label = "CLIP Output", num_top_classes=3)
|
|
|
32 |
|
33 |
examples = [["./baklava.jpg", "baklava, souffle, tiramisu"]]
|
34 |
gr.Examples(
|
35 |
examples = examples,
|
36 |
inputs=[image_input, text_input],
|
37 |
outputs=[clip_output,
|
|
|
38 |
],
|
39 |
fn=infer,
|
40 |
cache_examples=True
|
|
|
42 |
run_button.click(fn=infer,
|
43 |
inputs=[image_input, text_input],
|
44 |
outputs=[clip_output,
|
|
|
45 |
])
|
46 |
|
47 |
demo.launch()
|