shikunl commited on
Commit
a8208b6
β€’
1 Parent(s): ef365f5
Files changed (3) hide show
  1. app_caption.py +1 -1
  2. app_vqa.py +2 -2
  3. prismer_model.py +1 -1
app_caption.py CHANGED
@@ -15,7 +15,7 @@ def create_demo():
15
  with gr.Row():
16
  with gr.Column():
17
  image = gr.Image(label='Input', type='filepath')
18
- model_name = gr.Dropdown(label='Model', choices=['Prismer-Base, Prismer-Large'], value='Prismer-Base')
19
  run_button = gr.Button('Run')
20
  with gr.Column(scale=1.5):
21
  caption = gr.Text(label='Model Prediction')
 
15
  with gr.Row():
16
  with gr.Column():
17
  image = gr.Image(label='Input', type='filepath')
18
+ model_name = gr.Dropdown(label='Model', choices=['Prismer-Base', 'Prismer-Large'], value='Prismer-Base')
19
  run_button = gr.Button('Run')
20
  with gr.Column(scale=1.5):
21
  caption = gr.Text(label='Model Prediction')
app_vqa.py CHANGED
@@ -44,9 +44,9 @@ def create_demo():
44
  gr.Examples(examples=examples,
45
  inputs=inputs,
46
  outputs=outputs,
47
- fn=model.run_vqa_model)
48
 
49
- run_button.click(fn=model.run_vqa_model, inputs=inputs, outputs=outputs)
50
 
51
 
52
  if __name__ == '__main__':
 
44
  gr.Examples(examples=examples,
45
  inputs=inputs,
46
  outputs=outputs,
47
+ fn=model.run_vqa)
48
 
49
+ run_button.click(fn=model.run_vqa, inputs=inputs, outputs=outputs)
50
 
51
 
52
  if __name__ == '__main__':
prismer_model.py CHANGED
@@ -145,7 +145,7 @@ class Model:
145
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
146
  experts, _ = next(iter(test_loader))
147
  question = pre_question(question)
148
- answer = self.model(experts, question, train=False, inference='generate')
149
  answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
150
  answer = answer.to(experts['rgb'].device)[0]
151
  answer = self.tokenizer.decode(answer, skip_special_tokens=True)
 
145
  test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
146
  experts, _ = next(iter(test_loader))
147
  question = pre_question(question)
148
+ answer = self.model(experts, [question], train=False, inference='generate')
149
  answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
150
  answer = answer.to(experts['rgb'].device)[0]
151
  answer = self.tokenizer.decode(answer, skip_special_tokens=True)