mshukor commited on
Commit
ce7469b
1 Parent(s): 902be23
Files changed (1) hide show
  1. app.py +95 -26
app.py CHANGED
@@ -37,6 +37,7 @@ from ruamel.yaml import YAML
37
  import torch
38
  import gradio as gr
39
 
 
40
 
41
  yaml=YAML(typ='safe')
42
 
@@ -82,33 +83,33 @@ msg = model_caption.load_state_dict(state_dict,strict=False)
82
 
83
  model_caption.bfloat16()
84
 
85
- ###### VQA
86
- config = 'configs/image/ePALM_vqa.yaml'
87
- config = yaml.load(open(config, 'r'))
88
-
89
- start_layer_idx = 19
90
- end_layer_idx = 31
91
- low_cpu = True
92
- model_vqa = ePALM(opt_model_name=text_model,
93
- vision_model_name=vision_model_name,
94
- use_vis_prefix=True,
95
- start_layer_idx=start_layer_idx,
96
- end_layer_idx=end_layer_idx,
97
- return_hidden_state_vision=True,
98
- config=config,
99
- low_cpu=low_cpu
100
- )
101
- print("Model Built")
102
- model_vqa.to(device)
103
 
104
 
105
  checkpoint_path = 'checkpoints/float32/ePALM_vqa/checkpoint_best.pth'
106
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
107
- state_dict = checkpoint['model']
108
- msg = model_vqa.load_state_dict(state_dict,strict=False)
109
 
110
 
111
- model_vqa.bfloat16()
112
 
113
 
114
 
@@ -154,13 +155,80 @@ transform = transforms.Compose([
154
  normalize,
155
  ])
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
 
 
 
 
 
 
 
158
 
 
 
159
 
 
 
 
 
 
 
160
 
 
161
 
 
 
 
 
 
 
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  do_sample=False
165
  num_beams=3
166
  max_length=30
@@ -188,19 +256,20 @@ def inference(image, audio, video, task_type, instruction):
188
  elif task_type == 'Visual Question Answering':
189
  question = instruction+'?'+special_answer_token
190
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
191
- model = model_vqa.clone()
 
192
  elif task_type == 'Visual Question Answering':
193
  question = instruction+'?'+special_answer_token
194
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
195
- model_vqa = model_vqa.load_state_dict(state_dict_video_qa,strict=False)
196
- model = model_vqa.clone()
197
  else:
198
  raise NotImplemented
199
 
200
  if "Video" in task_type:
201
- pass
202
  elif "Audio" in task_type:
203
- pass
204
  else:
205
  image = transform(image)
206
  image = image.to(device,non_blocking=True).unsqueeze(0)
 
37
  import torch
38
  import gradio as gr
39
 
40
+ import torchaudio
41
 
42
  yaml=YAML(typ='safe')
43
 
 
83
 
84
  model_caption.bfloat16()
85
 
86
+ # ###### VQA
87
+ # config = 'configs/image/ePALM_vqa.yaml'
88
+ # config = yaml.load(open(config, 'r'))
89
+
90
+ # start_layer_idx = 19
91
+ # end_layer_idx = 31
92
+ # low_cpu = True
93
+ # model_vqa = ePALM(opt_model_name=text_model,
94
+ # vision_model_name=vision_model_name,
95
+ # use_vis_prefix=True,
96
+ # start_layer_idx=start_layer_idx,
97
+ # end_layer_idx=end_layer_idx,
98
+ # return_hidden_state_vision=True,
99
+ # config=config,
100
+ # low_cpu=low_cpu
101
+ # )
102
+ # print("Model Built")
103
+ # model_vqa.to(device)
104
 
105
 
106
  checkpoint_path = 'checkpoints/float32/ePALM_vqa/checkpoint_best.pth'
107
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
108
+ state_dict_vqa = checkpoint['model']
109
+ # msg = model_vqa.load_state_dict(state_dict,strict=False)
110
 
111
 
112
+ # model_vqa.bfloat16()
113
 
114
 
115
 
 
155
  normalize,
156
  ])
157
 
158
+ type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
159
+ test_transform = transforms.Compose([
160
+ transforms.Resize((image_size,image_size),interpolation=Image.BICUBIC),
161
+ type_transform,
162
+ normalize,
163
+ ])
164
+ from dataset.video_utils import VIDEO_READER_FUNCS
165
+ video_reader = VIDEO_READER_FUNCS['decord']
166
+
167
+ def read_video(path, num_frames=16):
168
+
169
+
170
+ frames, frame_indices, video_duration = video_reader(
171
+ path, num_frames, 'rand', max_num_frames=-1
172
+ )
173
+ video = test_transform(frames)
174
+
175
+ return video
176
 
177
+ def read_audio(path):
178
+
179
+ melbins = 128
180
+ target_length = 1024
181
+ skip_norm = False
182
+ norm_mean = -4.2677393
183
+ norm_std = 4.5689974
184
 
185
+ waveform, sr = torchaudio.load(path)
186
+ waveform = waveform - waveform.mean()
187
 
188
+ # audio
189
+ fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
190
+ window_type='hanning', num_mel_bins=melbins, dither=0.0,
191
+ frame_shift=10)
192
+
193
+ n_frames = fbank.shape[0]
194
 
195
+ p = target_length - n_frames
196
 
197
+ # cut and pad
198
+ if p > 0:
199
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
200
+ fbank = m(fbank)
201
+ elif p < 0:
202
+ fbank = fbank[0:target_length, :]
203
 
204
 
205
+
206
+
207
+ # SpecAug, not do for eval set
208
+
209
+ fbank = torch.transpose(fbank, 0, 1)
210
+ # this is just to satisfy new torchaudio version, which only accept [1, freq, time]
211
+ fbank = fbank.unsqueeze(0)
212
+
213
+
214
+
215
+ # squeeze it back, it is just a trick to satisfy new torchaudio version
216
+ fbank = fbank.squeeze(0)
217
+ fbank = torch.transpose(fbank, 0, 1)
218
+
219
+
220
+ # normalize the input for both training and test
221
+ if not skip_norm:
222
+ fbank = (fbank - norm_mean) / (norm_std * 2)
223
+ # skip normalization the input if you are trying to get the normalization stats.
224
+ else:
225
+ pass
226
+
227
+
228
+ audio = fbank
229
+
230
+ return audio
231
+
232
  do_sample=False
233
  num_beams=3
234
  max_length=30
 
256
  elif task_type == 'Visual Question Answering':
257
  question = instruction+'?'+special_answer_token
258
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
259
+ model_caption = model_caption.load_state_dict(state_dict_vqa,strict=False)
260
+ model = model_caption.clone()
261
  elif task_type == 'Visual Question Answering':
262
  question = instruction+'?'+special_answer_token
263
  text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
264
+ model_caption = model_caption.load_state_dict(state_dict_video_qa,strict=False)
265
+ model = model_caption.clone()
266
  else:
267
  raise NotImplemented
268
 
269
  if "Video" in task_type:
270
+ image = read_video(image)
271
  elif "Audio" in task_type:
272
+ image = read_audio(image)
273
  else:
274
  image = transform(image)
275
  image = image.to(device,non_blocking=True).unsqueeze(0)