anicolson commited on
Commit
0842d56
1 Parent(s): 797f142

Upload model

Browse files
Files changed (1) hide show
  1. modelling_variable.py +1 -30
modelling_variable.py CHANGED
@@ -205,35 +205,6 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
205
  **kwargs_encoder,
206
  ) # CvT does not support output_attentions.
207
 
208
- # Stack visual features from each study:
209
- mbatch_size = len(set(dicom_study_ids))
210
- max_images = dicom_study_ids.count(max(dicom_study_ids, key=dicom_study_ids.count))
211
- feature_size = encoder_outputs.projected_last_hidden_state.shape[-1]
212
- spatial_positions = encoder_outputs.projected_last_hidden_state.shape[-2]
213
-
214
- # Create attention mask and visual features:
215
- self.encoder_attention_mask = torch.zeros(mbatch_size, max_images * spatial_positions).to(self.device)
216
- visual_features = torch.zeros(
217
- mbatch_size,
218
- max_images * spatial_positions,
219
- feature_size,
220
- dtype=encoder_outputs.projected_last_hidden_state.dtype,
221
- ).to(self.device)
222
-
223
- # There has to be a better way to do the following:
224
- row_count, column_count = 0, 0
225
- previous = dicom_study_ids[0]
226
- for i, j in enumerate(dicom_study_ids):
227
- if j != previous:
228
- row_count += 1
229
- column_count = 0
230
- self.encoder_attention_mask[row_count, column_count:column_count + spatial_positions] = 1.0
231
- visual_features[row_count, column_count:column_count + spatial_positions] = encoder_outputs.projected_last_hidden_state[i]
232
- column_count += spatial_positions
233
- previous = j
234
-
235
- encoder_outputs.projected_last_hidden_state = visual_features
236
-
237
  elif isinstance(encoder_outputs, tuple):
238
  encoder_outputs = BaseModelOutput(*encoder_outputs)
239
 
@@ -243,7 +214,7 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
243
  input_ids=decoder_input_ids,
244
  attention_mask=decoder_attention_mask,
245
  encoder_hidden_states=encoder_hidden_states,
246
- encoder_attention_mask=self.encoder_attention_mask,
247
  inputs_embeds=decoder_inputs_embeds,
248
  output_attentions=output_attentions,
249
  output_hidden_states=output_hidden_states,
 
205
  **kwargs_encoder,
206
  ) # CvT does not support output_attentions.
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  elif isinstance(encoder_outputs, tuple):
209
  encoder_outputs = BaseModelOutput(*encoder_outputs)
210
 
 
214
  input_ids=decoder_input_ids,
215
  attention_mask=decoder_attention_mask,
216
  encoder_hidden_states=encoder_hidden_states,
217
+ encoder_attention_mask=encoder_outputs.attention_mask,
218
  inputs_embeds=decoder_inputs_embeds,
219
  output_attentions=output_attentions,
220
  output_hidden_states=output_hidden_states,