Upload model
Browse files- modelling_variable.py +1 -30
modelling_variable.py
CHANGED
@@ -205,35 +205,6 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
205 |
**kwargs_encoder,
|
206 |
) # CvT does not support output_attentions.
|
207 |
|
208 |
-
# Stack visual features from each study:
|
209 |
-
mbatch_size = len(set(dicom_study_ids))
|
210 |
-
max_images = dicom_study_ids.count(max(dicom_study_ids, key=dicom_study_ids.count))
|
211 |
-
feature_size = encoder_outputs.projected_last_hidden_state.shape[-1]
|
212 |
-
spatial_positions = encoder_outputs.projected_last_hidden_state.shape[-2]
|
213 |
-
|
214 |
-
# Create attention mask and visual features:
|
215 |
-
self.encoder_attention_mask = torch.zeros(mbatch_size, max_images * spatial_positions).to(self.device)
|
216 |
-
visual_features = torch.zeros(
|
217 |
-
mbatch_size,
|
218 |
-
max_images * spatial_positions,
|
219 |
-
feature_size,
|
220 |
-
dtype=encoder_outputs.projected_last_hidden_state.dtype,
|
221 |
-
).to(self.device)
|
222 |
-
|
223 |
-
# There has to be a better way to do the following:
|
224 |
-
row_count, column_count = 0, 0
|
225 |
-
previous = dicom_study_ids[0]
|
226 |
-
for i, j in enumerate(dicom_study_ids):
|
227 |
-
if j != previous:
|
228 |
-
row_count += 1
|
229 |
-
column_count = 0
|
230 |
-
self.encoder_attention_mask[row_count, column_count:column_count + spatial_positions] = 1.0
|
231 |
-
visual_features[row_count, column_count:column_count + spatial_positions] = encoder_outputs.projected_last_hidden_state[i]
|
232 |
-
column_count += spatial_positions
|
233 |
-
previous = j
|
234 |
-
|
235 |
-
encoder_outputs.projected_last_hidden_state = visual_features
|
236 |
-
|
237 |
elif isinstance(encoder_outputs, tuple):
|
238 |
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
239 |
|
@@ -243,7 +214,7 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
243 |
input_ids=decoder_input_ids,
|
244 |
attention_mask=decoder_attention_mask,
|
245 |
encoder_hidden_states=encoder_hidden_states,
|
246 |
-
encoder_attention_mask=
|
247 |
inputs_embeds=decoder_inputs_embeds,
|
248 |
output_attentions=output_attentions,
|
249 |
output_hidden_states=output_hidden_states,
|
|
|
205 |
**kwargs_encoder,
|
206 |
) # CvT does not support output_attentions.
|
207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
elif isinstance(encoder_outputs, tuple):
|
209 |
encoder_outputs = BaseModelOutput(*encoder_outputs)
|
210 |
|
|
|
214 |
input_ids=decoder_input_ids,
|
215 |
attention_mask=decoder_attention_mask,
|
216 |
encoder_hidden_states=encoder_hidden_states,
|
217 |
+
encoder_attention_mask=encoder_outputs.attention_mask,
|
218 |
inputs_embeds=decoder_inputs_embeds,
|
219 |
output_attentions=output_attentions,
|
220 |
output_hidden_states=output_hidden_states,
|