Spaces:
Sleeping
Sleeping
Fix labels
Browse files- app_caption.py +2 -2
- prismer_model.py +32 -14
app_caption.py
CHANGED
@@ -11,11 +11,11 @@ from prismer_model import Model
|
|
11 |
|
12 |
def create_demo():
|
13 |
model = Model()
|
14 |
-
|
15 |
with gr.Row():
|
16 |
with gr.Column():
|
17 |
-
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base'], value='Prismer-Base')
|
18 |
image = gr.Image(label='Input', type='filepath')
|
|
|
19 |
run_button = gr.Button('Run')
|
20 |
with gr.Column(scale=1.5):
|
21 |
caption = gr.Text(label='Caption')
|
|
|
11 |
|
12 |
def create_demo():
|
13 |
model = Model()
|
14 |
+
model.mode = 'caption'
|
15 |
with gr.Row():
|
16 |
with gr.Column():
|
|
|
17 |
image = gr.Image(label='Input', type='filepath')
|
18 |
+
model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
|
19 |
run_button = gr.Button('Run')
|
20 |
with gr.Column(scale=1.5):
|
21 |
caption = gr.Text(label='Caption')
|
prismer_model.py
CHANGED
@@ -58,7 +58,7 @@ def run_experts(image_path: str) -> tuple[str | None, ...]:
|
|
58 |
|
59 |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
60 |
results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
|
61 |
-
return tuple(path.as_posix()
|
62 |
|
63 |
|
64 |
class Model:
|
@@ -67,24 +67,42 @@ class Model:
|
|
67 |
self.model = None
|
68 |
self.tokenizer = None
|
69 |
self.exp_name = ''
|
|
|
70 |
|
71 |
def set_model(self, exp_name: str) -> None:
|
72 |
if exp_name == self.exp_name:
|
73 |
return
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
model.load_state_dict(state_dict)
|
89 |
model.eval()
|
90 |
|
|
|
58 |
|
59 |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
60 |
results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/image.png' for key in keys]
|
61 |
+
return tuple(path.as_posix() for path in results)
|
62 |
|
63 |
|
64 |
class Model:
|
|
|
67 |
self.model = None
|
68 |
self.tokenizer = None
|
69 |
self.exp_name = ''
|
70 |
+
self.mode = ''
|
71 |
|
72 |
def set_model(self, exp_name: str) -> None:
|
73 |
if exp_name == self.exp_name:
|
74 |
return
|
75 |
|
76 |
+
if self.mode == 'caption':
|
77 |
+
config = {
|
78 |
+
'dataset': 'demo',
|
79 |
+
'data_path': 'prismer/helpers',
|
80 |
+
'label_path': 'prismer/helpers/labels',
|
81 |
+
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
82 |
+
'image_resolution': 480,
|
83 |
+
'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
|
84 |
+
'freeze': 'freeze_vision',
|
85 |
+
'prefix': 'A picture of',
|
86 |
+
}
|
87 |
+
|
88 |
+
model = PrismerCaption(config)
|
89 |
+
state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
|
90 |
+
|
91 |
+
elif self.mode == 'vqa':
|
92 |
+
config = {
|
93 |
+
'dataset': 'demo',
|
94 |
+
'data_path': 'prismer/helpers',
|
95 |
+
'label_path': 'prismer/helpers/labels',
|
96 |
+
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
97 |
+
'image_resolution': 480,
|
98 |
+
'prismer_model': 'prismer_base' if self.exp_name == 'Prismer-Base' else 'prismer_large',
|
99 |
+
'freeze': 'freeze_vision',
|
100 |
+
'prefix': 'A picture of',
|
101 |
+
}
|
102 |
+
|
103 |
+
model = PrismerCaption(config)
|
104 |
+
state_dict = torch.load(f'prismer/logging/caption_{exp_name}/pytorch_model.bin', map_location='cuda:0')
|
105 |
+
|
106 |
model.load_state_dict(state_dict)
|
107 |
model.eval()
|
108 |
|