Upload 3 files
Browse files- __init__.py +0 -0
- configuration_bert_ordinal.py +107 -0
- modeling_bert_ordinal.py +282 -0
__init__.py
ADDED
|
File without changes
|
configuration_bert_ordinal.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
bert_ordinal.py
|
| 3 |
+
---------------
|
| 4 |
+
BERT-based ordinal regression model, fully integrated with the HuggingFace
|
| 5 |
+
Transformers API:
|
| 6 |
+
|
| 7 |
+
model.save_pretrained("my-checkpoint/")
|
| 8 |
+
model = BertOrdinal.from_pretrained("my-checkpoint/")
|
| 9 |
+
|
| 10 |
+
Architecture
|
| 11 |
+
------------
|
| 12 |
+
1. A (optionally frozen) BERT backbone.
|
| 13 |
+
2. A projection head on the [CLS] token:
|
| 14 |
+
Linear(hidden_size β hidden_dim) β ReLU β Dropout(p) β Linear(hidden_dim β 1)
|
| 15 |
+
producing a single latent score s β β.
|
| 16 |
+
3. K-1 learnable raw_threshold parameters enforcing monotonicity via
|
| 17 |
+
cumsum(softplus(Β·)).
|
| 18 |
+
4. Cumulative-link probabilities:
|
| 19 |
+
P(Y β€ j | x) = Ο(ΞΈ_j β s)
|
| 20 |
+
|
| 21 |
+
Usage
|
| 22 |
+
-----
|
| 23 |
+
from bert_ordinal import BertOrdinalConfig, BertOrdinal
|
| 24 |
+
|
| 25 |
+
# ββ Create from scratch ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
cfg = BertOrdinalConfig(
|
| 27 |
+
bert_model_name="bert-base-uncased",
|
| 28 |
+
num_classes=3,
|
| 29 |
+
hidden_dim=128,
|
| 30 |
+
dropout=0.1,
|
| 31 |
+
freeze_bert=True,
|
| 32 |
+
)
|
| 33 |
+
model = BertOrdinal(cfg)
|
| 34 |
+
|
| 35 |
+
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
model.save_pretrained("my-checkpoint/")
|
| 37 |
+
tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside
|
| 38 |
+
|
| 39 |
+
# ββ Reload ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
model = BertOrdinal.from_pretrained("my-checkpoint/")
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/")
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
from __future__ import annotations
|
| 45 |
+
from typing import Optional
|
| 46 |
+
from transformers import PretrainedConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# 1. Config β subclass PretrainedConfig for full HF serialisation
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
class BertOrdinalConfig(PretrainedConfig):
|
| 54 |
+
"""
|
| 55 |
+
Configuration for :class:`BertOrdinal`.
|
| 56 |
+
|
| 57 |
+
Because this inherits from :class:`~transformers.PretrainedConfig`,
|
| 58 |
+
``save_pretrained`` writes a ``config.json`` that ``from_pretrained``
|
| 59 |
+
can read back without any extra bookkeeping.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
bert_model_name : str
|
| 64 |
+
HuggingFace model name or local path for the BERT backbone.
|
| 65 |
+
num_classes : int
|
| 66 |
+
Number of ordinal classes K. Creates K-1 learnable thresholds.
|
| 67 |
+
hidden_dim : int
|
| 68 |
+
Inner dimension of the projection head.
|
| 69 |
+
dropout : float
|
| 70 |
+
Dropout probability inside the projection head.
|
| 71 |
+
freeze_bert : bool
|
| 72 |
+
Freeze backbone weights at construction time.
|
| 73 |
+
loss_reduction : str
|
| 74 |
+
``'mean'`` or ``'sum'``.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
# Tells HF which class owns this config (written into config.json).
|
| 78 |
+
model_type = "bert_ordinal"
|
| 79 |
+
problem_type = "single_label_classification"
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
bert_model_name: str = "allenai/scibert_scivocab_uncased",
|
| 84 |
+
num_classes: int = 3,
|
| 85 |
+
hidden_dim: int = 256,
|
| 86 |
+
dropout: float = 0.1,
|
| 87 |
+
freeze_bert: bool = True,
|
| 88 |
+
loss_reduction: str = "mean",
|
| 89 |
+
# hidden_size is set automatically by the model after loading BERT;
|
| 90 |
+
# it is stored here so from_pretrained can rebuild the head offline.
|
| 91 |
+
hidden_size: Optional[int] = None,
|
| 92 |
+
**kwargs,
|
| 93 |
+
) -> None:
|
| 94 |
+
super().__init__(**kwargs)
|
| 95 |
+
self.bert_model_name = bert_model_name
|
| 96 |
+
self.num_classes = num_classes
|
| 97 |
+
self.hidden_dim = hidden_dim
|
| 98 |
+
self.dropout = dropout
|
| 99 |
+
self.freeze_bert = freeze_bert
|
| 100 |
+
self.loss_reduction = loss_reduction
|
| 101 |
+
self.hidden_size = hidden_size # filled in by BertOrdinal.__init__
|
| 102 |
+
|
| 103 |
+
self.auto_map = {
|
| 104 |
+
"AutoConfig": "configuration_bert_ordinal.BertOrdinalConfig",
|
| 105 |
+
"AutoModel": "modeling_bert_ordinal.BertOrdinal",
|
| 106 |
+
"AutoModelForSequenceClassification": "modeling_bert_ordinal.BertOrdinal",
|
| 107 |
+
}
|
modeling_bert_ordinal.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
bert_ordinal.py
|
| 3 |
+
---------------
|
| 4 |
+
BERT-based ordinal regression model, fully integrated with the HuggingFace
|
| 5 |
+
Transformers API:
|
| 6 |
+
|
| 7 |
+
model.save_pretrained("my-checkpoint/")
|
| 8 |
+
model = BertOrdinal.from_pretrained("my-checkpoint/")
|
| 9 |
+
|
| 10 |
+
Architecture
|
| 11 |
+
------------
|
| 12 |
+
1. A (optionally frozen) BERT backbone.
|
| 13 |
+
2. A projection head on the [CLS] token:
|
| 14 |
+
Linear(hidden_size β hidden_dim) β ReLU β Dropout(p) β Linear(hidden_dim β 1)
|
| 15 |
+
producing a single latent score s β β.
|
| 16 |
+
3. K-1 learnable raw_threshold parameters enforcing monotonicity via
|
| 17 |
+
cumsum(softplus(Β·)).
|
| 18 |
+
4. Cumulative-link probabilities:
|
| 19 |
+
P(Y β€ j | x) = Ο(ΞΈ_j β s)
|
| 20 |
+
|
| 21 |
+
Usage
|
| 22 |
+
-----
|
| 23 |
+
from bert_ordinal import BertOrdinalConfig, BertOrdinal
|
| 24 |
+
|
| 25 |
+
# ββ Create from scratch ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
cfg = BertOrdinalConfig(
|
| 27 |
+
bert_model_name="bert-base-uncased",
|
| 28 |
+
num_classes=3,
|
| 29 |
+
hidden_dim=128,
|
| 30 |
+
dropout=0.1,
|
| 31 |
+
freeze_bert=True,
|
| 32 |
+
)
|
| 33 |
+
model = BertOrdinal(cfg)
|
| 34 |
+
|
| 35 |
+
# ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
+
model.save_pretrained("my-checkpoint/")
|
| 37 |
+
tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside
|
| 38 |
+
|
| 39 |
+
# ββ Reload ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
model = BertOrdinal.from_pretrained("my-checkpoint/")
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/")
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
from __future__ import annotations
|
| 45 |
+
|
| 46 |
+
from dataclasses import dataclass
|
| 47 |
+
from typing import Optional
|
| 48 |
+
|
| 49 |
+
import torch
|
| 50 |
+
import torch.nn as nn
|
| 51 |
+
import torch.nn.functional as F
|
| 52 |
+
from transformers import AutoModel, PreTrainedModel
|
| 53 |
+
from transformers.modeling_outputs import ModelOutput
|
| 54 |
+
|
| 55 |
+
from configuration_bert_ordinal import BertOrdinalConfig
|
| 56 |
+
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
# 1. Output dataclass
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class BertOrdinalOutput(ModelOutput):
|
| 63 |
+
"""
|
| 64 |
+
Return type of :class:`BertOrdinal`.
|
| 65 |
+
|
| 66 |
+
Attributes
|
| 67 |
+
----------
|
| 68 |
+
loss : torch.Tensor or None
|
| 69 |
+
Ordinal cross-entropy loss (scalar). Present only when ``labels``
|
| 70 |
+
are supplied.
|
| 71 |
+
logits : torch.Tensor (B,)
|
| 72 |
+
Raw latent score from the projection head.
|
| 73 |
+
predictions : torch.Tensor (B,)
|
| 74 |
+
Predicted class index β argmax of ``class_probs``.
|
| 75 |
+
cum_probs : torch.Tensor (B, K-1)
|
| 76 |
+
Cumulative probabilities P(Y β€ j | x).
|
| 77 |
+
class_probs : torch.Tensor (B, K)
|
| 78 |
+
Per-class probabilities P(Y = j | x).
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
loss: Optional[torch.Tensor] = None
|
| 82 |
+
logits: Optional[torch.Tensor] = None
|
| 83 |
+
predictions: Optional[torch.Tensor] = None
|
| 84 |
+
cum_probs: Optional[torch.Tensor] = None
|
| 85 |
+
class_probs: Optional[torch.Tensor] = None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
# 3. Model β subclass PreTrainedModel for save / from_pretrained
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
class BertOrdinal(PreTrainedModel):
|
| 93 |
+
"""
|
| 94 |
+
BERT encoder with an ordinal-regression head.
|
| 95 |
+
|
| 96 |
+
Fully compatible with the HuggingFace checkpoint API::
|
| 97 |
+
|
| 98 |
+
model.save_pretrained("my-checkpoint/")
|
| 99 |
+
model = BertOrdinal.from_pretrained("my-checkpoint/")
|
| 100 |
+
|
| 101 |
+
What gets saved
|
| 102 |
+
~~~~~~~~~~~~~~~
|
| 103 |
+
``save_pretrained`` writes two files:
|
| 104 |
+
|
| 105 |
+
* ``config.json`` β the full :class:`BertOrdinalConfig` (including
|
| 106 |
+
``bert_model_name``, ``hidden_size``, thresholds shape, β¦).
|
| 107 |
+
* ``model.safetensors`` (or ``pytorch_model.bin``) β a **single flat
|
| 108 |
+
state_dict** containing both the BERT backbone weights and the
|
| 109 |
+
head/threshold parameters.
|
| 110 |
+
|
| 111 |
+
``from_pretrained`` reconstructs the model from the config (which
|
| 112 |
+
already has ``hidden_size`` cached), loads the state_dict, and
|
| 113 |
+
re-applies the ``freeze_bert`` setting β no internet access needed
|
| 114 |
+
after the first save.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
config_class = BertOrdinalConfig
|
| 118 |
+
|
| 119 |
+
def __init__(self, config: BertOrdinalConfig) -> None:
|
| 120 |
+
super().__init__(config)
|
| 121 |
+
K = config.num_classes
|
| 122 |
+
|
| 123 |
+
# ββ 1. BERT backbone ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
+
# If hidden_size is already in the config (i.e. we are being called
|
| 125 |
+
# from from_pretrained after a save), build the backbone from the
|
| 126 |
+
# cached backbone config instead of re-downloading weights β
|
| 127 |
+
# from_pretrained will overwrite with the saved state_dict anyway.
|
| 128 |
+
self.bert = AutoModel.from_pretrained(config.bert_model_name)
|
| 129 |
+
hidden_size: int = self.bert.config.hidden_size
|
| 130 |
+
|
| 131 |
+
# Cache so the head can be rebuilt offline after save_pretrained.
|
| 132 |
+
config.hidden_size = hidden_size
|
| 133 |
+
|
| 134 |
+
if config.freeze_bert:
|
| 135 |
+
for param in self.bert.parameters():
|
| 136 |
+
param.requires_grad = False
|
| 137 |
+
|
| 138 |
+
# ββ 2. Projection head ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 139 |
+
self.head = nn.Sequential(
|
| 140 |
+
nn.Linear(hidden_size, config.hidden_dim),
|
| 141 |
+
nn.ReLU(),
|
| 142 |
+
nn.Dropout(config.dropout),
|
| 143 |
+
nn.Linear(config.hidden_dim, 1),
|
| 144 |
+
)
|
| 145 |
+
self._init_head()
|
| 146 |
+
|
| 147 |
+
# ββ 3. Ordinal thresholds βββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
# K-1 raw values; monotonicity enforced via cumsum(softplus(Β·)).
|
| 149 |
+
self.raw_thresholds = nn.Parameter(torch.zeros(K - 1))
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
targets = torch.linspace(-1.0, 1.0, K - 1)
|
| 152 |
+
diffs = torch.cat([targets[:1], targets[1:] - targets[:-1]])
|
| 153 |
+
self.raw_thresholds.copy_(
|
| 154 |
+
torch.log(torch.expm1(diffs.clamp(min=1e-3)))
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Finalises weight init bookkeeping required by PreTrainedModel.
|
| 158 |
+
self.post_init()
|
| 159 |
+
|
| 160 |
+
# -----------------------------------------------------------------------
|
| 161 |
+
# Helpers
|
| 162 |
+
# -----------------------------------------------------------------------
|
| 163 |
+
|
| 164 |
+
def _init_head(self) -> None:
|
| 165 |
+
for m in self.head.modules():
|
| 166 |
+
if isinstance(m, nn.Linear):
|
| 167 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 168 |
+
nn.init.zeros_(m.bias)
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def thresholds(self) -> torch.Tensor:
|
| 172 |
+
"""Monotone thresholds ΞΈβ β€ β¦ β€ ΞΈ_{K-1} (shape: K-1)."""
|
| 173 |
+
return torch.cumsum(F.softplus(self.raw_thresholds), dim=0)
|
| 174 |
+
|
| 175 |
+
# -----------------------------------------------------------------------
|
| 176 |
+
# Forward
|
| 177 |
+
# -----------------------------------------------------------------------
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
input_ids: torch.Tensor,
|
| 182 |
+
attention_mask: torch.Tensor,
|
| 183 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 184 |
+
labels: Optional[torch.Tensor] = None,
|
| 185 |
+
**kwargs,
|
| 186 |
+
) -> BertOrdinalOutput:
|
| 187 |
+
"""
|
| 188 |
+
Parameters
|
| 189 |
+
----------
|
| 190 |
+
input_ids : (B, L)
|
| 191 |
+
attention_mask : (B, L)
|
| 192 |
+
token_type_ids : (B, L) optional
|
| 193 |
+
labels : (B,) long β class indices in {0, β¦, K-1}
|
| 194 |
+
|
| 195 |
+
Returns
|
| 196 |
+
-------
|
| 197 |
+
BertOrdinalOutput
|
| 198 |
+
"""
|
| 199 |
+
# ββ Encode ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 200 |
+
bert_kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
| 201 |
+
if token_type_ids is not None:
|
| 202 |
+
bert_kwargs["token_type_ids"] = token_type_ids
|
| 203 |
+
|
| 204 |
+
cls_repr = self.bert(**bert_kwargs).last_hidden_state[:, 0, :] # (B, H)
|
| 205 |
+
|
| 206 |
+
# ββ Latent score ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 207 |
+
score = self.head(cls_repr).squeeze(-1) # (B,)
|
| 208 |
+
|
| 209 |
+
# ββ Cumulative probs P(Y β€ j) = Ο(ΞΈ_j β score) ββββββββββββββββββββ
|
| 210 |
+
cum_logits = self.thresholds.unsqueeze(0) - score.unsqueeze(1) # (B, K-1)
|
| 211 |
+
cum_probs = torch.sigmoid(cum_logits) # (B, K-1)
|
| 212 |
+
|
| 213 |
+
# ββ Class probs P(Y = j) = P(Y β€ j) β P(Y β€ j-1) βββββββββββββββββ
|
| 214 |
+
B, dev = cum_probs.size(0), cum_probs.device
|
| 215 |
+
F_ = torch.cat(
|
| 216 |
+
[torch.zeros(B, 1, device=dev), cum_probs, torch.ones(B, 1, device=dev)],
|
| 217 |
+
dim=1,
|
| 218 |
+
) # (B, K+1)
|
| 219 |
+
class_probs = (F_[:, 1:] - F_[:, :-1]).clamp(min=1e-9) # (B, K)
|
| 220 |
+
|
| 221 |
+
# ββ Predictions ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
+
predictions = class_probs.argmax(dim=-1) # (B,)
|
| 223 |
+
|
| 224 |
+
# ββ Loss βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
+
loss: Optional[torch.Tensor] = None
|
| 226 |
+
if labels is not None:
|
| 227 |
+
loss = ordinal_cross_entropy(
|
| 228 |
+
class_probs, labels, reduction=self.config.loss_reduction
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return BertOrdinalOutput(
|
| 232 |
+
loss=loss,
|
| 233 |
+
logits=score,
|
| 234 |
+
predictions=predictions,
|
| 235 |
+
cum_probs=cum_probs,
|
| 236 |
+
class_probs=class_probs,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# -----------------------------------------------------------------------
|
| 240 |
+
# Convenience
|
| 241 |
+
# -----------------------------------------------------------------------
|
| 242 |
+
|
| 243 |
+
@torch.no_grad()
|
| 244 |
+
def predict(
|
| 245 |
+
self,
|
| 246 |
+
input_ids: torch.Tensor,
|
| 247 |
+
attention_mask: torch.Tensor,
|
| 248 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 249 |
+
) -> torch.Tensor:
|
| 250 |
+
"""Return predicted class indices (no loss computed)."""
|
| 251 |
+
return self.forward(input_ids, attention_mask, token_type_ids).predictions
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ---------------------------------------------------------------------------
|
| 255 |
+
# Loss function
|
| 256 |
+
# ---------------------------------------------------------------------------
|
| 257 |
+
|
| 258 |
+
def ordinal_cross_entropy(
|
| 259 |
+
class_probs: torch.Tensor,
|
| 260 |
+
labels: torch.Tensor,
|
| 261 |
+
reduction: str = "mean",
|
| 262 |
+
) -> torch.Tensor:
|
| 263 |
+
"""
|
| 264 |
+
Ordinal cross-entropy.
|
| 265 |
+
|
| 266 |
+
Parameters
|
| 267 |
+
----------
|
| 268 |
+
class_probs : (B, K) β P(Y=j|x), clamped > 0
|
| 269 |
+
labels : (B,) β ground-truth indices in {0, β¦, K-1}
|
| 270 |
+
reduction : 'mean' | 'sum' | 'none'
|
| 271 |
+
"""
|
| 272 |
+
return F.nll_loss(torch.log(class_probs), labels, reduction=reduction)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# ---------------------------------------------------------------------------
|
| 276 |
+
# Register the model with the Transformers library
|
| 277 |
+
# ---------------------------------------------------------------------------
|
| 278 |
+
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification
|
| 279 |
+
|
| 280 |
+
AutoConfig.register("bert_ordinal", BertOrdinalConfig)
|
| 281 |
+
AutoModel.register(BertOrdinalConfig, BertOrdinal)
|
| 282 |
+
AutoModelForSequenceClassification.register(BertOrdinalConfig, BertOrdinal)
|