Skyy93 commited on
Commit
3d98c13
1 Parent(s): 1a746f6

Add new examples

Browse files
app.py CHANGED
@@ -1,9 +1,6 @@
1
- from urllib.request import urlopen
2
- import argparse
3
  import clip
4
  from PIL import Image
5
  import pandas as pd
6
- import time
7
  import torch
8
  from dataloader.extract_features_dataloader import transform_resize, question_preprocess
9
  from model.vqa_model import NetVQA
@@ -30,7 +27,7 @@ class InferenceConfig:
30
  5: "color",
31
  6: "other"}
32
  folds = 10
33
- tta = False
34
  # Data
35
  n_classes: int = 5726
36
 
@@ -38,7 +35,8 @@ class InferenceConfig:
38
  class_mapping: str = "./data/annotations/class_mapping.csv"
39
 
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
-
 
42
  config = InferenceConfig()
43
 
44
  # load class mapping
@@ -48,7 +46,7 @@ for i in range(len(cm)):
48
  row = cm.iloc[i]
49
  classid_to_answer[row["class_id"]] = row["answer"]
50
 
51
- clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip)
52
 
53
  model = NetVQA(config).to(config.device)
54
 
@@ -58,8 +56,8 @@ config.checkpoint_head = "{}/{}.pt".format(config.checkpoint_root_head, config.m
58
  model_state_dict = torch.load(config.checkpoint_head)
59
  model.load_state_dict(model_state_dict, strict=True)
60
 
 
61
 
62
- #%%
63
  # Select Preprocessing
64
  image_transforms = transform_resize(clip_model.visual.input_resolution)
65
 
@@ -69,30 +67,21 @@ else:
69
  question_transforms = None
70
 
71
  clip_model.eval()
72
- model.eval()
73
 
74
 
75
  def predict(img, text):
76
  img = Image.fromarray(img)
77
- if config.tta:
78
- image_augmentations = []
79
- for transform in image_transforms:
80
- image_augmentations.append(transform(img))
81
- img = torch.stack(image_augmentations, dim=0)
82
  else:
83
- img = image_transforms(img)
84
- img = img.unsqueeze(dim=0)
85
-
86
- question = question_transforms(text)
87
  question_tokens = clip.tokenize(question, truncate=True)
88
  with torch.no_grad():
89
  img = img.to(config.device)
90
  img_feature = clip_model.encode_image(img)
91
- if config.tta:
92
- weights = torch.tensor(config.features_selection).reshape((len(config.features_selection),1))
93
- img_feature = img_feature * weights.to(config.device)
94
- img_feature = img_feature.sum(0)
95
- img_feature = img_feature.unsqueeze(0)
96
 
97
  question_tokens = question_tokens.to(config.device)
98
  question_feature = clip_model.encode_text(question_tokens)
@@ -116,6 +105,6 @@ def predict(img, text):
116
  gr.Interface(fn=predict,
117
  inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')],
118
  outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)],
119
- examples=[['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']]
120
  ).launch()
121
 
 
 
 
1
  import clip
2
  from PIL import Image
3
  import pandas as pd
 
4
  import torch
5
  from dataloader.extract_features_dataloader import transform_resize, question_preprocess
6
  from model.vqa_model import NetVQA
 
27
  5: "color",
28
  6: "other"}
29
  folds = 10
30
+
31
  # Data
32
  n_classes: int = 5726
33
 
 
35
  class_mapping: str = "./data/annotations/class_mapping.csv"
36
 
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+
40
  config = InferenceConfig()
41
 
42
  # load class mapping
 
46
  row = cm.iloc[i]
47
  classid_to_answer[row["class_id"]] = row["answer"]
48
 
49
+ clip_model, preprocess = clip.load(config.model, download_root=config.checkpoint_root_clip, device=config.device)
50
 
51
  model = NetVQA(config).to(config.device)
52
 
 
56
  model_state_dict = torch.load(config.checkpoint_head)
57
  model.load_state_dict(model_state_dict, strict=True)
58
 
59
+ model.eval()
60
 
 
61
  # Select Preprocessing
62
  image_transforms = transform_resize(clip_model.visual.input_resolution)
63
 
 
67
  question_transforms = None
68
 
69
  clip_model.eval()
 
70
 
71
 
72
  def predict(img, text):
73
  img = Image.fromarray(img)
74
+ img = image_transforms(img)
75
+ img = img.unsqueeze(dim=0)
76
+
77
+ if question_transforms is not None:
78
+ question = question_transforms(text)
79
  else:
80
+ question = text
 
 
 
81
  question_tokens = clip.tokenize(question, truncate=True)
82
  with torch.no_grad():
83
  img = img.to(config.device)
84
  img_feature = clip_model.encode_image(img)
 
 
 
 
 
85
 
86
  question_tokens = question_tokens.to(config.device)
87
  question_feature = clip_model.encode_text(question_tokens)
 
105
  gr.Interface(fn=predict,
106
  inputs=[gr.Image(label='Image'), gr.Textbox(label='Question')],
107
  outputs=[gr.outputs.Label(label='Answer', num_top_classes=5), gr.outputs.Label(label='Answer Category', num_top_classes=7)],
108
+ examples=[['examples/Augustiner.jpg', 'What is this?'],['examples/VizWiz_test_00006968.jpg', 'Can you tell me the color of the dog?'], ['examples/VizWiz_test_00005604.jpg', 'What drink is this?'], ['examples/VizWiz_test_00006246.jpg', 'Can you please tell me what kind of tea this is?'], ['examples/VizWiz_train_00004056.jpg', 'Is that a beer or a coke?'], ['examples/VizWiz_train_00017146.jpg', 'Can you tell me what\'s on this envelope please?'], ['examples/VizWiz_val_00003077.jpg', 'What is this?']]
109
  ).launch()
110
 
examples/Augustiner.jpg ADDED
examples/VizWiz_test_00005604.jpg ADDED
examples/VizWiz_test_00006246.jpg ADDED
examples/VizWiz_test_00006968.jpg ADDED