Upload model
Browse files- modelling_variable.py +13 -36
modelling_variable.py
CHANGED
@@ -57,55 +57,36 @@ class VariableCvtWithProjectionHead(transformers.CvtPreTrainedModel):
|
|
57 |
def forward(
|
58 |
self,
|
59 |
pixel_values: Optional[torch.Tensor] = None,
|
60 |
-
dicom_study_ids: Optional[Any] = None,
|
61 |
output_hidden_states: Optional[bool] = None,
|
62 |
return_dict: Optional[bool] = None,
|
63 |
) -> Union[Tuple, ModelOutputWithProjectionEmbedding]:
|
64 |
|
65 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
66 |
|
|
|
67 |
outputs = self.cvt(
|
68 |
-
pixel_values,
|
69 |
output_hidden_states=output_hidden_states,
|
70 |
return_dict=return_dict,
|
71 |
)
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
)
|
76 |
|
77 |
-
#
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
# Create attention mask and visual features:
|
84 |
-
attention_mask = torch.zeros(mbatch_size, max_images * spatial_positions).to(self.device)
|
85 |
-
visual_features = torch.zeros(
|
86 |
-
mbatch_size,
|
87 |
-
max_images * spatial_positions,
|
88 |
-
feature_size,
|
89 |
-
dtype=projection.dtype,
|
90 |
-
).to(self.device)
|
91 |
|
92 |
-
#
|
93 |
-
|
94 |
-
previous = dicom_study_ids[0]
|
95 |
-
for i, j in enumerate(dicom_study_ids):
|
96 |
-
if j != previous:
|
97 |
-
row_count += 1
|
98 |
-
column_count = 0
|
99 |
-
attention_mask[row_count, column_count:column_count + spatial_positions] = 1.0
|
100 |
-
visual_features[row_count, column_count:column_count + spatial_positions] = projection[i]
|
101 |
-
column_count += spatial_positions
|
102 |
-
previous = j
|
103 |
|
104 |
if not return_dict:
|
105 |
-
return
|
106 |
|
107 |
return ModelOutputWithProjectionEmbedding(
|
108 |
-
projected_last_hidden_state=
|
109 |
)
|
110 |
|
111 |
|
@@ -171,7 +152,6 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
171 |
def forward(
|
172 |
self,
|
173 |
pixel_values: Optional[torch.FloatTensor] = None,
|
174 |
-
dicom_study_ids: Optional[Any] = None,
|
175 |
decoder_input_ids: Optional[torch.LongTensor] = None,
|
176 |
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
177 |
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
@@ -199,7 +179,6 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
199 |
|
200 |
encoder_outputs = self.encoder(
|
201 |
pixel_values,
|
202 |
-
dicom_study_ids=dicom_study_ids,
|
203 |
output_hidden_states=output_hidden_states,
|
204 |
return_dict=return_dict,
|
205 |
**kwargs_encoder,
|
@@ -276,12 +255,10 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
|
|
276 |
'attention_mask': attention_mask,
|
277 |
'decoder_attention_mask': decoder_attention_mask,
|
278 |
'decoder_input_ids': decoder_inputs['input_ids'],
|
279 |
-
'dicom_study_ids': kwargs['dicom_study_ids'],
|
280 |
'decoder_token_type_ids': token_type_ids,
|
281 |
'encoder_outputs': encoder_outputs,
|
282 |
'past_key_values': decoder_inputs['past_key_values'],
|
283 |
'use_cache': use_cache,
|
284 |
-
'dicom_study_ids': kwargs['dicom_study_ids'],
|
285 |
}
|
286 |
return input_dict
|
287 |
|
|
|
57 |
def forward(
|
58 |
self,
|
59 |
pixel_values: Optional[torch.Tensor] = None,
|
|
|
60 |
output_hidden_states: Optional[bool] = None,
|
61 |
return_dict: Optional[bool] = None,
|
62 |
) -> Union[Tuple, ModelOutputWithProjectionEmbedding]:
|
63 |
|
64 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
65 |
|
66 |
+
# Flatten the batch and study_id dimensions:
|
67 |
outputs = self.cvt(
|
68 |
+
pixel_values.view(-1, *pixel_values.shape[2:]),
|
69 |
output_hidden_states=output_hidden_states,
|
70 |
return_dict=return_dict,
|
71 |
)
|
72 |
|
73 |
+
# Flatten h x w:
|
74 |
+
last_hidden_state = torch.flatten(outputs.last_hidden_state, 2)
|
|
|
75 |
|
76 |
+
# Project the features for each spatial position to the decoder's hidden size:
|
77 |
+
projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
|
78 |
+
|
79 |
+
# Concatenate the features for each chest X-ray:
|
80 |
+
projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
# Derive the attention mask from the pixel values:
|
83 |
+
attention_mask = (pixel_values[:, :, 0, 0, 0] != 0.0).repeat_interleave(last_hidden_state.shape[-1], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
if not return_dict:
|
86 |
+
return projection
|
87 |
|
88 |
return ModelOutputWithProjectionEmbedding(
|
89 |
+
projected_last_hidden_state=projection, attention_mask=attention_mask,
|
90 |
)
|
91 |
|
92 |
|
|
|
152 |
def forward(
|
153 |
self,
|
154 |
pixel_values: Optional[torch.FloatTensor] = None,
|
|
|
155 |
decoder_input_ids: Optional[torch.LongTensor] = None,
|
156 |
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
157 |
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
|
|
179 |
|
180 |
encoder_outputs = self.encoder(
|
181 |
pixel_values,
|
|
|
182 |
output_hidden_states=output_hidden_states,
|
183 |
return_dict=return_dict,
|
184 |
**kwargs_encoder,
|
|
|
255 |
'attention_mask': attention_mask,
|
256 |
'decoder_attention_mask': decoder_attention_mask,
|
257 |
'decoder_input_ids': decoder_inputs['input_ids'],
|
|
|
258 |
'decoder_token_type_ids': token_type_ids,
|
259 |
'encoder_outputs': encoder_outputs,
|
260 |
'past_key_values': decoder_inputs['past_key_values'],
|
261 |
'use_cache': use_cache,
|
|
|
262 |
}
|
263 |
return input_dict
|
264 |
|