anicolson commited on
Commit
4aa737e
1 Parent(s): a9989b7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modelling_medicap.py +2 -12
modelling_medicap.py CHANGED
@@ -173,7 +173,6 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
173
  return_dict=return_dict,
174
  **kwargs_encoder,
175
  ) # CvT does not support output_attentions.
176
- assert decoder_inputs_embeds.shape[1] == 1
177
  decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
178
  if decoder_attention_mask is not None:
179
  decoder_attention_mask = torch.cat(
@@ -182,8 +181,8 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
182
  decoder_attention_mask
183
  ],
184
  dim=1,
185
- )
186
-
187
  decoder_outputs = self.decoder(
188
  attention_mask=decoder_attention_mask,
189
  inputs_embeds=decoder_inputs_embeds,
@@ -249,15 +248,6 @@ class MedICapEncoderDecoderModel(VisionEncoderDecoderModel):
249
  input_dict['past_key_values'] = decoder_inputs['past_key_values']
250
  input_dict['decoder_attention_mask'] = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
251
 
252
- # if torch.is_tensor(decoder_attention_mask):
253
- # decoder_attention_mask = torch.cat(
254
- # [
255
- # torch.ones(encoder_outputs[0].shape[:-1], dtype=decoder_attention_mask.dtype, device=self.device),
256
- # decoder_attention_mask
257
- # ],
258
- # dim=1,
259
- # )
260
-
261
  return input_dict
262
 
263
  def tokenize_captions_teacher_forcing(
 
173
  return_dict=return_dict,
174
  **kwargs_encoder,
175
  ) # CvT does not support output_attentions.
 
176
  decoder_inputs_embeds = torch.cat([encoder_outputs[0], decoder_inputs_embeds], dim=1)
177
  if decoder_attention_mask is not None:
178
  decoder_attention_mask = torch.cat(
 
181
  decoder_attention_mask
182
  ],
183
  dim=1,
184
+ )
185
+
186
  decoder_outputs = self.decoder(
187
  attention_mask=decoder_attention_mask,
188
  inputs_embeds=decoder_inputs_embeds,
 
248
  input_dict['past_key_values'] = decoder_inputs['past_key_values']
249
  input_dict['decoder_attention_mask'] = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
250
 
 
 
 
 
 
 
 
 
 
251
  return input_dict
252
 
253
  def tokenize_captions_teacher_forcing(