| |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoConfig, AutoModel, PreTrainedModel |
| from transformers.utils import ModelOutput |
|
|
|
|
| def hidden_size_from_config(config) -> int: |
| return int(getattr(config, "hidden_size", getattr(config, "dim"))) |
|
|
|
|
| @dataclass |
| class MultilabelSpanOutput(ModelOutput): |
| loss: Optional[torch.Tensor] = None |
| start_logits: Optional[torch.Tensor] = None |
| end_logits: Optional[torch.Tensor] = None |
|
|
|
|
| class IrishCoreSpanHeadModel(PreTrainedModel): |
| config_class = AutoConfig |
| base_model_prefix = "encoder" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| num_span_labels = int(getattr(config, "num_span_labels")) |
| self.encoder = AutoModel.from_config(config) |
| hidden_size = hidden_size_from_config(config) |
| dropout = float(getattr(config, "seq_classif_dropout", getattr(config, "dropout", 0.1))) |
| self.dropout = nn.Dropout(dropout) |
| self.start_classifier = nn.Linear(hidden_size, num_span_labels) |
| self.end_classifier = nn.Linear(hidden_size, num_span_labels) |
| pos_weight = float(getattr(config, "span_positive_weight", 6.0)) |
| self.register_buffer("loss_pos_weight", torch.full((num_span_labels,), pos_weight), persistent=False) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| token_type_ids=None, |
| start_positions=None, |
| end_positions=None, |
| token_mask=None, |
| **kwargs, |
| ) -> MultilabelSpanOutput: |
| encoder_kwargs = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| **kwargs, |
| } |
| if token_type_ids is not None and getattr(self.config, "model_type", "") not in {"distilbert", "roberta"}: |
| encoder_kwargs["token_type_ids"] = token_type_ids |
| outputs = self.encoder(**encoder_kwargs) |
| hidden = self.dropout(outputs.last_hidden_state) |
| start_logits = self.start_classifier(hidden) |
| end_logits = self.end_classifier(hidden) |
|
|
| loss = None |
| if start_positions is not None and end_positions is not None: |
| if token_mask is None: |
| token_mask = attention_mask |
| mask = token_mask.float().unsqueeze(-1) |
| pos_weight = self.loss_pos_weight.to(start_logits.device) |
| bce = nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight) |
| start_loss = bce(start_logits, start_positions.float()) * mask |
| end_loss = bce(end_logits, end_positions.float()) * mask |
| denom = mask.sum().clamp_min(1.0) * start_logits.shape[-1] |
| loss = (start_loss.sum() + end_loss.sum()) / (2.0 * denom) |
|
|
| return MultilabelSpanOutput(loss=loss, start_logits=start_logits, end_logits=end_logits) |
|
|