shikunl commited on
Commit
fb94f78
β€’
1 Parent(s): fb14311
Files changed (3) hide show
  1. app_caption.py +1 -1
  2. app_vqa.py +1 -2
  3. prismer_model.py +17 -20
app_caption.py CHANGED
@@ -28,7 +28,7 @@ def create_demo():
28
  object_detection = gr.Image(label='Object Detection')
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
- inputs = [image, model_name]
32
  outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
 
28
  object_detection = gr.Image(label='Object Detection')
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
+ inputs = [image, model_name, 'caption']
32
  outputs = [caption, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
app_vqa.py CHANGED
@@ -11,7 +11,6 @@ from prismer_model import Model
11
 
12
  def create_demo():
13
  model = Model()
14
- model.mode = 'vqa'
15
  with gr.Row():
16
  with gr.Column():
17
  image = gr.Image(label='Input', type='filepath')
@@ -29,7 +28,7 @@ def create_demo():
29
  object_detection = gr.Image(label='Object Detection')
30
  ocr = gr.Image(label='OCR Detection')
31
 
32
- inputs = [image, model_name, question]
33
  outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
34
 
35
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
 
11
 
12
  def create_demo():
13
  model = Model()
 
14
  with gr.Row():
15
  with gr.Column():
16
  image = gr.Image(label='Input', type='filepath')
 
28
  object_detection = gr.Image(label='Object Detection')
29
  ocr = gr.Image(label='OCR Detection')
30
 
31
+ inputs = [image, model_name, 'vqa', question]
32
  outputs = [answer, depth, edge, normals, segmentation, object_detection, ocr]
33
 
34
  # paths = sorted(pathlib.Path('prismer/images').glob('*'))
prismer_model.py CHANGED
@@ -68,20 +68,16 @@ class Model:
68
  self.config = None
69
  self.model = None
70
  self.tokenizer = None
 
71
  self.exp_name = ''
72
  self.mode = ''
73
 
74
- def set_model(self, exp_name: str) -> None:
75
- if exp_name == self.exp_name:
76
  return
77
 
78
- # remap model name
79
- if self.exp_name == 'Prismer-Base':
80
- self.exp_name = 'prismer_base'
81
- elif self.exp_name == 'Prismer-Large':
82
- self.exp_name = 'prismer_large'
83
-
84
  # load checkpoints
 
85
  if self.mode == 'caption':
86
  config = {
87
  'dataset': 'demo',
@@ -89,12 +85,12 @@ class Model:
89
  'label_path': 'prismer/helpers/labels',
90
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
91
  'image_resolution': 480,
92
- 'prismer_model': self.exp_name,
93
  'freeze': 'freeze_vision',
94
  'prefix': '',
95
  }
96
  model = PrismerCaption(config)
97
- state_dict = torch.load(f'prismer/logging/pretrain_{self.exp_name}/pytorch_model.bin', map_location='cuda:0')
98
 
99
  elif self.mode == 'vqa':
100
  config = {
@@ -103,12 +99,12 @@ class Model:
103
  'label_path': 'prismer/helpers/labels',
104
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
105
  'image_resolution': 480,
106
- 'prismer_model': self.exp_name,
107
  'freeze': 'freeze_vision',
108
  }
109
 
110
  model = PrismerVQA(config)
111
- state_dict = torch.load(f'prismer/logging/vqa_{self.exp_name}/pytorch_model.bin', map_location='cuda:0')
112
 
113
  model.load_state_dict(state_dict)
114
  model.eval()
@@ -117,10 +113,11 @@ class Model:
117
  self.model = model
118
  self.tokenizer = model.tokenizer
119
  self.exp_name = exp_name
 
120
 
121
  @torch.inference_mode()
122
- def run_caption_model(self, exp_name: str) -> str:
123
- self.set_model(exp_name)
124
  _, test_dataset = create_dataset('caption', self.config)
125
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
126
  experts, _ = next(iter(test_loader))
@@ -131,15 +128,15 @@ class Model:
131
  caption = caption.capitalize() + '.'
132
  return caption
133
 
134
- def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
135
  out_paths = run_experts(image_path)
136
- caption = self.run_caption_model(model_name)
137
  label_prettify(image_path, out_paths)
138
  return caption, *out_paths
139
 
140
  @torch.inference_mode()
141
- def run_vqa_model(self, exp_name: str, question: str) -> str:
142
- self.set_model(exp_name)
143
  _, test_dataset = create_dataset('caption', self.config)
144
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
145
  experts, _ = next(iter(test_loader))
@@ -151,8 +148,8 @@ class Model:
151
  answer = answer.capitalize() + '.'
152
  return answer
153
 
154
- def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
155
  out_paths = run_experts(image_path)
156
- answer = self.run_vqa_model(model_name, question)
157
  label_prettify(image_path, out_paths)
158
  return answer, *out_paths
 
68
  self.config = None
69
  self.model = None
70
  self.tokenizer = None
71
+ self.model_name = ''
72
  self.exp_name = ''
73
  self.mode = ''
74
 
75
+ def set_model(self, exp_name: str, mode: str) -> None:
76
+ if exp_name == self.exp_name and mode == self.mode:
77
  return
78
 
 
 
 
 
 
 
79
  # load checkpoints
80
+ model_name = exp_name.lower().replace('-', '_')
81
  if self.mode == 'caption':
82
  config = {
83
  'dataset': 'demo',
 
85
  'label_path': 'prismer/helpers/labels',
86
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
87
  'image_resolution': 480,
88
+ 'prismer_model': model_name,
89
  'freeze': 'freeze_vision',
90
  'prefix': '',
91
  }
92
  model = PrismerCaption(config)
93
+ state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0')
94
 
95
  elif self.mode == 'vqa':
96
  config = {
 
99
  'label_path': 'prismer/helpers/labels',
100
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
101
  'image_resolution': 480,
102
+ 'prismer_model': model_name,
103
  'freeze': 'freeze_vision',
104
  }
105
 
106
  model = PrismerVQA(config)
107
+ state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0')
108
 
109
  model.load_state_dict(state_dict)
110
  model.eval()
 
113
  self.model = model
114
  self.tokenizer = model.tokenizer
115
  self.exp_name = exp_name
116
+ self.mode = mode
117
 
118
  @torch.inference_mode()
119
+ def run_caption_model(self, exp_name: str, mode: str) -> str:
120
+ self.set_model(exp_name, mode)
121
  _, test_dataset = create_dataset('caption', self.config)
122
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
123
  experts, _ = next(iter(test_loader))
 
128
  caption = caption.capitalize() + '.'
129
  return caption
130
 
131
+ def run_caption(self, image_path: str, model_name: str, mode: str) -> tuple[str | None, ...]:
132
  out_paths = run_experts(image_path)
133
+ caption = self.run_caption_model(model_name, mode)
134
  label_prettify(image_path, out_paths)
135
  return caption, *out_paths
136
 
137
  @torch.inference_mode()
138
+ def run_vqa_model(self, exp_name: str, mode: str, question: str) -> str:
139
+ self.set_model(exp_name, mode)
140
  _, test_dataset = create_dataset('caption', self.config)
141
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
142
  experts, _ = next(iter(test_loader))
 
148
  answer = answer.capitalize() + '.'
149
  return answer
150
 
151
+ def run_vqa(self, image_path: str, model_name: str, mode: str, question: str) -> tuple[str | None, ...]:
152
  out_paths = run_experts(image_path)
153
+ answer = self.run_vqa_model(model_name, mode, question)
154
  label_prettify(image_path, out_paths)
155
  return answer, *out_paths