BartekSadlej's picture
Upload model
3bee1db verified
raw history blame
No virus
8.33 kB
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
MaskedLMOutput,
BaseModelOutput,
SequenceClassifierOutput,
)
from enum import Enum
from .config import ILKTConfig
def cls_pooling(last_hidden_state, attention_mask):
return last_hidden_state[:, 0, :]
def create_head_blocks(
hidden_size: int,
n_dense: int,
use_batch_norm: bool,
use_layer_norm: bool,
dropout: float,
**kwargs,
) -> nn.Module:
blocks = []
for _ in range(n_dense):
blocks.append(nn.Linear(hidden_size, hidden_size))
if use_batch_norm:
blocks.append(nn.BatchNorm1d(hidden_size))
elif use_layer_norm:
blocks.append(nn.LayerNorm(hidden_size))
blocks.append(nn.ReLU())
if dropout > 0:
blocks.append(nn.Dropout(dropout))
return nn.Sequential(*blocks)
class SentenceEmbeddingHead(nn.Module):
def __init__(
self, backbone_hidden_size: int, embedding_head_config: Dict[str, Any]
):
super().__init__()
self.config = embedding_head_config
self.head = nn.Sequential(
*[
create_head_blocks(backbone_hidden_size, **embedding_head_config),
]
)
def forward(
self, backbone_output: BaseModelOutput, attention_mask: torch.Tensor, **kwargs
) -> BaseModelOutputWithPooling:
if self.config["pool_type"] == "cls":
embeddings = cls_pooling(backbone_output.last_hidden_state, attention_mask)
else:
raise NotImplementedError(
f"Pooling type {self.config['pool_type']} not implemented"
)
embeddings = self.head(embeddings)
if self.config["normalize_embeddings"]:
embeddings = nn.functional.normalize(embeddings, p=2, dim=-1)
return BaseModelOutputWithPooling(
last_hidden_state=backbone_output.last_hidden_state,
pooler_output=embeddings, # type: ignore
)
class MLMHead(nn.Module):
def __init__(
self,
backbone_hidden_size: int,
vocab_size: int,
mlm_head_config: Dict[str, Any],
):
super().__init__()
self.config = mlm_head_config
self.head = nn.Sequential(
*[
create_head_blocks(backbone_hidden_size, **mlm_head_config),
nn.Linear(backbone_hidden_size, vocab_size),
]
)
def forward(
self,
backbone_output: BaseModelOutput,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> MaskedLMOutput:
prediction_scores = self.head(backbone_output.last_hidden_state)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
prediction_scores.view(-1, prediction_scores.size(-1)),
labels.view(-1),
)
return MaskedLMOutput(loss=loss, logits=prediction_scores)
class CLSHead(nn.Module):
def __init__(
self,
backbone_hidden_size: int,
n_classes: int,
cls_head_config: Dict[str, Any],
):
super().__init__()
self.config = cls_head_config
self.head = nn.Sequential(
*[
create_head_blocks(backbone_hidden_size, **cls_head_config),
nn.Linear(backbone_hidden_size, n_classes),
]
)
def forward(
self,
backbone_output: BaseModelOutput,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> SequenceClassifierOutput:
if self.config["pool_type"] == "cls":
embeddings = cls_pooling(backbone_output.last_hidden_state, attention_mask)
else:
raise NotImplementedError(
f"Pooling type {self.config['pool_type']} not implemented"
)
prediction_scores = self.head(embeddings)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
prediction_scores.view(-1, prediction_scores.size(-1)),
labels.view(-1),
)
return SequenceClassifierOutput(loss=loss, logits=prediction_scores)
class ForwardRouting(Enum):
GET_SENTENCE_EMBEDDING = "get_sentence_embedding"
GET_MLM_OUTPUT = "get_mlm_output"
GET_CLS_OUTPUT = "get_cls_output"
class ILKTModel(PreTrainedModel):
config_class = ILKTConfig
def __init__(self, config: ILKTConfig):
super().__init__(config)
backbone_config = AutoConfig.from_pretrained(**config.backbone_config)
pretrained_model_name_or_path = config.backbone_config[
"pretrained_model_name_or_path"
]
self.backbone = AutoModel.from_pretrained(
pretrained_model_name_or_path, config=backbone_config
)
backbone_hidden_size = backbone_config.hidden_size
self.config.hidden_size = backbone_hidden_size
backbone_vocab_size = backbone_config.vocab_size
self.embedding_head = SentenceEmbeddingHead(
backbone_hidden_size, config.embedding_head_config
)
self.mlm_head = MLMHead(
backbone_hidden_size, backbone_vocab_size, config.mlm_head_config
)
self.cls_heads = nn.ModuleDict(
dict(
[
(
name,
CLSHead(
backbone_hidden_size, n_classes, config.cls_head_config
),
)
for n_classes, name in config.cls_heads
]
)
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
forward_routing: ForwardRouting = ForwardRouting.GET_SENTENCE_EMBEDDING,
**kwargs,
):
if forward_routing == ForwardRouting.GET_SENTENCE_EMBEDDING:
return self.get_sentence_embedding(
input_ids, attention_mask, token_type_ids=token_type_ids
)
elif forward_routing == ForwardRouting.GET_MLM_OUTPUT:
return self.get_mlm_output(
input_ids, attention_mask, token_type_ids=token_type_ids, **kwargs
)
elif forward_routing == ForwardRouting.GET_CLS_OUTPUT:
return self.get_cls_output(
input_ids, attention_mask, token_type_ids=token_type_ids, **kwargs
)
else:
raise ValueError(f"Unknown forward routing {forward_routing}")
def get_sentence_embedding(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
):
backbone_output: BaseModelOutput = self.backbone(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)
embedding_output = self.embedding_head(
backbone_output, attention_mask, **kwargs
)
return embedding_output
def get_mlm_output(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
**kwargs,
):
backbone_output: BaseModelOutput = self.backbone(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)
mlm_output = self.mlm_head(backbone_output, attention_mask, labels, **kwargs)
return mlm_output
def get_cls_output(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
head_name: str,
labels: Optional[torch.Tensor] = None,
**kwargs,
):
backbone_output: BaseModelOutput = self.backbone(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)
if head_name not in self.cls_heads:
raise ValueError(f"Head {head_name} not found in model")
cls_output = self.cls_heads[head_name](
backbone_output, attention_mask, labels, **kwargs
)
return cls_output