infgrad commited on
Commit
d215750
1 Parent(s): 9ec969b

Upload modeling_jasper_vl.py

Browse files
Files changed (1) hide show
  1. modeling_jasper_vl.py +1 -20
modeling_jasper_vl.py CHANGED
@@ -1132,7 +1132,6 @@ class JasperVL(PreTrainedModel):
1132
  self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(
1133
  (self.config.num_img_tokens, config.text_config.hidden_size)
1134
  )
1135
- self.vec_dropout = nn.Dropout1d(config.vector_dropout_p, inplace=True)
1136
 
1137
  self.vector_linear_12288 = nn.Linear(config.text_config.hidden_size, 12288, bias=True)
1138
  self.vector_linear_1024 = nn.Linear(config.text_config.hidden_size, 1024, bias=True)
@@ -1164,25 +1163,11 @@ class JasperVL(PreTrainedModel):
1164
  )[0]
1165
  else:
1166
  inputs_embeds = self.model.embed_tokens(input_ids)
1167
- # print("inputs_embeds.shape", inputs_embeds.shape)
1168
  B, N, C = inputs_embeds.shape
1169
  inputs_embeds = inputs_embeds.reshape(B * N, C)
1170
-
1171
- vit_embeds = self.vision_model(pixel_values=pixel_values)["last_hidden_state"]
1172
- # print("vit_embeds.shape", vit_embeds.shape)
1173
  vit_embeds = self.adaptive_avg_pool2d(vit_embeds)
1174
- # print("vit_embeds_adapt.shape", vit_embeds.shape)
1175
- # 拼接start 和 end
1176
- # vit_embeds = torch.cat(
1177
- # (self.vs_token_emb.expand((B, 1, C)), vit_embeds, self.ve_token_emb.expand((B, 1, C))),
1178
- # dim=1,
1179
- # )
1180
- # print("vit_embeds_adapt_cat.shape", vit_embeds.shape)
1181
- # TODO vis start和 vis end都用img_token_id代替来简化代码
1182
  selected = (input_ids.reshape(B * N) == self.config.img_token_id)
1183
- # print("selected.shape", selected.shape)
1184
- # print("selected[:4]", selected[:4])
1185
- # print("selected[285:305]", selected[285:305])
1186
  inputs_embeds[selected] = vit_embeds.reshape(-1, C)
1187
  inputs_embeds = inputs_embeds.reshape(B, N, C)
1188
  last_hidden_state = self.model(
@@ -1190,9 +1175,6 @@ class JasperVL(PreTrainedModel):
1190
  attention_mask=attention_mask,
1191
  )[0]
1192
  last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
1193
- # 默认padding side是right.保留第一个避免全0
1194
-
1195
- self.vec_dropout(last_hidden[:, 1:, :])
1196
  mean_last_hidden = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
1197
 
1198
  vectors_12288 = self.vector_linear_12288(mean_last_hidden)
@@ -1210,4 +1192,3 @@ class JasperVL(PreTrainedModel):
1210
  "sentence_embedding": sentence_embedding,
1211
  "all_vectors": [vectors_12288, vectors_1024, vectors_512, vectors_256],
1212
  }
1213
-
 
1132
  self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(
1133
  (self.config.num_img_tokens, config.text_config.hidden_size)
1134
  )
 
1135
 
1136
  self.vector_linear_12288 = nn.Linear(config.text_config.hidden_size, 12288, bias=True)
1137
  self.vector_linear_1024 = nn.Linear(config.text_config.hidden_size, 1024, bias=True)
 
1163
  )[0]
1164
  else:
1165
  inputs_embeds = self.model.embed_tokens(input_ids)
 
1166
  B, N, C = inputs_embeds.shape
1167
  inputs_embeds = inputs_embeds.reshape(B * N, C)
1168
+ vit_embeds = self.vision_model(pixel_values=pixel_values, return_dict=True)["last_hidden_state"]
 
 
1169
  vit_embeds = self.adaptive_avg_pool2d(vit_embeds)
 
 
 
 
 
 
 
 
1170
  selected = (input_ids.reshape(B * N) == self.config.img_token_id)
 
 
 
1171
  inputs_embeds[selected] = vit_embeds.reshape(-1, C)
1172
  inputs_embeds = inputs_embeds.reshape(B, N, C)
1173
  last_hidden_state = self.model(
 
1175
  attention_mask=attention_mask,
1176
  )[0]
1177
  last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
 
 
 
1178
  mean_last_hidden = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
1179
 
1180
  vectors_12288 = self.vector_linear_12288(mean_last_hidden)
 
1192
  "sentence_embedding": sentence_embedding,
1193
  "all_vectors": [vectors_12288, vectors_1024, vectors_512, vectors_256],
1194
  }