DisambertSingleSense-base / DisamBertSingleSense.py
PeteBleackley's picture
End of training
9f5c6cd verified
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from enum import StrEnum
import pprint
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModel,
BatchEncoding,
ModernBertModel,
PreTrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
)
from transformers.modeling_outputs import TokenClassifierOutput
BATCH_SIZE = 16
class ModelURI(StrEnum):
BASE = "answerdotai/ModernBERT-base"
LARGE = "answerdotai/ModernBERT-large"
@dataclass(slots=True, frozen=True)
class LexicalExample:
concept: str
definition: str
@dataclass(slots=True, frozen=True)
class PaddedBatch:
input_ids: torch.Tensor
attention_mask: torch.Tensor
class DisamBertSingleSense(PreTrainedModel):
def __init__(self, config: PreTrainedConfig):
super().__init__(config)
if config.init_basemodel:
self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto")
self.config.vocab_size += 2
self.BaseModel.resize_token_embeddings(self.config.vocab_size)
else:
self.BaseModel = ModernBertModel(config)
config.init_basemodel = False
self.loss = nn.CrossEntropyLoss()
self.post_init()
@classmethod
def from_base(cls, base_id: ModelURI):
config = AutoConfig.from_pretrained(base_id)
config.init_basemodel = True
return cls(config)
def add_special_tokens(self, start: int, end: int):
self.config.start_token = start
self.config.end_token = end
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
candidate_tokens: torch.Tensor,
candidate_attention_masks: torch.Tensor,
candidate_mapping: torch.Tensor,
labels: Iterable[int] | None = None,
output_hidden_states: bool = False,
output_attentions: bool = False,
) -> TokenClassifierOutput:
base_model_output = self.BaseModel(
input_ids,
attention_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
token_vectors = base_model_output.last_hidden_state
selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype)
starts = (input_ids == self.config.start_token).nonzero()
ends = (input_ids == self.config.end_token).nonzero()
for startpos, endpos in zip(starts, ends, strict=True):
selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0
entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection)
gloss_vectors = self.gloss_vectors(
candidate_tokens, candidate_attention_masks, candidate_mapping
)
logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors)
return TokenClassifierOutput(
logits=logits,
loss=self.loss(logits, labels) if labels is not None else None,
hidden_states=base_model_output.hidden_states if output_hidden_states else None,
attentions=base_model_output.attentions if output_attentions else None,
)
def gloss_vectors(self, candidates, candidate_attention_masks, candidate_mapping):
with self.device:
vectors = self.BaseModel(candidates, candidate_attention_masks).last_hidden_state[:, 0]
chunks = [
torch.squeeze(vectors[(candidate_mapping == sentence_index).nonzero()],
dim=1)
for sentence_index in torch.unique(candidate_mapping)
]
maxlen = max(chunk.shape[0] for chunk in chunks)
return torch.stack(
[
torch.cat([chunk, torch.zeros((maxlen - chunk.shape[0], self.config.hidden_size))])
for chunk in chunks
]
)
class CandidateLabeller:
def __init__(self, tokenizer: PreTrainedTokenizer, ontology: Generator[LexicalExample], device:torch.device):
self.tokenizer = tokenizer
self.device = device
self.gloss_tokens = {
example.concept: self.tokenizer(example.definition, padding=True)
for example in ontology
}
def __call__(self, batch: dict) -> dict:
with self.device:
encoded = [
BatchEncoding(
{"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]}
)
for example in batch
]
tokens = self.tokenizer.pad(encoded, padding=True, return_tensors="pt")
candidate_tokens = self.tokenizer.pad(
[self.gloss_tokens[concept] for example in batch for concept in example["candidates"]],
padding=True,
return_attention_mask=True,
return_tensors="pt",
)
result = {
"input_ids": tokens.input_ids,
"attention_mask": tokens.attention_mask,
"candidate_tokens": candidate_tokens.input_ids,
"candidate_attention_masks": candidate_tokens.attention_mask,
"candidate_mapping": torch.cat(
[
torch.tensor([i] * len(example["candidates"]))
for (i, example) in enumerate(batch)
]
),
}
if "label" in batch[0]:
result["labels"] = torch.tensor(
[example["candidates"].index(example["label"]) for example in batch]
)
return result