|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from transformers import BertConfig, BertModel, BertPreTrainedModel, PreTrainedModel |
|
from transformers.models.bert.modeling_bert import BertOnlyMLMHead |
|
|
|
|
|
class BertEmbeddingConfig(BertConfig): |
|
n_output_dims: int |
|
distance_func: str = "euclidean" |
|
|
|
|
|
class BiEncoderConfig(BertEmbeddingConfig): |
|
max_length1: int |
|
max_length2: int |
|
|
|
|
|
class BiEncoder(PreTrainedModel): |
|
config_class = BiEncoderConfig |
|
|
|
def __init__(self, config: BiEncoderConfig): |
|
super().__init__(config) |
|
config1 = _replace_max_length(config, "max_length1") |
|
self.bert1 = BertForEmbedding(config1) |
|
config2 = _replace_max_length(config, "max_length2") |
|
self.bert2 = BertForEmbedding(config2) |
|
self.post_init() |
|
|
|
def forward(self, x1, x2): |
|
y1 = self.forward1(x1) |
|
y2 = self.forward2(x2) |
|
return {"y1": y1, "y2": y2} |
|
|
|
def forward2(self, x2): |
|
y2 = self.bert2(input_ids=x2["input_ids"]) |
|
return y2 |
|
|
|
def forward1(self, x1): |
|
y1 = self.bert1(input_ids=x1["input_ids"]) |
|
return y1 |
|
|
|
|
|
class BiEncoderWithMaskedLM(PreTrainedModel): |
|
config_class = BiEncoderConfig |
|
|
|
def __init__(self, config: BiEncoderConfig): |
|
super().__init__(config=config) |
|
config1 = _replace_max_length(config, "max_length1") |
|
self.bert1 = BertForEmbedding(config1) |
|
self.lm_head1 = BertOnlyMLMHead(config=config1) |
|
|
|
config2 = _replace_max_length(config, "max_length2") |
|
self.bert2 = BertForEmbedding(config2) |
|
self.lm_head2 = BertOnlyMLMHead(config=config2) |
|
self.post_init() |
|
|
|
def forward(self, x1, x2): |
|
y1, state1 = self.bert1.forward_with_state(input_ids=x1["input_ids"]) |
|
y2, state2 = self.bert2.forward_with_state(input_ids=x2["input_ids"]) |
|
scores1 = self.lm_head1(state1) |
|
scores2 = self.lm_head2(state2) |
|
outputs = {"y1": y1, "y2": y2, "scores1": scores1, "scores2": scores2} |
|
return outputs |
|
|
|
|
|
def _replace_max_length(config, length_key): |
|
c1 = config.__dict__.copy() |
|
c1["max_position_embeddings"] = c1.pop(length_key) |
|
config1 = BertEmbeddingConfig(**c1) |
|
return config1 |
|
|
|
|
|
class L2Norm: |
|
def __call__(self, x): |
|
return x / torch.norm(x, p=2, dim=-1, keepdim=True) |
|
|
|
|
|
class BertForEmbedding(BertPreTrainedModel): |
|
config_class = BertEmbeddingConfig |
|
|
|
def __init__(self, config: BertEmbeddingConfig): |
|
super().__init__(config) |
|
n_output_dims = config.n_output_dims |
|
self.fc = torch.nn.Linear(config.hidden_size, n_output_dims) |
|
self.bert = BertModel(config) |
|
self.activation = _get_activation(config.distance_func) |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> torch.Tensor: |
|
embedding, _ = self.forward_with_state( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
return embedding |
|
|
|
def forward_with_state( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
encoded = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
pooler_output = encoded.pooler_output |
|
logits = self.fc(pooler_output) |
|
embedding = self.activation(logits) |
|
return embedding, encoded.last_hidden_state |
|
|
|
|
|
def _get_activation(distance_func: str): |
|
if distance_func == "euclidean": |
|
activation = torch.nn.Tanh() |
|
elif distance_func == "angular": |
|
activation = L2Norm() |
|
else: |
|
raise NotImplementedError() |
|
return activation |
|
|