shikunl commited on
Commit
fb14311
β€’
1 Parent(s): 53b7b42
Files changed (2) hide show
  1. label_prettify.py +1 -1
  2. prismer_model.py +8 -6
label_prettify.py CHANGED
@@ -87,7 +87,7 @@ def ocr_detection_prettify(rgb_path, file_name):
87
  ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
88
 
89
  plt.imshow(rgb)
90
- plt.imshow((1 - ocr_labels) < 1, cmap='gray', alpha=0.8)
91
 
92
  for i in np.unique(ocr_labels)[:-1]:
93
  text_idx_all = np.where(ocr_labels == i)
 
87
  ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
88
 
89
  plt.imshow(rgb)
90
+ plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
91
 
92
  for i in np.unique(ocr_labels)[:-1]:
93
  text_idx_all = np.where(ocr_labels == i)
prismer_model.py CHANGED
@@ -75,11 +75,13 @@ class Model:
75
  if exp_name == self.exp_name:
76
  return
77
 
 
78
  if self.exp_name == 'Prismer-Base':
79
- model_name = 'prismer_base'
80
  elif self.exp_name == 'Prismer-Large':
81
- model_name = 'prismer_large'
82
 
 
83
  if self.mode == 'caption':
84
  config = {
85
  'dataset': 'demo',
@@ -87,12 +89,12 @@ class Model:
87
  'label_path': 'prismer/helpers/labels',
88
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
89
  'image_resolution': 480,
90
- 'prismer_model': model_name,
91
  'freeze': 'freeze_vision',
92
  'prefix': '',
93
  }
94
  model = PrismerCaption(config)
95
- state_dict = torch.load(f'prismer/logging/pretrain_{model_name}/pytorch_model.bin', map_location='cuda:0')
96
 
97
  elif self.mode == 'vqa':
98
  config = {
@@ -101,12 +103,12 @@ class Model:
101
  'label_path': 'prismer/helpers/labels',
102
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
103
  'image_resolution': 480,
104
- 'prismer_model': model_name,
105
  'freeze': 'freeze_vision',
106
  }
107
 
108
  model = PrismerVQA(config)
109
- state_dict = torch.load(f'prismer/logging/vqa_{model_name}/pytorch_model.bin', map_location='cuda:0')
110
 
111
  model.load_state_dict(state_dict)
112
  model.eval()
 
75
  if exp_name == self.exp_name:
76
  return
77
 
78
+ # remap model name
79
  if self.exp_name == 'Prismer-Base':
80
+ self.exp_name = 'prismer_base'
81
  elif self.exp_name == 'Prismer-Large':
82
+ self.exp_name = 'prismer_large'
83
 
84
+ # load checkpoints
85
  if self.mode == 'caption':
86
  config = {
87
  'dataset': 'demo',
 
89
  'label_path': 'prismer/helpers/labels',
90
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
91
  'image_resolution': 480,
92
+ 'prismer_model': self.exp_name,
93
  'freeze': 'freeze_vision',
94
  'prefix': '',
95
  }
96
  model = PrismerCaption(config)
97
+ state_dict = torch.load(f'prismer/logging/pretrain_{self.exp_name}/pytorch_model.bin', map_location='cuda:0')
98
 
99
  elif self.mode == 'vqa':
100
  config = {
 
103
  'label_path': 'prismer/helpers/labels',
104
  'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
105
  'image_resolution': 480,
106
+ 'prismer_model': self.exp_name,
107
  'freeze': 'freeze_vision',
108
  }
109
 
110
  model = PrismerVQA(config)
111
+ state_dict = torch.load(f'prismer/logging/vqa_{self.exp_name}/pytorch_model.bin', map_location='cuda:0')
112
 
113
  model.load_state_dict(state_dict)
114
  model.eval()