| """ |
| model.py β GlobalPointer-based NER model on top of BERT |
| |
| Changes vs previous version: |
| [FIX-1] Circle Loss: correct two-term formulation (Su Jianlin style), |
| with margin (m) and scale (gamma) params; no more logaddexp merging. |
| [FIX-2] Numerical safety: negated pos_logits no longer turns -1e9 β +1e9; |
| we apply the mask BEFORE negation. |
| [FIX-3] labels .float() cast inside forward (no silent runtime error / nan). |
| [FIX-4] valid_mask (bool, BΓL) replaces attention_mask for span masking; |
| attention_mask is still passed to the encoder for self-attention. |
| [FIX-5] use_rope flag for GlobalPointer's span-level RoPE (independent of |
| BERT encoder internals). |
| """ |
|
|
| import json |
| from pathlib import Path |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| class EfficientGlobalPointer(nn.Module): |
| """ |
| EfficientGlobalPointer span scorer (Su Jianlin style). |
| |
| Differences vs standard GlobalPointer: |
| - q/k are shared across labels: hidden -> 2 * head_size |
| - label-specific bias per token: hidden -> 2 * num_labels |
| (start_bias and end_bias for each label) |
| - logits: (q @ k^T)/sqrt(D) expanded to C labels, then add biases |
| |
| Output shape: (B, C, L, L) |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_labels: int, |
| head_size: int = 64, |
| use_rope: bool = True, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.num_labels = num_labels |
| self.head_size = head_size |
| self.use_rope = use_rope |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.dense_qk = nn.Linear(hidden_size, head_size * 2) |
|
|
| |
| self.dense_bias = nn.Linear(hidden_size, num_labels * 2) |
|
|
| if use_rope: |
| self.rope = RotaryEmbedding(head_size) |
|
|
| def forward(self, hidden: torch.Tensor) -> torch.Tensor: |
| """ |
| hidden: (B, L, H) |
| returns logits: (B, C, L, L) |
| """ |
| B, L, _ = hidden.shape |
| C = self.num_labels |
| D = self.head_size |
|
|
| hidden = self.dropout(hidden) |
|
|
| |
| qk = self.dense_qk(hidden) |
| q, k = qk[..., :D], qk[..., D:] |
|
|
| if self.use_rope: |
| emb = self.rope(L, hidden.device) |
| cos_ = emb.cos()[None, :, :] |
| sin_ = emb.sin()[None, :, :] |
| q = apply_rotary(q, cos_, sin_) |
| k = apply_rotary(k, cos_, sin_) |
|
|
| |
| base = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(D) |
|
|
| |
| bias = self.dense_bias(hidden) |
| bias = bias.view(B, L, C, 2) |
|
|
| |
| start_bias = bias[..., 0].permute(0, 2, 1) |
| end_bias = bias[..., 1].permute(0, 2, 1) |
|
|
| |
| |
| |
| |
| logits = ( |
| base[:, None, :, :] + |
| start_bias[:, :, :, None] + |
| end_bias[:, :, None, :] |
| ) |
|
|
| return logits |
|
|
| |
| |
| |
|
|
| class RotaryEmbedding(nn.Module): |
| """Rotary Position Embedding for GlobalPointer span scoring.""" |
|
|
| def __init__(self, dim: int): |
| super().__init__() |
| assert dim % 2 == 0, "RoPE dim must be even" |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
|
|
| def forward(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| """Returns cos/sin interleaved tensor of shape (seq_len, dim).""" |
| t = torch.arange(seq_len, device=device).float() |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| return emb |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| half = x.shape[-1] // 2 |
| x1, x2 = x[..., :half], x[..., half:] |
| return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
| def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| """x: (..., L, D) cos/sin: (L, D)""" |
| return x * cos + rotate_half(x) * sin |
|
|
|
|
| |
| |
| |
|
|
| def multilabel_circle_loss( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| mask2d: torch.Tensor, |
| margin: float = 0.25, |
| gamma: float = 32.0, |
| ) -> torch.Tensor: |
| """ |
| Su Jianlinβstyle Circle Loss for multi-label span classification. |
| |
| L = log(1 + Ξ£ exp(Ξ³Β·(s_neg + m))) + log(1 + Ξ£ exp(βΞ³Β·(s_pos β m))) |
| |
| Two independent logsumexp terms keep the original loss geometry intact. |
| Mask is applied BEFORE any sign flip to avoid Β±1e9 explosions. |
| |
| Args: |
| logits: raw span scores, shape (B, C, L, L) |
| labels: float tensor {0, 1}, same shape |
| mask2d: bool (B, 1, L, L) β True where span is valid (upper-tri + valid tokens) |
| margin: additive margin (default 0.25) |
| gamma: temperature / scale (default 32) |
| """ |
| B, C, L, _ = logits.shape |
|
|
| |
| mask = mask2d.expand(B, C, L, L) |
|
|
| |
| pos_mask = mask & (labels > 0.5) |
| neg_mask = mask & (labels < 0.5) |
|
|
| |
| s = logits * gamma |
|
|
| |
| |
| neg_scores = s.masked_fill(~neg_mask, float("-inf")) |
| |
| neg_lse = torch.logsumexp(neg_scores.view(B, C, -1), dim=-1) |
| loss_neg = F.softplus(neg_lse + gamma * margin) |
|
|
| |
| |
| |
| pos_scores = s.masked_fill(~pos_mask, float("-inf")) |
| neg_pos_scores = (-pos_scores).masked_fill(~pos_mask, float("-inf")) |
| pos_lse = torch.logsumexp(neg_pos_scores.view(B, C, -1), dim=-1) |
| loss_pos = F.softplus(pos_lse + gamma * margin) |
|
|
| |
| loss = (loss_neg + loss_pos).mean() |
| return loss |
|
|
|
|
| def multilabel_bce_loss( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| mask2d: torch.Tensor, |
| ) -> torch.Tensor: |
| mask = mask2d.expand_as(logits) |
| loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none") |
| loss = loss * mask.float() |
| return loss.sum() / mask.float().sum().clamp(min=1) |
|
|
|
|
| |
| |
| |
|
|
| class GlobalPointer(nn.Module): |
| """ |
| GlobalPointer span scorer. |
| |
| Projects encoder hidden states to per-label (q, k) vectors and computes |
| an (LΓL) score matrix per label. Optionally applies span-level RoPE. |
| |
| Note: encoder internals (inside self-attention layers) are entirely |
| separate from this span-level RoPE β both can be active simultaneously. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_labels: int, |
| head_size: int = 64, |
| use_rope: bool = True, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.num_labels = num_labels |
| self.head_size = head_size |
| self.use_rope = use_rope |
|
|
| self.dropout = nn.Dropout(dropout) |
| |
| self.dense = nn.Linear(hidden_size, num_labels * head_size * 2) |
|
|
| if use_rope: |
| self.rope = RotaryEmbedding(head_size) |
|
|
| def forward( |
| self, |
| hidden: torch.Tensor, |
| ) -> torch.Tensor: |
| B, L, H = hidden.shape |
| C = self.num_labels |
| D = self.head_size |
|
|
| hidden = self.dropout(hidden) |
| proj = self.dense(hidden) |
| proj = proj.view(B, L, C, D * 2) |
| q, k = proj[..., :D], proj[..., D:] |
|
|
| if self.use_rope: |
| emb = self.rope(L, hidden.device) |
| cos_ = emb.cos()[None, :, None, :] |
| sin_ = emb.sin()[None, :, None, :] |
| q = apply_rotary(q, cos_, sin_) |
| k = apply_rotary(k, cos_, sin_) |
|
|
| |
| q = q.permute(0, 2, 1, 3) |
| k = k.permute(0, 2, 1, 3) |
|
|
| |
| logits = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(D) |
|
|
| return logits |
|
|
|
|
| |
| |
| |
|
|
| class EcomBertNER(nn.Module): |
| """ |
| BERT encoder + GlobalPointer head for span-based NER. |
| |
| forward() signature: |
| input_ids (B, L) β token ids |
| attention_mask (B, L) β passed to encoder (1=real, 0=pad) |
| labels (B, C, L, L) torch.bool, optional |
| valid_mask (B, L) torch.bool, optional β True = valid token |
| (excludes CLS/SEP/PAD; from dataset collate_fn) |
| |
| If valid_mask is not provided, falls back to attention_mask.bool() |
| (slightly less precise β includes CLS/SEP as negative spans). |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = "bert-base-chinese", |
| num_labels: int = 23, |
| head_size: int = 64, |
| loss_type: str = "circle", |
| use_rope: bool = True, |
| dropout: float = 0.1, |
| cache_dir: str = None, |
| |
| circle_margin: float = 0.25, |
| circle_gamma: float = 32.0, |
| ): |
| super().__init__() |
| assert loss_type in ("circle", "bce"), \ |
| f"loss_type must be 'circle' or 'bce', got {loss_type!r}" |
|
|
| self.loss_type = loss_type |
| self.circle_margin = circle_margin |
| self.circle_gamma = circle_gamma |
|
|
| self.encoder = AutoModel.from_pretrained( |
| model_name, cache_dir=cache_dir |
| ) |
| hidden_size = self.encoder.config.hidden_size |
|
|
| self.global_pointer = EfficientGlobalPointer( |
| hidden_size = hidden_size, |
| num_labels = num_labels, |
| head_size = head_size, |
| use_rope = use_rope, |
| dropout = dropout, |
| ) |
|
|
| self.model_name = model_name |
| self.num_labels = num_labels |
| self.head_size = head_size |
| self.use_rope = use_rope |
| self.dropout = dropout |
|
|
| |
|
|
| @staticmethod |
| def _build_span_mask( |
| valid_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Returns upper-triangular span mask (B, 1, L, L) where |
| mask[b,0,i,j] = True iff i<=j and both token i and j are valid. |
| """ |
| |
| row = valid_mask[:, None, :, None] |
| col = valid_mask[:, None, None, :] |
| pair_mask = row & col |
|
|
| L = valid_mask.size(1) |
| upper_tri = torch.triu( |
| torch.ones(L, L, dtype=torch.bool, device=valid_mask.device) |
| ) |
|
|
| return pair_mask & upper_tri |
|
|
| |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| labels: torch.Tensor = None, |
| valid_mask: torch.Tensor = None, |
| ) -> dict: |
| |
| encoder_out = self.encoder( |
| input_ids = input_ids, |
| attention_mask = attention_mask, |
| ) |
| hidden = encoder_out.last_hidden_state |
|
|
| |
| logits = self.global_pointer(hidden) |
|
|
| |
| |
| if valid_mask is None: |
| valid_mask = attention_mask.bool() |
|
|
| mask2d = self._build_span_mask(valid_mask) |
|
|
| |
| logits_masked = logits.masked_fill( |
| ~mask2d.expand_as(logits), -1e4 |
| ) |
|
|
| |
| loss = None |
| if labels is not None: |
| |
| labels_f = labels.float() |
|
|
| if self.loss_type == "circle": |
| loss = multilabel_circle_loss( |
| logits = logits, |
| labels = labels_f, |
| mask2d = mask2d, |
| margin = self.circle_margin, |
| gamma = self.circle_gamma, |
| ) |
| else: |
| loss = multilabel_bce_loss( |
| logits = logits, |
| labels = labels_f, |
| mask2d = mask2d, |
| ) |
|
|
| return { |
| "loss": loss, |
| "logits": logits_masked, |
| } |
|
|
| def save_pretrained(self, save_directory: str | Path, *, extra_config: dict | None = None) -> None: |
| save_dir = Path(save_directory) |
| save_dir.mkdir(parents=True, exist_ok=True) |
|
|
| config = { |
| "architectures": [self.__class__.__name__], |
| "model_name": self.model_name, |
| "num_labels": self.num_labels, |
| "head_size": self.head_size, |
| "loss_type": self.loss_type, |
| "use_rope": self.use_rope, |
| "dropout": self.dropout, |
| "circle_margin": self.circle_margin, |
| "circle_gamma": self.circle_gamma, |
| } |
| if extra_config: |
| config.update(extra_config) |
|
|
| with open(save_dir / "config.json", "w", encoding="utf-8") as f: |
| json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
| torch.save(self.state_dict(), save_dir / "pytorch_model.bin") |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| model_dir: str | Path, |
| *, |
| device: torch.device | str | None = None, |
| cache_dir: str | None = None, |
| ) -> tuple["EcomBertNER", dict]: |
| model_dir = Path(model_dir) |
| with open(model_dir / "config.json", "r", encoding="utf-8") as f: |
| cfg = json.load(f) |
|
|
| model = cls( |
| model_name=cfg.get("model_name", "bert-base-chinese"), |
| num_labels=int(cfg.get("num_labels", 23)), |
| head_size=int(cfg.get("head_size", 64)), |
| loss_type=str(cfg.get("loss_type", "circle")), |
| use_rope=bool(cfg.get("use_rope", True)), |
| dropout=float(cfg.get("dropout", 0.1)), |
| cache_dir=cache_dir, |
| circle_margin=float(cfg.get("circle_margin", 0.25)), |
| circle_gamma=float(cfg.get("circle_gamma", 32.0)), |
| ) |
| state = torch.load(model_dir / "pytorch_model.bin", map_location="cpu", weights_only=False) |
| model.load_state_dict(state) |
| if device is not None: |
| model.to(device) |
| model.eval() |
| return model, cfg |