| | from collections.abc import Generator, Iterable |
| | from dataclasses import dataclass |
| | from enum import StrEnum |
| |
|
| | import pprint |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModel, |
| | BatchEncoding, |
| | ModernBertModel, |
| | PreTrainedConfig, |
| | PreTrainedModel, |
| | PreTrainedTokenizer, |
| | ) |
| | from transformers.modeling_outputs import TokenClassifierOutput |
| |
|
| | BATCH_SIZE = 16 |
| |
|
| |
|
| | class ModelURI(StrEnum): |
| | BASE = "answerdotai/ModernBERT-base" |
| | LARGE = "answerdotai/ModernBERT-large" |
| |
|
| |
|
| | @dataclass(slots=True, frozen=True) |
| | class LexicalExample: |
| | concept: str |
| | definition: str |
| |
|
| |
|
| | @dataclass(slots=True, frozen=True) |
| | class PaddedBatch: |
| | input_ids: torch.Tensor |
| | attention_mask: torch.Tensor |
| |
|
| |
|
| | class DisamBertSingleSense(PreTrainedModel): |
| | def __init__(self, config: PreTrainedConfig): |
| | super().__init__(config) |
| | if config.init_basemodel: |
| | self.BaseModel = AutoModel.from_pretrained(config.name_or_path, device_map="auto") |
| | self.config.vocab_size += 2 |
| | self.BaseModel.resize_token_embeddings(self.config.vocab_size) |
| | else: |
| | self.BaseModel = ModernBertModel(config) |
| | config.init_basemodel = False |
| |
|
| | self.loss = nn.CrossEntropyLoss() |
| | self.post_init() |
| |
|
| | @classmethod |
| | def from_base(cls, base_id: ModelURI): |
| | config = AutoConfig.from_pretrained(base_id) |
| | config.init_basemodel = True |
| | return cls(config) |
| |
|
| | def add_special_tokens(self, start: int, end: int): |
| | self.config.start_token = start |
| | self.config.end_token = end |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | candidate_tokens: torch.Tensor, |
| | candidate_attention_masks: torch.Tensor, |
| | candidate_mapping: torch.Tensor, |
| | labels: Iterable[int] | None = None, |
| | output_hidden_states: bool = False, |
| | output_attentions: bool = False, |
| | ) -> TokenClassifierOutput: |
| | base_model_output = self.BaseModel( |
| | input_ids, |
| | attention_mask, |
| | output_hidden_states=output_hidden_states, |
| | output_attentions=output_attentions, |
| | ) |
| | token_vectors = base_model_output.last_hidden_state |
| | selection = torch.zeros_like(input_ids, dtype=token_vectors.dtype) |
| | starts = (input_ids == self.config.start_token).nonzero() |
| | ends = (input_ids == self.config.end_token).nonzero() |
| | for startpos, endpos in zip(starts, ends, strict=True): |
| | selection[startpos[0], startpos[1] : endpos[1] + 1] = 1.0 |
| | entity_vectors = torch.einsum("ijk,ij->ik", token_vectors, selection) |
| | gloss_vectors = self.gloss_vectors( |
| | candidate_tokens, candidate_attention_masks, candidate_mapping |
| | ) |
| | logits = torch.einsum("ij,ikj->ik", entity_vectors, gloss_vectors) |
| |
|
| | return TokenClassifierOutput( |
| | logits=logits, |
| | loss=self.loss(logits, labels) if labels is not None else None, |
| | hidden_states=base_model_output.hidden_states if output_hidden_states else None, |
| | attentions=base_model_output.attentions if output_attentions else None, |
| | ) |
| |
|
| | def gloss_vectors(self, candidates, candidate_attention_masks, candidate_mapping): |
| | with self.device: |
| | vectors = self.BaseModel(candidates, candidate_attention_masks).last_hidden_state[:, 0] |
| | chunks = [ |
| | torch.squeeze(vectors[(candidate_mapping == sentence_index).nonzero()], |
| | dim=1) |
| | for sentence_index in torch.unique(candidate_mapping) |
| | ] |
| | maxlen = max(chunk.shape[0] for chunk in chunks) |
| | return torch.stack( |
| | [ |
| | torch.cat([chunk, torch.zeros((maxlen - chunk.shape[0], self.config.hidden_size))]) |
| | for chunk in chunks |
| | ] |
| | ) |
| |
|
| |
|
| | class CandidateLabeller: |
| | def __init__(self, tokenizer: PreTrainedTokenizer, ontology: Generator[LexicalExample], device:torch.device): |
| | self.tokenizer = tokenizer |
| | self.device = device |
| | self.gloss_tokens = { |
| | example.concept: self.tokenizer(example.definition, padding=True) |
| | for example in ontology |
| | } |
| |
|
| | def __call__(self, batch: dict) -> dict: |
| | with self.device: |
| | encoded = [ |
| | BatchEncoding( |
| | {"input_ids": example["input_ids"], "attention_mask": example["attention_mask"]} |
| | ) |
| | for example in batch |
| | ] |
| | tokens = self.tokenizer.pad(encoded, padding=True, return_tensors="pt") |
| | candidate_tokens = self.tokenizer.pad( |
| | [self.gloss_tokens[concept] for example in batch for concept in example["candidates"]], |
| | padding=True, |
| | return_attention_mask=True, |
| | return_tensors="pt", |
| | ) |
| | result = { |
| | "input_ids": tokens.input_ids, |
| | "attention_mask": tokens.attention_mask, |
| | "candidate_tokens": candidate_tokens.input_ids, |
| | "candidate_attention_masks": candidate_tokens.attention_mask, |
| | "candidate_mapping": torch.cat( |
| | [ |
| | torch.tensor([i] * len(example["candidates"])) |
| | for (i, example) in enumerate(batch) |
| | ] |
| | ), |
| | } |
| | if "label" in batch[0]: |
| | result["labels"] = torch.tensor( |
| | [example["candidates"].index(example["label"]) for example in batch] |
| | ) |
| | return result |
| |
|