pamessina commited on
Commit
b3ccc39
1 Parent(s): 1768035

try copying original CXR-BERT's source code to see if it works

Browse files
config.json CHANGED
@@ -1,13 +1,12 @@
1
  {
2
- "_name_or_path": "microsoft/BiomedVLP-CXR-BERT-specialized",
3
- "_commit_hash": "6cfc310817fb7d86762d888ced1e3709c57ac578",
4
  "architectures": [
5
  "CXRBertModel"
6
  ],
7
  "attention_probs_dropout_prob": 0.25,
8
  "auto_map": {
9
- "AutoConfig": "microsoft/BiomedVLP-CXR-BERT-specialized--configuration_cxrbert.CXRBertConfig",
10
- "AutoModel": "microsoft/BiomedVLP-CXR-BERT-specialized--modeling_cxrbert.CXRBertModel"
11
  },
12
  "classifier_dropout": null,
13
  "gradient_checkpointing": false,
 
1
  {
2
+ "_name_or_path": "pamessina/CXRFE",
 
3
  "architectures": [
4
  "CXRBertModel"
5
  ],
6
  "attention_probs_dropout_prob": 0.25,
7
  "auto_map": {
8
+ "AutoConfig": "pamessina/CXRFE--configuration_cxrbert.CXRBertConfig",
9
+ "AutoModel": "pamessina/CXRFE--modeling_cxrbert.CXRBertModel"
10
  },
11
  "classifier_dropout": null,
12
  "gradient_checkpointing": false,
configuration_cxrbert.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized/blob/refs%2Fpr%2F5/configuration_cxrbert.py
2
+ # Date: 2024-07-01
3
+
4
+ # ------------------------------------------------------------------------------------------
5
+ # Copyright (c) Microsoft Corporation. All rights reserved.
6
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
7
+ # ------------------------------------------------------------------------------------------
8
+
9
+ from typing import Any
10
+
11
+ from transformers import BertConfig, BertTokenizer
12
+
13
+
14
+ class CXRBertConfig(BertConfig):
15
+ """
16
+ Config class for CXR-BERT model.
17
+ :param projection_size: Dimensionality of the joint latent space.
18
+ """
19
+
20
+ model_type = "cxr-bert"
21
+
22
+ def __init__(self, projection_size: int = 128, **kwargs: Any) -> None:
23
+ super().__init__(**kwargs)
24
+ self.projection_size = projection_size
25
+
26
+
27
+ class CXRBertTokenizer(BertTokenizer):
28
+ def __init__(self, **kwargs: Any) -> None:
29
+ super().__init__(**kwargs)
modeling_cxrbert.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source: https://huggingface.co/microsoft/BiomedVLP-CXR-BERT-specialized/blob/refs%2Fpr%2F5/modeling_cxrbert.py
2
+ # Date: 2024-07-01
3
+
4
+ # ------------------------------------------------------------------------------------------
5
+ # Copyright (c) Microsoft Corporation. All rights reserved.
6
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
7
+ # ------------------------------------------------------------------------------------------
8
+
9
+ from typing import Any, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch import Tensor as T
15
+ from transformers import BertForMaskedLM
16
+ from transformers.modeling_outputs import ModelOutput
17
+
18
+ from .configuration_cxrbert import CXRBertConfig
19
+
20
+ from dataclasses import dataclass # manually added due to this bug: https://github.com/huggingface/transformers/issues/30412
21
+
22
+ BERTTupleOutput = Tuple[T, T, T, T, T]
23
+
24
+ @dataclass # manually added due to this bug: https://github.com/huggingface/transformers/issues/30412
25
+ class CXRBertOutput(ModelOutput):
26
+ last_hidden_state: torch.FloatTensor = None # None added. Not present in the original code
27
+ logits: torch.FloatTensor = None # None added. Not present in the original code
28
+ cls_projected_embedding: Optional[torch.FloatTensor] = None
29
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
30
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
31
+
32
+
33
+ class BertProjectionHead(nn.Module):
34
+ '''
35
+ Projection head to be used with BERT CLS token, it's similar to `BertPredictionHeadTransform` in HuggingFace library.
36
+ :param config: CXRBertConfig
37
+ :return: (batch_size, output_size)
38
+ '''
39
+ def __init__(self, config: CXRBertConfig) -> None:
40
+ super().__init__()
41
+ self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
42
+ self.transform_act_fn = nn.functional.gelu
43
+ self.LayerNorm = nn.LayerNorm(config.projection_size, eps=1e-12)
44
+ self.dense_to_output = nn.Linear(config.projection_size, config.projection_size)
45
+
46
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
47
+ hidden_states = self.dense_to_hidden(hidden_states)
48
+ hidden_states = self.transform_act_fn(hidden_states)
49
+ hidden_states = self.LayerNorm(hidden_states)
50
+ hidden_states = self.dense_to_output(hidden_states)
51
+
52
+ return hidden_states
53
+
54
+
55
+ class CXRBertModel(BertForMaskedLM):
56
+ """
57
+ Implements the CXR-BERT model outlined in the manuscript:
58
+ Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
59
+ https://arxiv.org/abs/2204.09817
60
+ Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is used to align
61
+ the latent vectors of image and text modalities.
62
+ """
63
+
64
+ config_class = CXRBertConfig
65
+
66
+ def __init__(self, config: CXRBertConfig):
67
+ super().__init__(config)
68
+
69
+ self.cls_projection_head = BertProjectionHead(config)
70
+ self.init_weights()
71
+
72
+ def forward(
73
+ self,
74
+ input_ids: torch.Tensor,
75
+ attention_mask: torch.Tensor,
76
+ token_type_ids: Optional[torch.Tensor] = None,
77
+ position_ids: Optional[torch.Tensor] = None,
78
+ head_mask: Optional[torch.Tensor] = None,
79
+ inputs_embeds: Optional[torch.Tensor] = None,
80
+ output_attentions: Optional[bool] = None,
81
+ output_hidden_states: Optional[bool] = None,
82
+ output_cls_projected_embedding: Optional[bool] = None,
83
+ return_dict: Optional[bool] = None,
84
+ **kwargs: Any
85
+ ) -> Union[BERTTupleOutput, CXRBertOutput]:
86
+
87
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
+
89
+ bert_for_masked_lm_output = super().forward(input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ token_type_ids=token_type_ids,
92
+ position_ids=position_ids,
93
+ head_mask=head_mask,
94
+ inputs_embeds=inputs_embeds,
95
+ output_attentions=output_attentions,
96
+ output_hidden_states=True,
97
+ return_dict=True)
98
+
99
+ last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
100
+ cls_projected_embedding = self.cls_projection_head(last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
101
+
102
+ if return_dict:
103
+ return CXRBertOutput(
104
+ last_hidden_state=last_hidden_state,
105
+ logits=bert_for_masked_lm_output.logits,
106
+ cls_projected_embedding=cls_projected_embedding,
107
+ hidden_states=bert_for_masked_lm_output.hidden_states if output_hidden_states else None,
108
+ attentions=bert_for_masked_lm_output.attentions,
109
+ )
110
+ else:
111
+ return (
112
+ last_hidden_state,
113
+ bert_for_masked_lm_output.logits,
114
+ cls_projected_embedding,
115
+ bert_for_masked_lm_output.hidden_states,
116
+ bert_for_masked_lm_output.attentions,)
117
+
118
+ def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
119
+ """
120
+ Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
121
+ The joint latent space is trained using a contrastive objective between image and text data modalities.
122
+ :param input_ids: (batch_size, sequence_length)
123
+ :param attention_mask: (batch_size, sequence_length)
124
+ :return: (batch_size, projection_size)
125
+ """
126
+
127
+ outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask,
128
+ output_cls_projected_embedding=True, return_dict=True)
129
+ assert isinstance(outputs, CXRBertOutput)
130
+
131
+ normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
132
+ return normalized_cls_embedding
tokenizer_config.json CHANGED
@@ -43,7 +43,7 @@
43
  },
44
  "auto_map": {
45
  "AutoTokenizer": [
46
- "microsoft/BiomedVLP-CXR-BERT-specialized--configuration_cxrbert.CXRBertTokenizer",
47
  null
48
  ]
49
  },
 
43
  },
44
  "auto_map": {
45
  "AutoTokenizer": [
46
+ "pamessina/CXRFE--configuration_cxrbert.CXRBertTokenizer",
47
  null
48
  ]
49
  },