anicolson commited on
Commit
6e83530
·
1 Parent(s): ee5c3e0

Upload model

Browse files
Files changed (2) hide show
  1. config.json +3 -0
  2. modelling_single.py +410 -0
config.json CHANGED
@@ -3,6 +3,9 @@
3
  "architectures": [
4
  "SingleCXREncoderDecoderModel"
5
  ],
 
 
 
6
  "decoder": {
7
  "_name_or_path": "",
8
  "add_cross_attention": true,
 
3
  "architectures": [
4
  "SingleCXREncoderDecoderModel"
5
  ],
6
+ "auto_map": {
7
+ "AutoModel": "modelling_single.SingleCXREncoderDecoderModel"
8
+ },
9
  "decoder": {
10
  "_name_or_path": "",
11
  "add_cross_attention": true,
modelling_single.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import transformers
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers import VisionEncoderDecoderModel
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import \
12
+ VisionEncoderDecoderConfig
13
+ from transformers.utils import logging
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class CvtWithProjectionHeadConfig(transformers.CvtConfig):
19
+ def __init__(self, projection_size: int = None, **kwargs: Any) -> None:
20
+ super().__init__(**kwargs)
21
+ self.projection_size = projection_size
22
+
23
+
24
+ class ModelOutputWithProjectionEmbedding(transformers.modeling_outputs.ModelOutput):
25
+ projected_last_hidden_state: torch.FloatTensor
26
+
27
+
28
+ class CvtProjectionHead(torch.nn.Module):
29
+
30
+ def __init__(self, config) -> None:
31
+ super().__init__()
32
+
33
+ # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/models/cvt/modeling_cvt.py#L657
34
+ self.layer_norm = torch.nn.LayerNorm(config.embed_dim[-1], eps=config.layer_norm_eps)
35
+
36
+ # No bias as following layer normalisation with bias:
37
+ self.projection = torch.nn.Linear(config.embed_dim[-1], config.projection_size, bias=False)
38
+
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ x = self.layer_norm(x)
42
+ x = self.projection(x)
43
+ return x
44
+
45
+
46
+ class CvtWithProjectionHead(transformers.CvtPreTrainedModel):
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+
50
+ self.cvt = transformers.CvtModel(config, add_pooling_layer=False)
51
+ self.projection_head = CvtProjectionHead(config)
52
+
53
+ # Initialize weights and apply final processing:
54
+ self.post_init()
55
+
56
+ def forward(
57
+ self,
58
+ pixel_values: Optional[torch.Tensor] = None,
59
+ output_hidden_states: Optional[bool] = None,
60
+ return_dict: Optional[bool] = None,
61
+ ) -> Union[Tuple, ModelOutputWithProjectionEmbedding]:
62
+
63
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
+
65
+ outputs = self.cvt(
66
+ pixel_values,
67
+ output_hidden_states=output_hidden_states,
68
+ return_dict=return_dict,
69
+ )
70
+
71
+ projection = self.projection_head(
72
+ torch.permute(torch.flatten(outputs.last_hidden_state, 2), [0, 2, 1]),
73
+ )
74
+
75
+ if not return_dict:
76
+ return projection
77
+
78
+ return ModelOutputWithProjectionEmbedding(
79
+ projected_last_hidden_state=projection,
80
+ )
81
+
82
+
83
+ class SingleCXREncoderDecoderModel(VisionEncoderDecoderModel):
84
+
85
+ config_class = VisionEncoderDecoderConfig
86
+ base_model_prefix = "vision_encoder_decoder"
87
+ main_input_name = "pixel_values"
88
+ supports_gradient_checkpointing = True
89
+
90
+ def __init__(
91
+ self,
92
+ config: Optional[PretrainedConfig] = None,
93
+ encoder: Optional[PreTrainedModel] = None,
94
+ decoder: Optional[PreTrainedModel] = None,
95
+ ):
96
+
97
+ if config is None and (encoder is None or decoder is None):
98
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
99
+ if config is None:
100
+ config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
101
+ else:
102
+ if not isinstance(config, self.config_class):
103
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
104
+
105
+ config.tie_word_embeddings = False
106
+
107
+ # initialize with config
108
+ PreTrainedModel.__init__(self, config)
109
+
110
+ # Encoder:
111
+ if encoder is None:
112
+ encoder = CvtWithProjectionHead(config=config.encoder)
113
+
114
+ # Decoder:
115
+ if decoder is None:
116
+ decoder = transformers.BertLMHeadModel(config=config.decoder)
117
+
118
+ self.encoder = encoder
119
+ self.decoder = decoder
120
+
121
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
122
+ logger.warning(
123
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
124
+ f" {self.config.encoder}"
125
+ )
126
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
127
+ logger.warning(
128
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
129
+ f" {self.config.decoder}"
130
+ )
131
+
132
+ self.encoder.config = self.config.encoder
133
+ self.decoder.config = self.config.decoder
134
+
135
+ # config.add_cross_attention = True
136
+ # config.is_decoder = True
137
+
138
+ def forward(
139
+ self,
140
+ pixel_values: Optional[torch.FloatTensor] = None,
141
+ decoder_input_ids: Optional[torch.LongTensor] = None,
142
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
143
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
144
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
145
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
146
+ labels: Optional[torch.LongTensor] = None,
147
+ use_cache: Optional[bool] = None,
148
+ output_attentions: Optional[bool] = None,
149
+ output_hidden_states: Optional[bool] = None,
150
+ return_dict: Optional[bool] = None,
151
+ **kwargs,
152
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
153
+
154
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
155
+
156
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
157
+
158
+ kwargs_decoder = {
159
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
160
+ }
161
+
162
+ if encoder_outputs is None:
163
+ if pixel_values is None:
164
+ raise ValueError("You have to specify pixel_values")
165
+
166
+ encoder_outputs = self.encoder(
167
+ pixel_values,
168
+ output_hidden_states=output_hidden_states,
169
+ return_dict=return_dict,
170
+ **kwargs_encoder,
171
+ ) # CvT does not support output_attentions.
172
+ elif isinstance(encoder_outputs, tuple):
173
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
174
+
175
+ encoder_hidden_states = encoder_outputs[0]
176
+ encoder_attention_mask = None
177
+
178
+ decoder_outputs = self.decoder(
179
+ input_ids=decoder_input_ids,
180
+ attention_mask=decoder_attention_mask,
181
+ encoder_hidden_states=encoder_hidden_states,
182
+ encoder_attention_mask=encoder_attention_mask,
183
+ inputs_embeds=decoder_inputs_embeds,
184
+ output_attentions=output_attentions,
185
+ output_hidden_states=output_hidden_states,
186
+ use_cache=use_cache,
187
+ past_key_values=past_key_values,
188
+ return_dict=return_dict,
189
+ **kwargs_decoder,
190
+ )
191
+
192
+ # Loss:
193
+ loss = None
194
+ if labels is not None:
195
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
196
+ loss_fct = CrossEntropyLoss()
197
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
198
+
199
+ if not return_dict:
200
+ if loss is not None:
201
+ return (loss,) + decoder_outputs + encoder_outputs
202
+ else:
203
+ return decoder_outputs + encoder_outputs
204
+
205
+ return Seq2SeqLMOutput(
206
+ loss=loss,
207
+ logits=decoder_outputs.logits,
208
+ past_key_values=decoder_outputs.past_key_values,
209
+ decoder_hidden_states=decoder_outputs.hidden_states,
210
+ decoder_attentions=decoder_outputs.attentions,
211
+ cross_attentions=decoder_outputs.cross_attentions,
212
+ encoder_last_hidden_state=encoder_outputs.projected_last_hidden_state,
213
+ # encoder_hidden_states=encoder_outputs.hidden_states,
214
+ # encoder_attentions=encoder_outputs.attentions,
215
+ )
216
+
217
+ def prepare_inputs_for_generation(
218
+ self,
219
+ input_ids,
220
+ special_token_ids,
221
+ past_key_values=None,
222
+ attention_mask=None,
223
+ use_cache=None,
224
+ encoder_outputs=None,
225
+ **kwargs,
226
+ ):
227
+ """
228
+ Modification of:
229
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py#L660
230
+ """
231
+
232
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
233
+ decoder_attention_mask = decoder_inputs['attention_mask'] if 'attention_mask' in decoder_inputs else None
234
+
235
+ if not past_key_values:
236
+ token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids)
237
+ else:
238
+ token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids)
239
+
240
+ input_dict = {
241
+ 'attention_mask': attention_mask,
242
+ 'decoder_attention_mask': decoder_attention_mask,
243
+ 'decoder_input_ids': decoder_inputs['input_ids'],
244
+ 'decoder_token_type_ids': token_type_ids,
245
+ 'encoder_outputs': encoder_outputs,
246
+ 'past_key_values': decoder_inputs['past_key_values'],
247
+ 'use_cache': use_cache,
248
+ }
249
+ return input_dict
250
+
251
+ def token_ids_to_token_type_ids(self, token_ids, special_token_ids, token_type_id_sections=None):
252
+ """
253
+ Extract token type identifiers from the token identifiers.
254
+
255
+ Argument/s:
256
+ token_ids - token identifiers.
257
+ special_token_ids - special token identifiers that indicate the separation between sections.
258
+ token_type_id_section - token type identifier for each section.
259
+
260
+ Returns:
261
+ token_type_ids - token type identifiers.
262
+ """
263
+
264
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
265
+
266
+ mbatch_size, seq_len = token_ids.shape
267
+ token_type_ids = torch.full_like(token_ids, token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
268
+
269
+ for i, j in enumerate(special_token_ids):
270
+ # Find first occurrence of special tokens that indicate the boundary between sections:
271
+ cols = (token_ids == j).int().argmax(dim=1)
272
+ rows = torch.arange(mbatch_size, device=token_ids.device)
273
+
274
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
275
+ cols += 1
276
+
277
+ # Ensure that the column index is not out of bounds. If 0, then token_id not present.
278
+ # This is safe as index 0 is always a special token (now equal to 1 due to +1):
279
+ rows = rows[torch.logical_and(cols != 1, cols < seq_len)]
280
+ cols = cols[torch.logical_and(cols != 1, cols < seq_len)]
281
+
282
+ # Indices to that correspond to the second sequence:
283
+ if rows.nelement() != 0:
284
+ ids = torch.stack([
285
+ torch.stack([x, z]) for (x, y) in zip(rows, cols) for z in torch.arange(
286
+ y, seq_len, device=token_ids.device,
287
+ )
288
+ ])
289
+
290
+ token_type_ids[ids[:, 0], ids[:, 1]] = token_type_id_sections[i + 1]
291
+
292
+ return token_type_ids
293
+
294
+ def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids, token_type_id_sections=None):
295
+ """
296
+ Extract token type identifiers from the token identifiers if past != None.
297
+
298
+ Argument/s:
299
+ token_ids - token identifiers.
300
+ special_token_ids - special token identifiers that indicate the separation between sections.
301
+
302
+ Returns:
303
+ token_type_ids - token type identifiers.
304
+ """
305
+
306
+ token_type_id_sections = token_type_id_sections if token_type_id_sections is not None else list(range(len(special_token_ids) + 1))
307
+ token_type_ids = torch.full([token_ids.shape[0], 1], token_type_id_sections[0], dtype=torch.long, device=token_ids.device)
308
+
309
+ # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
310
+ token_ids = token_ids[:, :-1]
311
+
312
+ for i, j in enumerate(special_token_ids):
313
+
314
+ # Find first occurrence of special token, which indicates the boundary between sections:
315
+ exists = torch.any(token_ids == j, dim=1, keepdim=True)
316
+ token_type_ids[exists] = token_type_id_sections[i + 1]
317
+
318
+ return token_type_ids
319
+
320
+ def tokenize_report_teacher_forcing(self, findings: str, impression: str, tokenizer):
321
+ """
322
+ Tokenize the reports and creates the inputs and targets for teacher forcing.
323
+
324
+ Argument/s:
325
+ findings - findings section.
326
+ impression - impression section.
327
+ return_token_type_ids - return the token type identifiers.
328
+ tokenizer - Hugging Face tokenizer.
329
+
330
+ Returns:
331
+ decoder_input_ids - the token identifiers for the input of the decoder.
332
+ decoder_attention_mask - the attention mask for the decoder_input_ids.
333
+ label_ids - the label token identifiers for the decoder.
334
+ """
335
+
336
+ # Prepare the sections for the tokenizer by placing special tokens between each section:
337
+ report = [f'{tokenizer.bos_token}{i}{tokenizer.sep_token}{j}{tokenizer.eos_token}' for i, j in
338
+ zip(findings, impression)]
339
+
340
+ # Tokenize the report:
341
+ tokenized = tokenizer(
342
+ report,
343
+ padding='longest',
344
+ truncation=True,
345
+ max_length=self.decoder_max_len + 1, # +1 to account for the bias between input and target.
346
+ return_tensors='pt',
347
+ return_token_type_ids=False,
348
+ add_special_tokens=False,
349
+ ).to(self.device)
350
+
351
+ # Modify for language modelling:
352
+ batch_dict = {
353
+
354
+ # Labels for the decoder (shifted right by one for autoregression):
355
+ 'label_ids': tokenized['input_ids'][:, 1:].detach().clone(),
356
+
357
+ # Remove last token identifier to match the sequence length of the labels:
358
+ 'decoder_input_ids': tokenized['input_ids'][:, :-1],
359
+
360
+ # Attention mask for the decoder_input_ids (remove first token so that the eos_token_id is not considered):
361
+ 'decoder_attention_mask': tokenized['attention_mask'][:, 1:],
362
+ }
363
+
364
+ return batch_dict
365
+
366
+ def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer):
367
+ """
368
+ Split the token identifiers into sections, then convert the token identifiers into strings.
369
+
370
+ Argument/s:
371
+ token_ids - token identifiers.
372
+ special_token_ids - special token identifiers that indicate the end of each section.
373
+ tokenizer - Hugging Face tokenizer.
374
+
375
+ Returns:
376
+ token_type_ids - token type identifiers.
377
+ """
378
+
379
+ _, seq_len = token_ids.shape
380
+
381
+ # The number of sections is the same as the number of special_token_ids:
382
+ num_sections = len(special_token_ids)
383
+
384
+ sections = {k: [] for k in range(num_sections)}
385
+
386
+ for i in token_ids:
387
+ prev_col = 0
388
+ for j, k in enumerate(special_token_ids):
389
+
390
+ # The maximum sequence length was exceeded, thus no more tokens:
391
+ if prev_col >= seq_len:
392
+ sections[j].append('')
393
+ continue
394
+
395
+ # Find first occurrence of special tokens that indicate the boundary between sections:
396
+ col = (i == k).int().argmax().item()
397
+
398
+ # If equal to 0, token was not found, set the column to the sequence length (as the decoder exceeded
399
+ # the maximum sequence length):
400
+ if col == 0:
401
+ col = seq_len
402
+
403
+ # Extract section token identifiers:
404
+ section_token_ids = i[prev_col:col]
405
+ prev_col = col
406
+ section_string = tokenizer.decode(section_token_ids, skip_special_tokens=True)
407
+
408
+ sections[j].append(section_string)
409
+
410
+ return tuple(sections.values())