mshukor commited on
Commit
902be23
1 Parent(s): 78ad2cd
Files changed (1) hide show
  1. app.py +41 -9
app.py CHANGED
@@ -80,13 +80,9 @@ checkpoint = torch.load(checkpoint_path, map_location='cpu')
80
  state_dict = checkpoint['model']
81
  msg = model_caption.load_state_dict(state_dict,strict=False)
82
 
83
-
84
-
85
 
86
  ###### VQA
87
-
88
-
89
-
90
  config = 'configs/image/ePALM_vqa.yaml'
91
  config = yaml.load(open(config, 'r'))
92
 
@@ -112,6 +108,28 @@ state_dict = checkpoint['model']
112
  msg = model_vqa.load_state_dict(state_dict,strict=False)
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
 
@@ -148,8 +166,7 @@ num_beams=3
148
  max_length=30
149
 
150
 
151
- model_caption.bfloat16()
152
- model_vqa.bfloat16()
153
 
154
 
155
  def inference(image, audio, video, task_type, instruction):
@@ -157,11 +174,26 @@ def inference(image, audio, video, task_type, instruction):
157
  if task_type == 'Image Captioning':
158
  text = ['']
159
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
160
- model = model_caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  elif task_type == 'Visual Question Answering':
162
  question = instruction+'?'+special_answer_token
163
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
164
- model = model_vqa
 
165
  else:
166
  raise NotImplemented
167
 
 
80
  state_dict = checkpoint['model']
81
  msg = model_caption.load_state_dict(state_dict,strict=False)
82
 
83
+ model_caption.bfloat16()
 
84
 
85
  ###### VQA
 
 
 
86
  config = 'configs/image/ePALM_vqa.yaml'
87
  config = yaml.load(open(config, 'r'))
88
 
 
108
  msg = model_vqa.load_state_dict(state_dict,strict=False)
109
 
110
 
111
+ model_vqa.bfloat16()
112
+
113
+
114
+
115
+ # Video Captioning
116
+ checkpoint_path = 'checkpoints/float32/ePALM_video_caption_msrvtt/checkpoint_best.pth'
117
+ # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
118
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
119
+ state_dict_video_caption = checkpoint['model']
120
+
121
+ # Video QA
122
+ checkpoint_path = 'checkpoints/float32/ePALM_video_qa_msrvtt/checkpoint_best.pth'
123
+ # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
124
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
125
+ state_dict_video_qa = checkpoint['model']
126
+
127
+
128
+ # Audio Captioning
129
+ checkpoint_path = 'checkpoints/float32/ePALM_audio_caption/checkpoint_best.pth'
130
+ # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
131
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
132
+ state_dict_audio_caption = checkpoint['model']
133
 
134
 
135
 
 
166
  max_length=30
167
 
168
 
169
+
 
170
 
171
 
172
  def inference(image, audio, video, task_type, instruction):
 
174
  if task_type == 'Image Captioning':
175
  text = ['']
176
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
177
+ model = model_caption.clone()
178
+ elif task_type == 'Video Captioning':
179
+ text = ['']
180
+ text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
181
+ model_caption = model_caption.load_state_dict(state_dict_video_caption,strict=False)
182
+ model = model_caption.clone()
183
+ elif task_type == 'Audio Captioning':
184
+ text = ['']
185
+ text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
186
+ model_caption = model_caption.load_state_dict(state_dict_audio_caption,strict=False)
187
+ model = model_caption.clone()
188
+ elif task_type == 'Visual Question Answering':
189
+ question = instruction+'?'+special_answer_token
190
+ text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
191
+ model = model_vqa.clone()
192
  elif task_type == 'Visual Question Answering':
193
  question = instruction+'?'+special_answer_token
194
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
195
+ model_vqa = model_vqa.load_state_dict(state_dict_video_qa,strict=False)
196
+ model = model_vqa.clone()
197
  else:
198
  raise NotImplemented
199