kobiso commited on
Commit
f9165b2
1 Parent(s): 730fc9d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch
5
+ import kelip
6
+ import gradio as gr
7
+
8
+ def load_model():
9
+ model, preprocess_img, tokenizer = kelip.build_model('ViT-B/32')
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model = model.to(device)
12
+ model.eval()
13
+
14
+ model_dict = {'model': model,
15
+ 'preprocess_img': preprocess_img,
16
+ 'tokenizer': tokenizer
17
+ }
18
+ return model_dict
19
+
20
+ def classify(img, user_text):
21
+ preprocess_img = model_dict['preprocess_img']
22
+
23
+ input_img = preprocess_img(img).unsqueeze(0)
24
+
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ input_img = input_img.to(device)
27
+
28
+ # extract image features
29
+ with torch.no_grad():
30
+ image_features = model_dict['model'].encode_image(input_img)
31
+
32
+ # extract text features
33
+ user_texts = user_text.split(',')
34
+ if user_text == '' or user_text.isspace():
35
+ user_texts = []
36
+
37
+ input_texts = model_dict['tokenizer'].encode(user_texts)
38
+ if torch.cuda.is_available():
39
+ input_texts = input_texts.cuda()
40
+ text_features = model_dict['model'].encode_text(input_texts)
41
+
42
+ # l2 normalize
43
+ image_features /= image_features.norm(dim=-1, keepdim=True)
44
+ text_features /= text_features.norm(dim=-1, keepdim=True)
45
+ similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
46
+ values, indices = similarity[0].topk(len(user_texts))
47
+ result = {}
48
+ for value, index in zip(values, indices):
49
+ result[user_texts[index]] = value.item()
50
+
51
+ return result
52
+
53
+ if __name__ == '__main__':
54
+ global model_dict
55
+
56
+ model_dict = load_model()
57
+
58
+ inputs = [gr.inputs.Image(type="pil", label="Image"),
59
+ gr.inputs.Textbox(lines=5, label="Caption"),
60
+ ]
61
+
62
+ outputs = ['label']
63
+
64
+ title = "KELIP"
65
+
66
+ if torch.cuda.is_available():
67
+ demo_status = "Demo is running on GPU"
68
+ else:
69
+ demo_status = "Demo is running on CPU"
70
+ description = f"Details: paper_url. {demo_status}"
71
+ examples = []
72
+ article = ""
73
+
74
+ iface=gr.Interface(
75
+ fn=classify,
76
+ inputs=inputs,
77
+ outputs=outputs,
78
+ examples=[],
79
+ title=title,
80
+ description=description,
81
+ article=article
82
+ )
83
+ iface.launch()