Upload folder using huggingface_hub
Browse files- modelling_medicap.py +2 -12
modelling_medicap.py
CHANGED
@@ -173,7 +173,6 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
173 |
return_dict=return_dict,
|
174 |
**kwargs_encoder,
|
175 |
) # CvT does not support output_attentions.
|
176 |
-
assert decoder_inputs_embeds.shape[1] == 1
|
177 |
decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
|
178 |
if decoder_attention_mask is not None:
|
179 |
decoder_attention_mask = torch.cat(
|
@@ -182,8 +181,8 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
182 |
decoder_attention_mask
|
183 |
],
|
184 |
dim=1,
|
185 |
-
)
|
186 |
-
|
187 |
decoder_outputs = self.decoder(
|
188 |
attention_mask=decoder_attention_mask,
|
189 |
inputs_embeds=decoder_inputs_embeds,
|
@@ -249,15 +248,6 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
|
|
249 |
input_dict['past_key_values'] = decoder_inputs['past_key_values']
|
250 |
input_dict['decoder_attention_mask'] = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
|
251 |
|
252 |
-
# if torch.is_tensor(decoder_attention_mask):
|
253 |
-
# decoder_attention_mask = torch.cat(
|
254 |
-
# [
|
255 |
-
# torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
|
256 |
-
# decoder_attention_mask
|
257 |
-
# ],
|
258 |
-
# dim=1,
|
259 |
-
# )
|
260 |
-
|
261 |
return input_dict
|
262 |
|
263 |
def tokenize_captions_teacher_forcing(
|
|
|
173 |
return_dict=return_dict,
|
174 |
**kwargs_encoder,
|
175 |
) # CvT does not support output_attentions.
|
|
|
176 |
decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
|
177 |
if decoder_attention_mask is not None:
|
178 |
decoder_attention_mask = torch.cat(
|
|
|
181 |
decoder_attention_mask
|
182 |
],
|
183 |
dim=1,
|
184 |
+
)
|
185 |
+
|
186 |
decoder_outputs = self.decoder(
|
187 |
attention_mask=decoder_attention_mask,
|
188 |
inputs_embeds=decoder_inputs_embeds,
|
|
|
248 |
input_dict['past_key_values'] = decoder_inputs['past_key_values']
|
249 |
input_dict['decoder_attention_mask'] = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
|
250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
return input_dict
|
252 |
|
253 |
def tokenize_captions_teacher_forcing(
|