| import torch |
| import torchmetrics |
|
|
| from transformers import AutoTokenizer, AutoModel |
| from huggingface_hub import PyTorchModelHubMixin |
| from lightning import LightningModule |
|
|
| from mentioned.data import DataBlob |
|
|
| class ModelRegistry: |
| _registry = {} |
|
|
| @classmethod |
| def register(cls, name): |
| def decorator(func): |
| cls._registry[name] = func |
| return func |
| return decorator |
|
|
| @classmethod |
| def get(cls, name): |
| return cls._registry[name] |
|
|
|
|
| class SentenceEncoder(torch.nn.Module): |
| def __init__( |
| self, |
| model_name: str = "distilroberta-base", |
| max_length: int = 512, |
| ): |
| super().__init__() |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| use_fast=True, |
| ) |
| self.encoder = AutoModel.from_pretrained(model_name) |
| self.max_length = max_length |
| self.dim = self.encoder.config.hidden_size |
| self.stats = {} |
|
|
| def forward(self, input_ids, attention_mask, word_ids): |
| """ |
| Args: |
| input_ids: B x N |
| attention_mask: B x N |
| word_ids: B x N |
| """ |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| subword_embeddings = outputs.last_hidden_state |
| num_words = word_ids.max() + 1 |
| word_mask = word_ids.unsqueeze(-1) == torch.arange( |
| num_words, device=word_ids.device |
| ) |
| word_mask = word_mask.to(subword_embeddings.dtype) |
| |
| word_sums = torch.bmm(word_mask.transpose(1, 2), subword_embeddings) |
| |
| |
| subword_counts = word_mask.sum(dim=1).unsqueeze(-1).clamp(min=1e-9) |
| |
| word_embeddings = word_sums / subword_counts |
| return word_embeddings |
|
|
|
|
| class Detector(torch.nn.Module): |
| def __init__( |
| self, |
| input_dim: int, |
| hidden_dim: int, |
| num_classes: int = 1, |
| ): |
| super().__init__() |
| self.net = torch.nn.Sequential( |
| torch.nn.Linear(input_dim, hidden_dim), |
| torch.nn.ReLU(), |
| torch.nn.Linear(hidden_dim, num_classes), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x: (B, N, input_dim) for start detection |
| (B, N, N, input_dim) for end detection |
| Returns: |
| logits: (B, N) or (B, N, N) |
| """ |
| return self.net(x) |
|
|
|
|
| class MentionDetectorCore(torch.nn.Module): |
| def __init__( |
| self, |
| start_detector: Detector, |
| end_detector: Detector, |
| ): |
| super().__init__() |
| self.start_detector = start_detector |
| self.end_detector = end_detector |
|
|
| def forward(self, emb: torch.Tensor): |
| """ |
| Args: |
| emb: (Batch, Seq_Len, Hidden_Dim) |
| Returns: |
| start_logits: (Batch, Seq_Len) |
| end_logits: (Batch, Seq_Len, Seq_Len) |
| """ |
| B, N, H = emb.shape |
| start_logits = self.start_detector(emb).squeeze(-1) |
| |
| start_rep = emb.unsqueeze(2).expand(-1, -1, N, -1) |
| end_rep = emb.unsqueeze(1).expand(-1, N, -1, -1) |
| pair_emb = torch.cat([start_rep, end_rep], dim=-1) |
| end_logits = self.end_detector(pair_emb).squeeze(-1) |
|
|
| return start_logits, end_logits |
|
|
|
|
| class MentionLabeler(torch.nn.Module): |
| def __init__(self, classifier: Detector): |
| super().__init__() |
| self.classifier = classifier |
|
|
| def forward(self, emb: torch.Tensor): |
| """ |
| Args: |
| emb: (Batch, Seq_Len, Hidden_Dim) |
| Returns: |
| start_logits: (Batch, Seq_Len) |
| end_logits: (Batch, Seq_Len, Seq_Len) |
| """ |
| B, N, H = emb.shape |
| |
| start_rep = emb.unsqueeze(2).expand(-1, -1, N, -1) |
| end_rep = emb.unsqueeze(1).expand(-1, N, -1, -1) |
| pair_emb = torch.cat([start_rep, end_rep], dim=-1) |
| logits = self.classifier(pair_emb).squeeze(-1) |
|
|
| return logits |
| |
|
|
| class LitMentionDetector(LightningModule, PyTorchModelHubMixin): |
| def __init__( |
| self, |
| tokenizer, |
| encoder: torch.nn.Module, |
| mention_detector: torch.nn.Module, |
| mention_labeler: torch.nn.Module | None = None, |
| label2id: dict | None = None, |
| lr: float = 2e-5, |
| threshold: float = 0.5, |
| ): |
| super().__init__() |
| self.save_hyperparameters(ignore=["encoder", "mention_detector", "mention_labeler"]) |
| self.tokenizer = tokenizer |
| self.encoder = encoder |
| |
| for param in self.encoder.parameters(): |
| param.requires_grad = False |
| self.mention_detector = mention_detector |
| self.mention_labeler = mention_labeler |
| self.label2id = label2id |
| self.loss_fn = torch.nn.BCEWithLogitsLoss(reduction="none") |
|
|
| |
| self.val_f1_start = torchmetrics.classification.BinaryF1Score() |
| self.val_f1_end = torchmetrics.classification.BinaryF1Score() |
| self.val_f1_mention = torchmetrics.classification.BinaryF1Score() |
|
|
| if mention_labeler is not None: |
| if label2id is None: |
| raise ValueError("Need label2id!") |
| num_classes = len(self.label2id) |
| self.val_f1_entity_start = torchmetrics.classification.BinaryF1Score() |
| self.val_f1_entity_end = torchmetrics.classification.BinaryF1Score() |
| self.val_f1_entity_mention = torchmetrics.classification.BinaryF1Score() |
| self.val_f1_entity_labels = torchmetrics.classification.MulticlassF1Score( |
| num_classes=num_classes, |
| average="macro" |
| ) |
| self.entity_loss = torch.nn.CrossEntropyLoss() |
| log_2 = torch.log(torch.tensor(2.0)) |
| |
| self.entity_weight = log_2 / torch.log(torch.tensor(float(num_classes))) |
|
|
| def encode(self, docs: list[list[str]]): |
| """ |
| Handles the non-vectorized tokenization and calls the vectorized encoder. |
| """ |
| device = next(self.parameters()).device |
| inputs = self.tokenizer( |
| docs, |
| is_split_into_words=True, |
| return_tensors="pt", |
| truncation=True, |
| max_length=self.encoder.max_length, |
| padding=True, |
| return_attention_mask=True, |
| return_offsets_mapping=True, |
| ) |
| input_ids = inputs["input_ids"].to(device) |
| attention_mask = inputs["attention_mask"].to(device) |
| batch_word_ids = [] |
| for i in range(len(docs)): |
| w_ids = [w if w is not None else -1 for w in inputs.word_ids(batch_index=i)] |
| batch_word_ids.append(torch.tensor(w_ids)) |
|
|
| word_ids_tensor = torch.stack(batch_word_ids).to(device) |
| word_embeddings = self.encoder( |
| input_ids=input_ids, attention_mask=attention_mask, word_ids=word_ids_tensor |
| ) |
| return word_embeddings |
| |
| def forward_detector(self, emb: torch.Tensor): |
| start_logits, end_logits = self.mention_detector(emb) |
| return start_logits, end_logits |
|
|
| def forward_labeler(self, emb: torch.Tensor): |
| entity_logits = self.mention_labeler(emb) |
| return entity_logits |
|
|
| def _compute_start_loss(self, start_logits, batch): |
| targets = batch["starts"].float() |
| mask = batch["token_mask"].bool() |
| return self.loss_fn(start_logits, targets)[mask].mean() |
|
|
| def _compute_end_loss(self, end_logits, batch): |
| targets = batch["spans"].float() |
| mask = batch["span_loss_mask"].bool() |
| raw_loss = self.loss_fn(end_logits, targets) |
| relevant_loss = raw_loss[mask] |
|
|
| if relevant_loss.numel() == 0: |
| return end_logits.sum() * 0 |
| return relevant_loss.mean() |
|
|
| def _compute_entity_loss(self, entity_logits, batch): |
| """ |
| entity_logits shape: [batch, max_len, max_len, num_classes] |
| """ |
| preds = [] |
| targets = [] |
| |
| for b, labels_dict in enumerate(batch["gold_labels"]): |
| for (s, e), label_str in labels_dict.items(): |
| |
| if s < entity_logits.size(1) and e < entity_logits.size(2): |
| label_id = self.label2id[label_str] |
| |
| preds.append(entity_logits[b, s, e]) |
| targets.append(label_id) |
|
|
| if not targets: |
| |
| return entity_logits.sum() * 0 |
|
|
| |
| preds_tensor = torch.stack(preds) |
| targets_tensor = torch.tensor(targets, device=entity_logits.device) |
|
|
| |
| return self.entity_loss(preds_tensor, targets_tensor) |
| |
| def training_step(self, batch, batch_idx): |
| emb = self.encode(batch["sentences"]) |
| start_logits, end_logits = self.forward_detector(emb) |
| loss_start = self._compute_start_loss(start_logits, batch) |
| loss_end = self._compute_end_loss(end_logits, batch) |
| total_loss = loss_start + loss_end |
| log_metrics = { |
| "train_start_loss": loss_start, |
| "train_end_loss": loss_end, |
| } |
| if batch["task_id"][0] == 1: |
| entity_logits = self.forward_labeler(emb) |
| loss_entity = self._compute_entity_loss(entity_logits, batch) |
| log_metrics["train_entity_loss"] = loss_entity |
| total_loss = total_loss + self.entity_weight * loss_entity |
|
|
| |
| log_metrics["train_loss"] = total_loss |
| self.log_dict(log_metrics, prog_bar=True) |
| return total_loss |
|
|
| def validation_step(self, batch, batch_idx): |
| |
| emb = self.encode(batch["sentences"]) |
| start_logits, end_logits = self.forward_detector(emb) |
| |
| token_mask = batch["token_mask"].bool() |
| span_loss_mask = batch["span_loss_mask"].bool() |
| |
| |
| is_start = (torch.sigmoid(start_logits) > self.hparams.threshold).int() |
| is_end = (torch.sigmoid(end_logits) > self.hparams.threshold).int() |
| |
| |
| valid_pair_mask = token_mask.unsqueeze(2) & token_mask.unsqueeze(1) |
| upper_tri = torch.triu(torch.ones_like(end_logits), diagonal=0).bool() |
| mention_eval_mask = valid_pair_mask & upper_tri |
| |
| |
| pred_spans = (is_start.unsqueeze(2) & is_end)[mention_eval_mask] |
| target_spans = batch["spans"][mention_eval_mask].int() |
|
|
| |
| log_stats = {} |
|
|
| |
| if batch["task_id"][0] == 0: |
| |
| if token_mask.any(): |
| self.val_f1_start.update(is_start[token_mask], batch["starts"][token_mask].int()) |
| |
| if span_loss_mask.any(): |
| self.val_f1_end.update(is_end[span_loss_mask], batch["spans"][span_loss_mask].int()) |
| |
| if mention_eval_mask.any(): |
| self.val_f1_mention.update(pred_spans, target_spans) |
| |
| log_stats["val_f1_mention"] = self.val_f1_mention |
|
|
| |
| elif batch["task_id"][0] == 1: |
| |
| if token_mask.any(): |
| self.val_f1_entity_start.update(is_start[token_mask], batch["starts"][token_mask].int()) |
| |
| if span_loss_mask.any(): |
| self.val_f1_entity_end.update(is_end[span_loss_mask], batch["spans"][span_loss_mask].int()) |
| |
| if mention_eval_mask.any(): |
| self.val_f1_entity_mention.update(pred_spans, target_spans) |
| |
| log_stats["val_f1_entity_mention"] = self.val_f1_entity_mention |
|
|
| |
| if self.mention_labeler is not None: |
| entity_logits = self.forward_labeler(emb) |
| gold_preds, gold_targets = [], [] |
| |
| for b, labels_dict in enumerate(batch["gold_labels"]): |
| for (s, e), label_str in labels_dict.items(): |
| if s < entity_logits.size(1) and e < entity_logits.size(2): |
| gold_preds.append(torch.argmax(entity_logits[b, s, e], dim=-1)) |
| gold_targets.append(self.label2id[label_str]) |
| |
| |
| if gold_targets: |
| self.val_f1_entity_labels.update( |
| torch.stack(gold_preds), |
| torch.tensor(gold_targets, device=emb.device) |
| ) |
| log_stats["val_f1_entity_labels"] = self.val_f1_entity_labels |
|
|
| |
| |
| loss_start = self._compute_start_loss(start_logits, batch) |
| loss_end = self._compute_end_loss(end_logits, batch) |
| log_stats["val_loss"] = loss_start + loss_end |
| |
| self.log_dict(log_stats, prog_bar=True, on_epoch=True, batch_size=len(batch["sentences"])) |
|
|
| @torch.no_grad() |
| def predict_mentions( |
| self, sentences: list[list[str]], batch_size: int = 2 |
| ) -> list[list[tuple[int, int]]]: |
| self.eval() |
| all_results = [] |
| thresh = self.hparams.threshold |
| for i in range(0, len(sentences), batch_size): |
| batch_sentences = sentences[i:i + batch_size] |
| emb = self.encode(batch_sentences) |
| start_logits, end_logits = self.forward_detector(emb) |
| is_start = torch.sigmoid(start_logits) > thresh |
| is_span = torch.sigmoid(end_logits) > thresh |
| |
| N = end_logits.size(1) |
| upper_tri = torch.triu( |
| torch.ones((N, N), device=self.device), diagonal=0 |
| ).bool() |
| pred_mask = is_start.unsqueeze(2) & is_span & upper_tri |
|
|
| |
| indices = pred_mask.nonzero() |
|
|
| batch_results = [[] for _ in range(len(batch_sentences))] |
| for b_idx, s_idx, e_idx in indices: |
| batch_results[b_idx.item()].append((s_idx.item(), e_idx.item())) |
|
|
| all_results.extend(batch_results) |
|
|
| return all_results |
|
|
| def test_step(self, batch, batch_idx): |
| |
| return self.validation_step(batch, batch_idx) |
|
|
| def configure_optimizers(self): |
| return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr) |
|
|
|
|
| @ModelRegistry.register("model_v1") |
| def make_model_v1(data: DataBlob, model_name="distilroberta-base"): |
| dim = 768 |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| encoder = SentenceEncoder(model_name).train() |
| encoder.train() |
| start_detector = Detector(dim, dim) |
| end_detector = Detector(dim * 2, dim) |
| mention_detector = MentionDetectorCore(start_detector, end_detector) |
| return LitMentionDetector(tokenizer, encoder, mention_detector) |
|
|
|
|
| @ModelRegistry.register("model_v2") |
| def make_model_v2(data: DataBlob, model_name="distilroberta-base"): |
| label2id = data.label2id |
| dim = 768 |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| encoder = SentenceEncoder(model_name).train() |
| encoder.train() |
| start_detector = Detector(dim, dim) |
| end_detector = Detector(dim * 2, dim) |
| classifier = Detector(dim * 2, dim, num_classes=len(label2id)) |
| mention_detector = MentionDetectorCore(start_detector, end_detector) |
| mention_labeler = MentionLabeler(classifier) |
| return LitMentionDetector( |
| tokenizer, |
| encoder, |
| mention_detector, |
| mention_labeler, |
| label2id, |
| ) |
|
|