micole66 commited on
Commit
d68b3d4
1 Parent(s): 95d4d49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -20
app.py CHANGED
@@ -5,20 +5,23 @@ import torch
5
  import kelip
6
  import gradio as gr
7
 
 
8
  def load_model():
9
- model, preprocess_img, tokenizer = kelip.build_model('ViT-B/32')
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model = model.to(device)
12
  model.eval()
13
 
14
- model_dict = {'model': model,
15
- 'preprocess_img': preprocess_img,
16
- 'tokenizer': tokenizer
17
- }
 
18
  return model_dict
19
 
 
20
  def classify(img, user_text):
21
- preprocess_img = model_dict['preprocess_img']
22
 
23
  input_img = preprocess_img(img).unsqueeze(0)
24
 
@@ -27,17 +30,17 @@ def classify(img, user_text):
27
 
28
  # extract image features
29
  with torch.no_grad():
30
- image_features = model_dict['model'].encode_image(input_img)
31
 
32
  # extract text features
33
- user_texts = user_text.split(',')
34
- if user_text == '' or user_text.isspace():
35
  user_texts = []
36
 
37
- input_texts = model_dict['tokenizer'].encode(user_texts)
38
  if torch.cuda.is_available():
39
  input_texts = input_texts.cuda()
40
- text_features = model_dict['model'].encode_text(input_texts)
41
 
42
  # l2 normalize
43
  image_features /= image_features.norm(dim=-1, keepdim=True)
@@ -50,28 +53,30 @@ def classify(img, user_text):
50
 
51
  return result
52
 
53
- if __name__ == '__main__':
 
54
  global model_dict
55
 
56
  model_dict = load_model()
57
 
58
- inputs = [gr.inputs.Image(type="pil", label="Image"),
59
- gr.inputs.Textbox(lines=5, label="Caption"),
60
- ]
 
61
 
62
- outputs = ['label']
63
 
64
  title = "KELIP"
65
  description = "Zero-shot classification with KELIP -- Korean and English bilingual contrastive Language-Image Pre-training model that is trained with collected 1.1 billion image-text pairs (708 million Korean and 476 million English).<br> <br><a href='https://arxiv.org/abs/2203.14463' target='_blank'>Arxiv</a> | <a href='https://github.com/navervision/KELIP' target='_blank'>Github</a>"
66
-
67
  article = ""
68
 
69
- iface=gr.Interface(
70
  fn=classify,
71
  inputs=inputs,
72
  outputs=outputs,
73
  title=title,
74
  description=description,
75
- article=article
76
  )
77
- iface.launch()
 
5
  import kelip
6
  import gradio as gr
7
 
8
+
9
  def load_model():
10
+ model, preprocess_img, tokenizer = kelip.build_model("ViT-B/32")
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model = model.to(device)
13
  model.eval()
14
 
15
+ model_dict = {
16
+ "model": model,
17
+ "preprocess_img": preprocess_img,
18
+ "tokenizer": tokenizer,
19
+ }
20
  return model_dict
21
 
22
+
23
  def classify(img, user_text):
24
+ preprocess_img = model_dict["preprocess_img"]
25
 
26
  input_img = preprocess_img(img).unsqueeze(0)
27
 
 
30
 
31
  # extract image features
32
  with torch.no_grad():
33
+ image_features = model_dict["model"].encode_image(input_img)
34
 
35
  # extract text features
36
+ user_texts = user_text.split(",")
37
+ if user_text == "" or user_text.isspace():
38
  user_texts = []
39
 
40
+ input_texts = model_dict["tokenizer"].encode(user_texts)
41
  if torch.cuda.is_available():
42
  input_texts = input_texts.cuda()
43
+ text_features = model_dict["model"].encode_text(input_texts)
44
 
45
  # l2 normalize
46
  image_features /= image_features.norm(dim=-1, keepdim=True)
 
53
 
54
  return result
55
 
56
+
57
+ if __name__ == "__main__":
58
  global model_dict
59
 
60
  model_dict = load_model()
61
 
62
+ inputs = [
63
+ gr.inputs.Image(type="pil", label="Image"),
64
+ gr.inputs.Textbox(lines=5, label="Caption"),
65
+ ]
66
 
67
+ outputs = ["label"]
68
 
69
  title = "KELIP"
70
  description = "Zero-shot classification with KELIP -- Korean and English bilingual contrastive Language-Image Pre-training model that is trained with collected 1.1 billion image-text pairs (708 million Korean and 476 million English).<br> <br><a href='https://arxiv.org/abs/2203.14463' target='_blank'>Arxiv</a> | <a href='https://github.com/navervision/KELIP' target='_blank'>Github</a>"
71
+
72
  article = ""
73
 
74
+ iface = gr.Interface(
75
  fn=classify,
76
  inputs=inputs,
77
  outputs=outputs,
78
  title=title,
79
  description=description,
80
+ article=article,
81
  )
82
+ iface.launch()