shikunl commited on
Commit
ef365f5
1 Parent(s): fb94f78
Files changed (3) hide show
  1. app_caption.py +1 -1
  2. app_vqa.py +1 -1
  3. prismer_model.py +13 -9
app_caption.py CHANGED
@@ -28,7 +28,7 @@ def create_demo():
28
  object_detection = gr.Image(label='Object Detection')
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
- inputs = [image, model_name, 'caption']
32
  outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
 
28
  object_detection = gr.Image(label='Object Detection')
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
+ inputs = [image, model_name]
32
  outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
app_vqa.py CHANGED
@@ -28,7 +28,7 @@ def create_demo():
28
  object_detection = gr.Image(label='Object Detection')
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
- inputs = [image, model_name, 'vqa', question]
32
  outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
 
28
  object_detection = gr.Image(label='Object Detection')
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
+ inputs = [image, model_name, question]
32
  outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
prismer_model.py CHANGED
@@ -19,6 +19,7 @@ from dataset import create_dataset, create_loader
19
  from dataset.utils import pre_question
20
  from model.prismer_caption import PrismerCaption
21
  from model.prismer_vqa import PrismerVQA
 
22
 
23
 
24
  def download_models() -> None:
@@ -91,7 +92,8 @@ class Model:
91
  }
92
  model = PrismerCaption(config)
93
  state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0')
94
-
 
95
  elif self.mode == 'vqa':
96
  config = {
97
  'dataset': 'demo',
@@ -105,6 +107,8 @@ class Model:
105
 
106
  model = PrismerVQA(config)
107
  state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0')
 
 
108
 
109
  model.load_state_dict(state_dict)
110
  model.eval()
@@ -116,8 +120,8 @@ class Model:
116
  self.mode = mode
117
 
118
  @torch.inference_mode()
119
- def run_caption_model(self, exp_name: str, mode: str) -> str:
120
- self.set_model(exp_name, mode)
121
  _, test_dataset = create_dataset('caption', self.config)
122
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
123
  experts, _ = next(iter(test_loader))
@@ -128,15 +132,15 @@ class Model:
128
  caption = caption.capitalize() + '.'
129
  return caption
130
 
131
- def run_caption(self, image_path: str, model_name: str, mode: str) -> tuple[str | None, ...]:
132
  out_paths = run_experts(image_path)
133
- caption = self.run_caption_model(model_name, mode)
134
  label_prettify(image_path, out_paths)
135
  return caption, *out_paths
136
 
137
  @torch.inference_mode()
138
- def run_vqa_model(self, exp_name: str, mode: str, question: str) -> str:
139
- self.set_model(exp_name, mode)
140
  _, test_dataset = create_dataset('caption', self.config)
141
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
142
  experts, _ = next(iter(test_loader))
@@ -148,8 +152,8 @@ class Model:
148
  answer = answer.capitalize() + '.'
149
  return answer
150
 
151
- def run_vqa(self, image_path: str, model_name: str, mode: str, question: str) -> tuple[str | None, ...]:
152
  out_paths = run_experts(image_path)
153
- answer = self.run_vqa_model(model_name, mode, question)
154
  label_prettify(image_path, out_paths)
155
  return answer, *out_paths
 
19
  from dataset.utils import pre_question
20
  from model.prismer_caption import PrismerCaption
21
  from model.prismer_vqa import PrismerVQA
22
+ from model.modules.utils import interpolate_pos_embed
23
 
24
 
25
  def download_models() -> None:
 
92
  }
93
  model = PrismerCaption(config)
94
  state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0')
95
+ state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'],
96
+ len(model.expert_encoder.positional_embedding))
97
  elif self.mode == 'vqa':
98
  config = {
99
  'dataset': 'demo',
 
107
 
108
  model = PrismerVQA(config)
109
  state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0')
110
+ state_dict['expert_encoder.positional_embedding'] = interpolate_pos_embed(state_dict['expert_encoder.positional_embedding'],
111
+ len(model.expert_encoder.positional_embedding))
112
 
113
  model.load_state_dict(state_dict)
114
  model.eval()
 
120
  self.mode = mode
121
 
122
  @torch.inference_mode()
123
+ def run_caption_model(self, exp_name: str) -> str:
124
+ self.set_model(exp_name, 'caption')
125
  _, test_dataset = create_dataset('caption', self.config)
126
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
127
  experts, _ = next(iter(test_loader))
 
132
  caption = caption.capitalize() + '.'
133
  return caption
134
 
135
+ def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
136
  out_paths = run_experts(image_path)
137
+ caption = self.run_caption_model(model_name)
138
  label_prettify(image_path, out_paths)
139
  return caption, *out_paths
140
 
141
  @torch.inference_mode()
142
+ def run_vqa_model(self, exp_name: str, question: str) -> str:
143
+ self.set_model(exp_name, 'vqa')
144
  _, test_dataset = create_dataset('caption', self.config)
145
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
146
  experts, _ = next(iter(test_loader))
 
152
  answer = answer.capitalize() + '.'
153
  return answer
154
 
155
+ def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
156
  out_paths = run_experts(image_path)
157
+ answer = self.run_vqa_model(model_name, question)
158
  label_prettify(image_path, out_paths)
159
  return answer, *out_paths