Upload modeling_jasper_vl.py
Browse files- modeling_jasper_vl.py +1 -20
modeling_jasper_vl.py
CHANGED
@@ -1132,7 +1132,6 @@ class JasperVL(PreTrainedModel):
|
|
1132 |
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(
|
1133 |
(self.config.num_img_tokens, config.text_config.hidden_size)
|
1134 |
)
|
1135 |
-
self.vec_dropout = nn.Dropout1d(config.vector_dropout_p, inplace=True)
|
1136 |
|
1137 |
self.vector_linear_12288 = nn.Linear(config.text_config.hidden_size, 12288, bias=True)
|
1138 |
self.vector_linear_1024 = nn.Linear(config.text_config.hidden_size, 1024, bias=True)
|
@@ -1164,25 +1163,11 @@ class JasperVL(PreTrainedModel):
|
|
1164 |
)[0]
|
1165 |
else:
|
1166 |
inputs_embeds = self.model.embed_tokens(input_ids)
|
1167 |
-
# print("inputs_embeds.shape", inputs_embeds.shape)
|
1168 |
B, N, C = inputs_embeds.shape
|
1169 |
inputs_embeds = inputs_embeds.reshape(B * N, C)
|
1170 |
-
|
1171 |
-
vit_embeds = self.vision_model(pixel_values=pixel_values)["last_hidden_state"]
|
1172 |
-
# print("vit_embeds.shape", vit_embeds.shape)
|
1173 |
vit_embeds = self.adaptive_avg_pool2d(vit_embeds)
|
1174 |
-
# print("vit_embeds_adapt.shape", vit_embeds.shape)
|
1175 |
-
# 拼接start 和 end
|
1176 |
-
# vit_embeds = torch.cat(
|
1177 |
-
# (self.vs_token_emb.expand((B, 1, C)), vit_embeds, self.ve_token_emb.expand((B, 1, C))),
|
1178 |
-
# dim=1,
|
1179 |
-
# )
|
1180 |
-
# print("vit_embeds_adapt_cat.shape", vit_embeds.shape)
|
1181 |
-
# TODO vis start和 vis end都用img_token_id代替来简化代码
|
1182 |
selected = (input_ids.reshape(B * N) == self.config.img_token_id)
|
1183 |
-
# print("selected.shape", selected.shape)
|
1184 |
-
# print("selected[:4]", selected[:4])
|
1185 |
-
# print("selected[285:305]", selected[285:305])
|
1186 |
inputs_embeds[selected] = vit_embeds.reshape(-1, C)
|
1187 |
inputs_embeds = inputs_embeds.reshape(B, N, C)
|
1188 |
last_hidden_state = self.model(
|
@@ -1190,9 +1175,6 @@ class JasperVL(PreTrainedModel):
|
|
1190 |
attention_mask=attention_mask,
|
1191 |
)[0]
|
1192 |
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
1193 |
-
# 默认padding side是right.保留第一个避免全0
|
1194 |
-
|
1195 |
-
self.vec_dropout(last_hidden[:, 1:, :])
|
1196 |
mean_last_hidden = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
1197 |
|
1198 |
vectors_12288 = self.vector_linear_12288(mean_last_hidden)
|
@@ -1210,4 +1192,3 @@ class JasperVL(PreTrainedModel):
|
|
1210 |
"sentence_embedding": sentence_embedding,
|
1211 |
"all_vectors": [vectors_12288, vectors_1024, vectors_512, vectors_256],
|
1212 |
}
|
1213 |
-
|
|
|
1132 |
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(
|
1133 |
(self.config.num_img_tokens, config.text_config.hidden_size)
|
1134 |
)
|
|
|
1135 |
|
1136 |
self.vector_linear_12288 = nn.Linear(config.text_config.hidden_size, 12288, bias=True)
|
1137 |
self.vector_linear_1024 = nn.Linear(config.text_config.hidden_size, 1024, bias=True)
|
|
|
1163 |
)[0]
|
1164 |
else:
|
1165 |
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
|
1166 |
B, N, C = inputs_embeds.shape
|
1167 |
inputs_embeds = inputs_embeds.reshape(B * N, C)
|
1168 |
+
vit_embeds = self.vision_model(pixel_values=pixel_values, return_dict=True)["last_hidden_state"]
|
|
|
|
|
1169 |
vit_embeds = self.adaptive_avg_pool2d(vit_embeds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1170 |
selected = (input_ids.reshape(B * N) == self.config.img_token_id)
|
|
|
|
|
|
|
1171 |
inputs_embeds[selected] = vit_embeds.reshape(-1, C)
|
1172 |
inputs_embeds = inputs_embeds.reshape(B, N, C)
|
1173 |
last_hidden_state = self.model(
|
|
|
1175 |
attention_mask=attention_mask,
|
1176 |
)[0]
|
1177 |
last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
|
|
|
|
|
|
1178 |
mean_last_hidden = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
1179 |
|
1180 |
vectors_12288 = self.vector_linear_12288(mean_last_hidden)
|
|
|
1192 |
"sentence_embedding": sentence_embedding,
|
1193 |
"all_vectors": [vectors_12288, vectors_1024, vectors_512, vectors_256],
|
1194 |
}
|
|