shikunl commited on
Commit
6eaf487
β€’
1 Parent(s): 63bc825

Fix labels

Browse files
Files changed (2) hide show
  1. app_caption.py +2 -2
  2. prismer_model.py +32 -14
app_caption.py CHANGED
@@ -11,11 +11,11 @@ from prismer_model import Model
11
 
12
  def create_demo():
13
  model = Model()
14
-
15
  with gr.Row():
16
  with gr.Column():
17
- model_name = gr.Dropdown(label='Model', choices=['Prismer-Base'], value='Prismer-Base')
18
  image = gr.Image(label='Input', type='filepath')
 
19
  run_button = gr.Button('Run')
20
  with gr.Column(scale=1.5):
21
  caption = gr.Text(label='Caption')
 
11
 
12
  def create_demo():
13
  model = Model()
14
+ model.mode = 'caption'
15
  with gr.Row():
16
  with gr.Column():
 
17
  image = gr.Image(label='Input', type='filepath')
18
+ model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
19
  run_button = gr.Button('Run')
20
  with gr.Column(scale=1.5):
21
  caption = gr.Text(label='Caption')
prismer_model.py CHANGED
@@ -58,7 +58,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
58
 
59
  keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
60
  results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
61
- return tuple(path.as_posix() if path.exists() else None for path in results)
62
 
63
 
64
  class Model:
@@ -67,24 +67,42 @@ class Model:
67
  self.model = None
68
  self.tokenizer = None
69
  self.exp_name = ''
 
70
 
71
  def set_model(self, exp_name: str) -> None:
72
  if exp_name == self.exp_name:
73
  return
74
 
75
- config = {
76
- 'dataset': 'demo',
77
- 'data_path': 'prismer/helpers',
78
- 'label_path': 'prismer/helpers/labels',
79
- 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
80
- 'image_resolution': 480,
81
- 'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
82
- 'freeze': 'freeze_vision',
83
- 'prefix': 'A picture of',
84
- }
85
-
86
- model = PrismerCaption(config)
87
- state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  model.load_state_dict(state_dict)
89
  model.eval()
90
 
 
58
 
59
  keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
60
  results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
61
+ return tuple(path.as_posix() for path in results)
62
 
63
 
64
  class Model:
 
67
  self.model = None
68
  self.tokenizer = None
69
  self.exp_name = ''
70
+ self.mode = ''
71
 
72
  def set_model(self, exp_name: str) -> None:
73
  if exp_name == self.exp_name:
74
  return
75
 
76
+ if self.mode == 'caption':
77
+ config = {
78
+ 'dataset': 'demo',
79
+ 'data_path': 'prismer/helpers',
80
+ 'label_path': 'prismer/helpers/labels',
81
+ 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
82
+ 'image_resolution': 480,
83
+ 'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
84
+ 'freeze': 'freeze_vision',
85
+ 'prefix': 'A picture of',
86
+ }
87
+
88
+ model = PrismerCaption(config)
89
+ state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
90
+
91
+ elif self.mode == 'vqa':
92
+ config = {
93
+ 'dataset': 'demo',
94
+ 'data_path': 'prismer/helpers',
95
+ 'label_path': 'prismer/helpers/labels',
96
+ 'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
97
+ 'image_resolution': 480,
98
+ 'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
99
+ 'freeze': 'freeze_vision',
100
+ 'prefix': 'A picture of',
101
+ }
102
+
103
+ model = PrismerCaption(config)
104
+ state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
105
+
106
  model.load_state_dict(state_dict)
107
  model.eval()
108