anicolson commited on
Commit
6778ac9
·
1 Parent(s): 4cdadba

Upload model

Browse files
Files changed (1) hide show
  1. modelling_variable.py +13 -36
modelling_variable.py CHANGED
@@ -57,55 +57,36 @@ class VariableCvtWithProjectionHead(transformers.CvtPreTrainedModel):
57
  def forward(
58
  self,
59
  pixel_values: Optional[torch.Tensor] = None,
60
- dicom_study_ids: Optional[Any] = None,
61
  output_hidden_states: Optional[bool] = None,
62
  return_dict: Optional[bool] = None,
63
  ) -> Union[Tuple, ModelOutputWithProjectionEmbedding]:
64
 
65
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
66
 
 
67
  outputs = self.cvt(
68
- pixel_values,
69
  output_hidden_states=output_hidden_states,
70
  return_dict=return_dict,
71
  )
72
 
73
- projection = self.projection_head(
74
- torch.permute(torch.flatten(outputs.last_hidden_state, 2), [0, 2, 1]),
75
- )
76
 
77
- # Stack visual features from each study:
78
- mbatch_size = len(set(dicom_study_ids))
79
- max_images = dicom_study_ids.count(max(dicom_study_ids, key=dicom_study_ids.count))
80
- feature_size = projection.shape[-1]
81
- spatial_positions = projection.shape[-2]
82
-
83
- # Create attention mask and visual features:
84
- attention_mask = torch.zeros(mbatch_size, max_images * spatial_positions).to(self.device)
85
- visual_features = torch.zeros(
86
- mbatch_size,
87
- max_images * spatial_positions,
88
- feature_size,
89
- dtype=projection.dtype,
90
- ).to(self.device)
91
 
92
- # There has to be a better way to do the following:
93
- row_count, column_count = 0, 0
94
- previous = dicom_study_ids[0]
95
- for i, j in enumerate(dicom_study_ids):
96
- if j != previous:
97
- row_count += 1
98
- column_count = 0
99
- attention_mask[row_count, column_count:column_count + spatial_positions] = 1.0
100
- visual_features[row_count, column_count:column_count + spatial_positions] = projection[i]
101
- column_count += spatial_positions
102
- previous = j
103
 
104
  if not return_dict:
105
- return visual_features
106
 
107
  return ModelOutputWithProjectionEmbedding(
108
- projected_last_hidden_state=visual_features, attention_mask=attention_mask,
109
  )
110
 
111
 
@@ -171,7 +152,6 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
171
  def forward(
172
  self,
173
  pixel_values: Optional[torch.FloatTensor] = None,
174
- dicom_study_ids: Optional[Any] = None,
175
  decoder_input_ids: Optional[torch.LongTensor] = None,
176
  decoder_attention_mask: Optional[torch.BoolTensor] = None,
177
  encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
@@ -199,7 +179,6 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
199
 
200
  encoder_outputs = self.encoder(
201
  pixel_values,
202
- dicom_study_ids=dicom_study_ids,
203
  output_hidden_states=output_hidden_states,
204
  return_dict=return_dict,
205
  **kwargs_encoder,
@@ -276,12 +255,10 @@ class VariableCXREncoderDecoderModel(VisionEncoderDecoderModel):
276
  'attention_mask': attention_mask,
277
  'decoder_attention_mask': decoder_attention_mask,
278
  'decoder_input_ids': decoder_inputs['input_ids'],
279
- 'dicom_study_ids': kwargs['dicom_study_ids'],
280
  'decoder_token_type_ids': token_type_ids,
281
  'encoder_outputs': encoder_outputs,
282
  'past_key_values': decoder_inputs['past_key_values'],
283
  'use_cache': use_cache,
284
- 'dicom_study_ids': kwargs['dicom_study_ids'],
285
  }
286
  return input_dict
287
 
 
57
  def forward(
58
  self,
59
  pixel_values: Optional[torch.Tensor] = None,
 
60
  output_hidden_states: Optional[bool] = None,
61
  return_dict: Optional[bool] = None,
62
  ) -> Union[Tuple, ModelOutputWithProjectionEmbedding]:
63
 
64
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65
 
66
+ # Flatten the batch and study_id dimensions:
67
  outputs = self.cvt(
68
+ pixel_values.view(-1, *pixel_values.shape[2:]),
69
  output_hidden_states=output_hidden_states,
70
  return_dict=return_dict,
71
  )
72
 
73
+ # Flatten h x w:
74
+ last_hidden_state = torch.flatten(outputs.last_hidden_state, 2)
 
75
 
76
+ # Project the features for each spatial position to the decoder's hidden size:
77
+ projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
78
+
79
+ # Concatenate the features for each chest X-ray:
80
+ projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
 
 
 
 
 
 
 
 
 
81
 
82
+ # Derive the attention mask from the pixel values:
83
+ attention_mask = (pixel_values[:, :, 0, 0, 0] != 0.0).repeat_interleave(last_hidden_state.shape[-1], dim=1)
 
 
 
 
 
 
 
 
 
84
 
85
  if not return_dict:
86
+ return projection
87
 
88
  return ModelOutputWithProjectionEmbedding(
89
+ projected_last_hidden_state=projection, attention_mask=attention_mask,
90
  )
91
 
92
 
 
152
  def forward(
153
  self,
154
  pixel_values: Optional[torch.FloatTensor] = None,
 
155
  decoder_input_ids: Optional[torch.LongTensor] = None,
156
  decoder_attention_mask: Optional[torch.BoolTensor] = None,
157
  encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
 
179
 
180
  encoder_outputs = self.encoder(
181
  pixel_values,
 
182
  output_hidden_states=output_hidden_states,
183
  return_dict=return_dict,
184
  **kwargs_encoder,
 
255
  'attention_mask': attention_mask,
256
  'decoder_attention_mask': decoder_attention_mask,
257
  'decoder_input_ids': decoder_inputs['input_ids'],
 
258
  'decoder_token_type_ids': token_type_ids,
259
  'encoder_outputs': encoder_outputs,
260
  'past_key_values': decoder_inputs['past_key_values'],
261
  'use_cache': use_cache,
 
262
  }
263
  return input_dict
264