mshukor commited on
Commit
78ad2cd
1 Parent(s): 85d478c
Files changed (1) hide show
  1. app.py +62 -15
app.py CHANGED
@@ -48,22 +48,21 @@ device_type = 'cuda' if use_cuda else 'cpu'
48
 
49
  ## Load model
50
 
 
51
  config = 'configs/image/ePALM_caption.yaml'
52
  # config = yaml.load(open(config, 'r'), Loader=yaml.Loader)
53
  config = yaml.load(open(config, 'r'))
54
 
55
-
56
  text_model = 'facebook/opt-2.7b'
57
  vision_model_name = 'vit_base_patch16_224'
58
 
59
  # text_model = 'facebook/opt-6.7b'
60
  # vision_model_name = 'vit_large_patch16_224'
61
 
62
-
63
  start_layer_idx = 19
64
  end_layer_idx = 31
65
  low_cpu = True
66
- model = ePALM(opt_model_name=text_model,
67
  vision_model_name=vision_model_name,
68
  use_vis_prefix=True,
69
  start_layer_idx=start_layer_idx,
@@ -73,14 +72,48 @@ model = ePALM(opt_model_name=text_model,
73
  low_cpu=low_cpu
74
  )
75
  print("Model Built")
76
- model.to(device)
77
-
78
 
79
  checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
80
  # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
81
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
82
  state_dict = checkpoint['model']
83
- msg = model.load_state_dict(state_dict,strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  ## Load tokenizer
@@ -88,7 +121,10 @@ tokenizer = AutoTokenizer.from_pretrained(text_model, use_fast=False)
88
  eos_token = tokenizer.eos_token
89
  pad_token = tokenizer.pad_token
90
 
 
91
 
 
 
92
 
93
 
94
  image_size = 224
@@ -112,7 +148,8 @@ num_beams=3
112
  max_length=30
113
 
114
 
115
- model.bfloat16()
 
116
 
117
 
118
  def inference(image, audio, video, task_type, instruction):
@@ -120,6 +157,11 @@ def inference(image, audio, video, task_type, instruction):
120
  if task_type == 'Image Captioning':
121
  text = ['']
122
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
 
 
 
 
 
123
  else:
124
  raise NotImplemented
125
 
@@ -139,10 +181,15 @@ def inference(image, audio, video, task_type, instruction):
139
  out = model(image=image, text=text_input, mode='generate', return_dict=True, max_length=max_length,
140
  do_sample=do_sample, num_beams=num_beams)
141
 
142
- out_decode = []
143
- for i, o in enumerate(out):
144
- res = tokenizer.decode(o)
145
- response = res.split('</s>')[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') # skip_special_tokens=True
 
 
 
 
 
146
 
147
  return response
148
 
@@ -152,14 +199,14 @@ outputs = ['text']
152
  examples = [
153
  ['examples/images/soccer.jpg', None, None, 'Image Captioning', None],
154
  ['examples/images/ski.jpg', None, None, 'Visual Question Answering', 'what does the woman wearing black do?'],
155
- ['examples/images/banana.jpg', None, None, 'Visual Grounding', 'the detached banana'],
156
- ['examples/images/skateboard.jpg', None, None, 'General', 'which region does the text " a yellow bird " describe?'],
157
- ['examples/images/baseball.jpg', None, None, 'General', 'what color is the left car?'],
158
  [None, None, 'examples/videos/video7014.mp4', 'Video Captioning', None],
159
  [None, None, 'examples/videos/video7017.mp4', 'Video Captioning', None],
160
  [None, None, 'examples/videos/video7019.mp4', 'Video Captioning', None],
161
  [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
162
- [None, None, 'examples/videos/video7021.mp4', 'General Video', "What is this sport?"],
163
  [None, 'examples/audios/6cS0FsUM-cQ.wav', None, 'Audio Captioning', None],
164
  [None, 'examples/audios/AJtNitYMa1I.wav', None, 'Audio Captioning', None],
165
  ]
 
48
 
49
  ## Load model
50
 
51
+ ### Captioning
52
  config = 'configs/image/ePALM_caption.yaml'
53
  # config = yaml.load(open(config, 'r'), Loader=yaml.Loader)
54
  config = yaml.load(open(config, 'r'))
55
 
 
56
  text_model = 'facebook/opt-2.7b'
57
  vision_model_name = 'vit_base_patch16_224'
58
 
59
  # text_model = 'facebook/opt-6.7b'
60
  # vision_model_name = 'vit_large_patch16_224'
61
 
 
62
  start_layer_idx = 19
63
  end_layer_idx = 31
64
  low_cpu = True
65
+ model_caption = ePALM(opt_model_name=text_model,
66
  vision_model_name=vision_model_name,
67
  use_vis_prefix=True,
68
  start_layer_idx=start_layer_idx,
 
72
  low_cpu=low_cpu
73
  )
74
  print("Model Built")
75
+ model_caption.to(device)
 
76
 
77
  checkpoint_path = 'checkpoints/float32/ePALM_caption/checkpoint_best.pth'
78
  # checkpoint_path = '/data/mshukor/logs/eplam/models/accelerate/ePALM_pt_L_acc_caption/checkpoint_best.pth'
79
  checkpoint = torch.load(checkpoint_path, map_location='cpu')
80
  state_dict = checkpoint['model']
81
+ msg = model_caption.load_state_dict(state_dict,strict=False)
82
+
83
+
84
+
85
+
86
+ ###### VQA
87
+
88
+
89
+
90
+ config = 'configs/image/ePALM_vqa.yaml'
91
+ config = yaml.load(open(config, 'r'))
92
+
93
+ start_layer_idx = 19
94
+ end_layer_idx = 31
95
+ low_cpu = True
96
+ model_vqa = ePALM(opt_model_name=text_model,
97
+ vision_model_name=vision_model_name,
98
+ use_vis_prefix=True,
99
+ start_layer_idx=start_layer_idx,
100
+ end_layer_idx=end_layer_idx,
101
+ return_hidden_state_vision=True,
102
+ config=config,
103
+ low_cpu=low_cpu
104
+ )
105
+ print("Model Built")
106
+ model_vqa.to(device)
107
+
108
+
109
+ checkpoint_path = 'checkpoints/float32/ePALM_vqa/checkpoint_best.pth'
110
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
111
+ state_dict = checkpoint['model']
112
+ msg = model_vqa.load_state_dict(state_dict,strict=False)
113
+
114
+
115
+
116
+
117
 
118
 
119
  ## Load tokenizer
 
121
  eos_token = tokenizer.eos_token
122
  pad_token = tokenizer.pad_token
123
 
124
+ special_answer_token = '</a>'
125
 
126
+ special_tokens_dict = {'additional_special_tokens': [special_answer_token]}
127
+ tokenizer.add_special_tokens(special_tokens_dict)
128
 
129
 
130
  image_size = 224
 
148
  max_length=30
149
 
150
 
151
+ model_caption.bfloat16()
152
+ model_vqa.bfloat16()
153
 
154
 
155
  def inference(image, audio, video, task_type, instruction):
 
157
  if task_type == 'Image Captioning':
158
  text = ['']
159
  text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
160
+ model = model_caption
161
+ elif task_type == 'Visual Question Answering':
162
+ question = instruction+'?'+special_answer_token
163
+ text_input = tokenizer(question, padding='longest', return_tensors="pt").to(device)
164
+ model = model_vqa
165
  else:
166
  raise NotImplemented
167
 
 
181
  out = model(image=image, text=text_input, mode='generate', return_dict=True, max_length=max_length,
182
  do_sample=do_sample, num_beams=num_beams)
183
 
184
+
185
+ if 'Captioning' in task_type:
186
+ for i, o in enumerate(out):
187
+ res = tokenizer.decode(o)
188
+ response = res.split('</s>')[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') # skip_special_tokens=True
189
+ else:
190
+ for o in out:
191
+ o_list = o.tolist()
192
+ response = tokenizer.decode(o_list).split(special_answer_token)[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') # skip_special_tokens=True
193
 
194
  return response
195
 
 
199
  examples = [
200
  ['examples/images/soccer.jpg', None, None, 'Image Captioning', None],
201
  ['examples/images/ski.jpg', None, None, 'Visual Question Answering', 'what does the woman wearing black do?'],
202
+ ['examples/images/banana.jpg', None, None, 'Image Captioning', None],
203
+ ['examples/images/skateboard.jpg', None, None, 'Visual Question Answering', 'what is on top of the skateboard?'],
204
+ ['examples/images/baseball.jpg', None, None, 'Image Captioning', None],
205
  [None, None, 'examples/videos/video7014.mp4', 'Video Captioning', None],
206
  [None, None, 'examples/videos/video7017.mp4', 'Video Captioning', None],
207
  [None, None, 'examples/videos/video7019.mp4', 'Video Captioning', None],
208
  [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
209
+ [None, None, 'examples/videos/video7021.mp4', 'Video Captioning', None],
210
  [None, 'examples/audios/6cS0FsUM-cQ.wav', None, 'Audio Captioning', None],
211
  [None, 'examples/audios/AJtNitYMa1I.wav', None, 'Audio Captioning', None],
212
  ]