File size: 5,818 Bytes
8cc04d6 f1f6ae0 8cc04d6 b990504 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 fce0b41 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 8cc04d6 f1f6ae0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor as T
from transformers import BertForMaskedLM
from transformers.modeling_outputs import ModelOutput
from .configuration_cxrbert import CXRBertConfig
BERTTupleOutput = Tuple[T, T, T, T, T]
@dataclass
class CXRBertOutput(ModelOutput):
last_hidden_state: torch.FloatTensor
logits: Optional[torch.FloatTensor] = None
cls_projected_embedding: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BertProjectionHead(nn.Module):
"""Projection head to be used with BERT CLS token.
This is similar to ``BertPredictionHeadTransform`` in HuggingFace.
:param config: Configuration for BERT.
"""
def __init__(self, config: CXRBertConfig) -> None:
super().__init__()
self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
self.transform_act_fn = nn.functional.gelu
self.LayerNorm = nn.LayerNorm(config.projection_size, eps=1e-12)
self.dense_to_output = nn.Linear(config.projection_size, config.projection_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense_to_hidden(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dense_to_output(hidden_states)
return hidden_states
class CXRBertModel(BertForMaskedLM):
"""
Implements the CXR-BERT model outlined in the manuscript:
Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
https://link.springer.com/chapter/10.1007/978-3-031-20059-5_1
Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is
used to align the latent vectors of image and text modalities.
"""
config_class = CXRBertConfig # type: ignore
def __init__(self, config: CXRBertConfig):
super().__init__(config)
self.cls_projection_head = BertProjectionHead(config)
self.init_weights()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_cls_projected_embedding: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Any
) -> Union[BERTTupleOutput, CXRBertOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
bert_for_masked_lm_output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=True,
)
last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
cls_projected_embedding = (
self.cls_projection_head(last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
)
if return_dict:
return CXRBertOutput(
last_hidden_state=last_hidden_state,
logits=bert_for_masked_lm_output.logits,
cls_projected_embedding=cls_projected_embedding,
hidden_states=bert_for_masked_lm_output.hidden_states if output_hidden_states else None,
attentions=bert_for_masked_lm_output.attentions,
)
else:
return (
last_hidden_state,
bert_for_masked_lm_output.logits,
cls_projected_embedding,
bert_for_masked_lm_output.hidden_states,
bert_for_masked_lm_output.attentions,
)
def get_projected_text_embeddings(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, normalize_embeddings: bool = True
) -> torch.Tensor:
"""
Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
The joint latent space is trained using a contrastive objective between image and text data modalities.
:param input_ids: (batch_size, sequence_length)
:param attention_mask: (batch_size, sequence_length)
:param normalize_embeddings: Whether to l2-normalise the embeddings.
:return: (batch_size, projection_size)
"""
outputs = self.forward(
input_ids=input_ids, attention_mask=attention_mask, output_cls_projected_embedding=True, return_dict=True
)
assert isinstance(outputs, CXRBertOutput)
cls_projected_embedding = outputs.cls_projected_embedding
assert cls_projected_embedding is not None
if normalize_embeddings:
return F.normalize(cls_projected_embedding, dim=1)
return cls_projected_embedding |