Ozan Oktay commited on
Commit
8cc04d6
1 Parent(s): eccac43
Files changed (4) hide show
  1. config.json +31 -0
  2. configuration_cxrbert.py +16 -0
  3. modeling_cxrbert.py +128 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/tmp/hf_demo_may_11th",
3
+ "architectures": [
4
+ "CXRBertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.25,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_cxrbert.CXRBertConfig",
9
+ "AutoModel": "modeling_cxrbert.CXRBertModel"
10
+ },
11
+ "classifier_dropout": null,
12
+ "gradient_checkpointing": false,
13
+ "hidden_act": "gelu",
14
+ "hidden_dropout_prob": 0.25,
15
+ "hidden_size": 768,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 3072,
18
+ "layer_norm_eps": 1e-12,
19
+ "max_position_embeddings": 512,
20
+ "model_type": "cxr-bert",
21
+ "num_attention_heads": 12,
22
+ "num_hidden_layers": 12,
23
+ "pad_token_id": 0,
24
+ "position_embedding_type": "absolute",
25
+ "projection_size": 128,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.18.0",
28
+ "type_vocab_size": 2,
29
+ "use_cache": true,
30
+ "vocab_size": 30522
31
+ }
configuration_cxrbert.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+
6
+ from typing import Any
7
+
8
+ from transformers import BertConfig
9
+
10
+
11
+ class CXRBertConfig(BertConfig):
12
+ model_type = "cxr-bert"
13
+
14
+ def __init__(self, projection_size: int = 128, **kwargs: Any) -> None:
15
+ super().__init__(**kwargs)
16
+ self.projection_size = projection_size
modeling_cxrbert.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
4
+ # ------------------------------------------------------------------------------------------
5
+
6
+ from typing import Any, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from torch import Tensor as T
12
+ from transformers import BertForMaskedLM
13
+ from transformers.modeling_outputs import ModelOutput
14
+
15
+ from .configuration_cxrbert import CXRBertConfig
16
+
17
+ BERTTupleOutput = Tuple[T, T, T, T, T]
18
+
19
+ class CXRBertOutput(ModelOutput):
20
+ last_hidden_state: torch.FloatTensor
21
+ prediction_logits: torch.FloatTensor
22
+ cls_projected_embedding: Optional[torch.FloatTensor] = None
23
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
24
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
25
+
26
+
27
+ class BertProjectionHead(nn.Module):
28
+ '''
29
+ Projection head to be used with BERT CLS token, it's similar to `BertPredictionHeadTransform` in HuggingFace library.
30
+ :param config: CXRBertConfig
31
+ :return: (batch_size, output_size)
32
+ '''
33
+ def __init__(self, config: CXRBertConfig) -> None:
34
+ super().__init__()
35
+ self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
36
+ self.transform_act_fn = nn.functional.gelu
37
+ self.LayerNorm = nn.LayerNorm(config.projection_size, eps=1e-12)
38
+ self.dense_to_output = nn.Linear(config.projection_size, config.projection_size)
39
+
40
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
41
+ hidden_states = self.dense_to_hidden(hidden_states)
42
+ hidden_states = self.transform_act_fn(hidden_states)
43
+ hidden_states = self.LayerNorm(hidden_states)
44
+ hidden_states = self.dense_to_output(hidden_states)
45
+
46
+ return hidden_states
47
+
48
+
49
+ class CXRBertModel(BertForMaskedLM):
50
+ """
51
+ Implements the CXR-BERT model outlined in the manuscript:
52
+ Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
53
+ https://arxiv.org/abs/2204.09817
54
+
55
+ Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is used to align
56
+ the latent vectors of image and text modalities.
57
+ """
58
+
59
+ config_class = CXRBertConfig
60
+
61
+ def __init__(self, config: CXRBertConfig):
62
+ super().__init__(config)
63
+
64
+ self.cls_projection_head = BertProjectionHead(config)
65
+ self.init_weights()
66
+
67
+ def forward(
68
+ self,
69
+ input_ids: torch.Tensor,
70
+ attention_mask: torch.Tensor,
71
+ token_type_ids: Optional[torch.Tensor] = None,
72
+ position_ids: Optional[torch.Tensor] = None,
73
+ head_mask: Optional[torch.Tensor] = None,
74
+ inputs_embeds: Optional[torch.Tensor] = None,
75
+ output_attentions: Optional[bool] = None,
76
+ output_hidden_states: Optional[bool] = None,
77
+ output_cls_projected_embedding: Optional[bool] = None,
78
+ return_dict: Optional[bool] = None,
79
+ **kwargs: Any
80
+ ) -> Union[BERTTupleOutput, CXRBertOutput]:
81
+
82
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
83
+
84
+ bert_for_masked_lm_output = super().forward(input_ids=input_ids,
85
+ attention_mask=attention_mask,
86
+ token_type_ids=token_type_ids,
87
+ position_ids=position_ids,
88
+ head_mask=head_mask,
89
+ inputs_embeds=inputs_embeds,
90
+ output_attentions=output_attentions,
91
+ output_hidden_states=True,
92
+ return_dict=True)
93
+
94
+ last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
95
+ cls_projected_embedding = self.cls_projection_head(last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
96
+
97
+ if return_dict:
98
+ return CXRBertOutput(
99
+ last_hidden_state=last_hidden_state,
100
+ prediction_logits=bert_for_masked_lm_output.logits,
101
+ cls_projected_embedding=cls_projected_embedding,
102
+ hidden_states=bert_for_masked_lm_output.hidden_states if output_hidden_states else None,
103
+ attentions=bert_for_masked_lm_output.attentions,
104
+ )
105
+ else:
106
+ return (
107
+ last_hidden_state,
108
+ bert_for_masked_lm_output.logits,
109
+ cls_projected_embedding,
110
+ bert_for_masked_lm_output.hidden_states,
111
+ bert_for_masked_lm_output.attentions,)
112
+
113
+ def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
114
+ """
115
+ Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
116
+ The joint latent space is trained using a contrastive objective between image and text data modalities.
117
+
118
+ :param input_ids: (batch_size, sequence_length)
119
+ :param attention_mask: (batch_size, sequence_length)
120
+ :return: (batch_size, projection_size)
121
+ """
122
+
123
+ outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask,
124
+ output_cls_projected_embedding=True, return_dict=True)
125
+ assert isinstance(outputs, CXRBertOutput)
126
+
127
+ normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
128
+ return normalized_cls_embedding
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a41d2f4b33e5bdbacabecb8d34c4420163e2165b9fdc6915014e8263a0d0782b
3
+ size 438588639