merve HF staff commited on
Commit
5e733a4
1 Parent(s): 37ea8d2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, CLIPProcessor, SiglipModel, AutoProcessor
2
+ import requests
3
+ from PIL import Image
4
+ from modeling_nllb_clip import NLLBCLIPModel
5
+ import torch.nn.functional as F
6
+ from sentence_transformers import SentenceTransformer, util
7
+ from PIL import Image, ImageFile
8
+ import requests
9
+ import torch
10
+ import numpy as np
11
+ import gradio as gr
12
+
13
+ ## NLLB Inference
14
+ nllb_clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
15
+ nllb_clip_processor = nllb_clip_processor.image_processor
16
+ nllb_clip_tokenizer = AutoTokenizer.from_pretrained(
17
+ "facebook/nllb-200-distilled-600M"
18
+ )
19
+
20
+ def nllb_clip_inference(image,labels):
21
+ labels = labels.split(",")
22
+ image_inputs = nllb_clip_processor(images=image, return_tensors="pt")
23
+ text_inputs = nllb_clip_tokenizer(labels, padding="longest", return_tensors="pt",)
24
+ nllb_clip_model = NLLBCLIPModel.from_pretrained("visheratin/nllb-clip-base")
25
+
26
+ outputs = nllb_clip_model(input_ids = text_inputs.input_ids, attention_mask = text_inputs.attention_mask, pixel_values=image_inputs.pixel_values)
27
+ normalized_tensor = F.softmax(outputs["logits_per_text"], dim=0)
28
+ normalized_tensor = normalized_tensor.detach().numpy()
29
+ return {labels[i]: float(np.array(normalized_tensor)[i]) for i in range(len(labels))}
30
+
31
+ # SentenceTransformers CLIP-ViT-B-32
32
+ img_model = SentenceTransformer('clip-ViT-B-32')
33
+ text_model = SentenceTransformer('sentence-transformers/clip-ViT-B-32-multilingual-v1')
34
+
35
+ def infer_st(image, texts):
36
+ texts = texts.split(",")
37
+ img_embeddings = img_model.encode(image)
38
+ text_embeddings = text_model.encode(texts)
39
+ cos_sim = util.cos_sim(text_embeddings, img_embeddings)
40
+ return {texts[i]: float(np.array(cos_sim)[i]) for i in range(len(texts))}
41
+
42
+ ### SigLIP Inference
43
+
44
+ siglip_model = SiglipModel.from_pretrained("google/siglip-base-patch16-256-multilingual")
45
+ siglip_processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-256-multilingual")
46
+
47
+
48
+ def postprocess_siglip(output, labels):
49
+ return {labels[i]: float(np.array(output[0])[i]) for i in range(len(labels))}
50
+
51
+ def siglip_detector(image, texts):
52
+ inputs = siglip_processor(text=texts, images=image, return_tensors="pt",
53
+ padding="max_length")
54
+
55
+ with torch.no_grad():
56
+ outputs = siglip_model(**inputs)
57
+ logits_per_image = outputs.logits_per_image
58
+ probs = torch.sigmoid(logits_per_image)
59
+
60
+ return probs
61
+
62
+
63
+ def infer_siglip(image, candidate_labels):
64
+ candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
65
+ siglip_out = siglip_detector(image, candidate_labels)
66
+ return postprocess_siglip(siglip_out, labels=candidate_labels)
67
+
68
+ def infer(image, labels):
69
+ st_out = infer_st(image, labels)
70
+ nllb_out = nllb_clip_inference(image, labels)
71
+ siglip_out = infer_siglip(image, labels)
72
+ return st_out, siglip_out, nllb_out
73
+
74
+
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown("# Compare Multilingual Zero-shot Image Classification")
77
+ gr.Markdown("Compare the performance of SigLIP and othe rmodels on zero-shot classification in this Space 👇")
78
+ with gr.Row():
79
+ with gr.Column():
80
+ image_input = gr.Image(type="pil")
81
+ text_input = gr.Textbox(label="Input a list of labels")
82
+ run_button = gr.Button("Run", visible=True)
83
+
84
+ with gr.Column():
85
+ st_output = gr.Label(label = "CLIP-ViT Multilingual Output", num_top_classes=3)
86
+ siglip_output = gr.Label(label = "SigLIP Output", num_top_classes=3)
87
+ nllb_output = gr.Label(label = "NLLB-CLIP Output", num_top_classes=3)
88
+
89
+ examples = [["../cat.jpg", "eine Katze, köpek, un oiseau"]]
90
+ gr.Examples(
91
+ examples = examples,
92
+ inputs=[image_input, text_input],
93
+ outputs=[st_output,
94
+ siglip_output,
95
+ nllb_output],
96
+ fn=infer,
97
+ cache_examples=True
98
+ )
99
+ run_button.click(fn=infer,
100
+ inputs=[image_input, text_input],
101
+ outputs=[st_output,
102
+ siglip_output,
103
+ nllb_output])
104
+
105
+ demo.launch()