7eu7d7 commited on
Commit
3e9654b
·
verified ·
1 Parent(s): 3768ea9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -21
app.py CHANGED
@@ -4,33 +4,16 @@ import numpy as np
4
  import gradio as gr
5
 
6
 
7
- siglip_checkpoint = "nielsr/siglip-base-patch16-224"
8
- clip_checkpoint = "openai/clip-vit-base-patch16"
9
  clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
10
- siglip_model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
11
- siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
12
 
13
 
14
  def postprocess(output):
15
  return {out["label"]: float(out["score"]) for out in output}
16
 
17
- def postprocess_siglip(output, labels):
18
- return {labels[i]: float(np.array(output[0])[i]) for i in range(len(labels))}
19
-
20
- def siglip_detector(image, texts):
21
- inputs = siglip_processor(text=texts, images=image, return_tensors="pt",
22
- padding="max_length")
23
-
24
- with torch.no_grad():
25
- outputs = siglip_model(**inputs)
26
- logits_per_image = outputs.logits_per_image
27
- probs = torch.sigmoid(logits_per_image)
28
- return probs
29
-
30
 
31
  def infer(image, candidate_labels):
32
  candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
33
- siglip_out = siglip_detector(image, candidate_labels)
34
  clip_out = clip_detector(image, candidate_labels=candidate_labels)
35
  return postprocess(clip_out), postprocess_siglip(siglip_out, labels=candidate_labels)
36
 
@@ -46,14 +29,12 @@ with gr.Blocks() as demo:
46
 
47
  with gr.Column():
48
  clip_output = gr.Label(label = "CLIP Output", num_top_classes=3)
49
- siglip_output = gr.Label(label = "SigLIP Output", num_top_classes=3)
50
 
51
  examples = [["./baklava.jpg", "baklava, souffle, tiramisu"]]
52
  gr.Examples(
53
  examples = examples,
54
  inputs=[image_input, text_input],
55
  outputs=[clip_output,
56
- siglip_output
57
  ],
58
  fn=infer,
59
  cache_examples=True
@@ -61,7 +42,6 @@ with gr.Blocks() as demo:
61
  run_button.click(fn=infer,
62
  inputs=[image_input, text_input],
63
  outputs=[clip_output,
64
- siglip_output
65
  ])
66
 
67
  demo.launch()
 
4
  import gradio as gr
5
 
6
 
7
+ clip_checkpoint = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
 
8
  clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
 
 
9
 
10
 
11
  def postprocess(output):
12
  return {out["label"]: float(out["score"]) for out in output}
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def infer(image, candidate_labels):
16
  candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
 
17
  clip_out = clip_detector(image, candidate_labels=candidate_labels)
18
  return postprocess(clip_out), postprocess_siglip(siglip_out, labels=candidate_labels)
19
 
 
29
 
30
  with gr.Column():
31
  clip_output = gr.Label(label = "CLIP Output", num_top_classes=3)
 
32
 
33
  examples = [["./baklava.jpg", "baklava, souffle, tiramisu"]]
34
  gr.Examples(
35
  examples = examples,
36
  inputs=[image_input, text_input],
37
  outputs=[clip_output,
 
38
  ],
39
  fn=infer,
40
  cache_examples=True
 
42
  run_button.click(fn=infer,
43
  inputs=[image_input, text_input],
44
  outputs=[clip_output,
 
45
  ])
46
 
47
  demo.launch()