Cbelem commited on
Commit
c1f6b2a
Β·
verified Β·
1 Parent(s): 4b01687

Upload 3 files

Browse files
__init__.py ADDED
File without changes
configuration_bert_ordinal.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ bert_ordinal.py
3
+ ---------------
4
+ BERT-based ordinal regression model, fully integrated with the HuggingFace
5
+ Transformers API:
6
+
7
+ model.save_pretrained("my-checkpoint/")
8
+ model = BertOrdinal.from_pretrained("my-checkpoint/")
9
+
10
+ Architecture
11
+ ------------
12
+ 1. A (optionally frozen) BERT backbone.
13
+ 2. A projection head on the [CLS] token:
14
+ Linear(hidden_size β†’ hidden_dim) β†’ ReLU β†’ Dropout(p) β†’ Linear(hidden_dim β†’ 1)
15
+ producing a single latent score s ∈ ℝ.
16
+ 3. K-1 learnable raw_threshold parameters enforcing monotonicity via
17
+ cumsum(softplus(Β·)).
18
+ 4. Cumulative-link probabilities:
19
+ P(Y ≀ j | x) = Οƒ(ΞΈ_j βˆ’ s)
20
+
21
+ Usage
22
+ -----
23
+ from bert_ordinal import BertOrdinalConfig, BertOrdinal
24
+
25
+ # ── Create from scratch ──────────────────────────────────────────────────
26
+ cfg = BertOrdinalConfig(
27
+ bert_model_name="bert-base-uncased",
28
+ num_classes=3,
29
+ hidden_dim=128,
30
+ dropout=0.1,
31
+ freeze_bert=True,
32
+ )
33
+ model = BertOrdinal(cfg)
34
+
35
+ # ── Save ────────────────────────────────────────────────────────────────
36
+ model.save_pretrained("my-checkpoint/")
37
+ tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside
38
+
39
+ # ── Reload ──────────────────────────────────────────────────────────────
40
+ model = BertOrdinal.from_pretrained("my-checkpoint/")
41
+ tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/")
42
+ """
43
+
44
+ from __future__ import annotations
45
+ from typing import Optional
46
+ from transformers import PretrainedConfig
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # 1. Config β€” subclass PretrainedConfig for full HF serialisation
51
+ # ---------------------------------------------------------------------------
52
+
53
+ class BertOrdinalConfig(PretrainedConfig):
54
+ """
55
+ Configuration for :class:`BertOrdinal`.
56
+
57
+ Because this inherits from :class:`~transformers.PretrainedConfig`,
58
+ ``save_pretrained`` writes a ``config.json`` that ``from_pretrained``
59
+ can read back without any extra bookkeeping.
60
+
61
+ Parameters
62
+ ----------
63
+ bert_model_name : str
64
+ HuggingFace model name or local path for the BERT backbone.
65
+ num_classes : int
66
+ Number of ordinal classes K. Creates K-1 learnable thresholds.
67
+ hidden_dim : int
68
+ Inner dimension of the projection head.
69
+ dropout : float
70
+ Dropout probability inside the projection head.
71
+ freeze_bert : bool
72
+ Freeze backbone weights at construction time.
73
+ loss_reduction : str
74
+ ``'mean'`` or ``'sum'``.
75
+ """
76
+
77
+ # Tells HF which class owns this config (written into config.json).
78
+ model_type = "bert_ordinal"
79
+ problem_type = "single_label_classification"
80
+
81
+ def __init__(
82
+ self,
83
+ bert_model_name: str = "allenai/scibert_scivocab_uncased",
84
+ num_classes: int = 3,
85
+ hidden_dim: int = 256,
86
+ dropout: float = 0.1,
87
+ freeze_bert: bool = True,
88
+ loss_reduction: str = "mean",
89
+ # hidden_size is set automatically by the model after loading BERT;
90
+ # it is stored here so from_pretrained can rebuild the head offline.
91
+ hidden_size: Optional[int] = None,
92
+ **kwargs,
93
+ ) -> None:
94
+ super().__init__(**kwargs)
95
+ self.bert_model_name = bert_model_name
96
+ self.num_classes = num_classes
97
+ self.hidden_dim = hidden_dim
98
+ self.dropout = dropout
99
+ self.freeze_bert = freeze_bert
100
+ self.loss_reduction = loss_reduction
101
+ self.hidden_size = hidden_size # filled in by BertOrdinal.__init__
102
+
103
+ self.auto_map = {
104
+ "AutoConfig": "configuration_bert_ordinal.BertOrdinalConfig",
105
+ "AutoModel": "modeling_bert_ordinal.BertOrdinal",
106
+ "AutoModelForSequenceClassification": "modeling_bert_ordinal.BertOrdinal",
107
+ }
modeling_bert_ordinal.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ bert_ordinal.py
3
+ ---------------
4
+ BERT-based ordinal regression model, fully integrated with the HuggingFace
5
+ Transformers API:
6
+
7
+ model.save_pretrained("my-checkpoint/")
8
+ model = BertOrdinal.from_pretrained("my-checkpoint/")
9
+
10
+ Architecture
11
+ ------------
12
+ 1. A (optionally frozen) BERT backbone.
13
+ 2. A projection head on the [CLS] token:
14
+ Linear(hidden_size β†’ hidden_dim) β†’ ReLU β†’ Dropout(p) β†’ Linear(hidden_dim β†’ 1)
15
+ producing a single latent score s ∈ ℝ.
16
+ 3. K-1 learnable raw_threshold parameters enforcing monotonicity via
17
+ cumsum(softplus(Β·)).
18
+ 4. Cumulative-link probabilities:
19
+ P(Y ≀ j | x) = Οƒ(ΞΈ_j βˆ’ s)
20
+
21
+ Usage
22
+ -----
23
+ from bert_ordinal import BertOrdinalConfig, BertOrdinal
24
+
25
+ # ── Create from scratch ──────────────────────────────────────────────────
26
+ cfg = BertOrdinalConfig(
27
+ bert_model_name="bert-base-uncased",
28
+ num_classes=3,
29
+ hidden_dim=128,
30
+ dropout=0.1,
31
+ freeze_bert=True,
32
+ )
33
+ model = BertOrdinal(cfg)
34
+
35
+ # ── Save ────────────────────────────────────────────────────────────────
36
+ model.save_pretrained("my-checkpoint/")
37
+ tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside
38
+
39
+ # ── Reload ──────────────────────────────────────────────────────────────
40
+ model = BertOrdinal.from_pretrained("my-checkpoint/")
41
+ tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/")
42
+ """
43
+
44
+ from __future__ import annotations
45
+
46
+ from dataclasses import dataclass
47
+ from typing import Optional
48
+
49
+ import torch
50
+ import torch.nn as nn
51
+ import torch.nn.functional as F
52
+ from transformers import AutoModel, PreTrainedModel
53
+ from transformers.modeling_outputs import ModelOutput
54
+
55
+ from configuration_bert_ordinal import BertOrdinalConfig
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # 1. Output dataclass
59
+ # ---------------------------------------------------------------------------
60
+
61
+ @dataclass
62
+ class BertOrdinalOutput(ModelOutput):
63
+ """
64
+ Return type of :class:`BertOrdinal`.
65
+
66
+ Attributes
67
+ ----------
68
+ loss : torch.Tensor or None
69
+ Ordinal cross-entropy loss (scalar). Present only when ``labels``
70
+ are supplied.
71
+ logits : torch.Tensor (B,)
72
+ Raw latent score from the projection head.
73
+ predictions : torch.Tensor (B,)
74
+ Predicted class index β€” argmax of ``class_probs``.
75
+ cum_probs : torch.Tensor (B, K-1)
76
+ Cumulative probabilities P(Y ≀ j | x).
77
+ class_probs : torch.Tensor (B, K)
78
+ Per-class probabilities P(Y = j | x).
79
+ """
80
+
81
+ loss: Optional[torch.Tensor] = None
82
+ logits: Optional[torch.Tensor] = None
83
+ predictions: Optional[torch.Tensor] = None
84
+ cum_probs: Optional[torch.Tensor] = None
85
+ class_probs: Optional[torch.Tensor] = None
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # 3. Model β€” subclass PreTrainedModel for save / from_pretrained
90
+ # ---------------------------------------------------------------------------
91
+
92
+ class BertOrdinal(PreTrainedModel):
93
+ """
94
+ BERT encoder with an ordinal-regression head.
95
+
96
+ Fully compatible with the HuggingFace checkpoint API::
97
+
98
+ model.save_pretrained("my-checkpoint/")
99
+ model = BertOrdinal.from_pretrained("my-checkpoint/")
100
+
101
+ What gets saved
102
+ ~~~~~~~~~~~~~~~
103
+ ``save_pretrained`` writes two files:
104
+
105
+ * ``config.json`` β€” the full :class:`BertOrdinalConfig` (including
106
+ ``bert_model_name``, ``hidden_size``, thresholds shape, …).
107
+ * ``model.safetensors`` (or ``pytorch_model.bin``) β€” a **single flat
108
+ state_dict** containing both the BERT backbone weights and the
109
+ head/threshold parameters.
110
+
111
+ ``from_pretrained`` reconstructs the model from the config (which
112
+ already has ``hidden_size`` cached), loads the state_dict, and
113
+ re-applies the ``freeze_bert`` setting β€” no internet access needed
114
+ after the first save.
115
+ """
116
+
117
+ config_class = BertOrdinalConfig
118
+
119
+ def __init__(self, config: BertOrdinalConfig) -> None:
120
+ super().__init__(config)
121
+ K = config.num_classes
122
+
123
+ # ── 1. BERT backbone ────────────────────────────────────────────────
124
+ # If hidden_size is already in the config (i.e. we are being called
125
+ # from from_pretrained after a save), build the backbone from the
126
+ # cached backbone config instead of re-downloading weights β€”
127
+ # from_pretrained will overwrite with the saved state_dict anyway.
128
+ self.bert = AutoModel.from_pretrained(config.bert_model_name)
129
+ hidden_size: int = self.bert.config.hidden_size
130
+
131
+ # Cache so the head can be rebuilt offline after save_pretrained.
132
+ config.hidden_size = hidden_size
133
+
134
+ if config.freeze_bert:
135
+ for param in self.bert.parameters():
136
+ param.requires_grad = False
137
+
138
+ # ── 2. Projection head ──────────────────────────────────────────────
139
+ self.head = nn.Sequential(
140
+ nn.Linear(hidden_size, config.hidden_dim),
141
+ nn.ReLU(),
142
+ nn.Dropout(config.dropout),
143
+ nn.Linear(config.hidden_dim, 1),
144
+ )
145
+ self._init_head()
146
+
147
+ # ── 3. Ordinal thresholds ───────────────────────────────────────────
148
+ # K-1 raw values; monotonicity enforced via cumsum(softplus(Β·)).
149
+ self.raw_thresholds = nn.Parameter(torch.zeros(K - 1))
150
+ with torch.no_grad():
151
+ targets = torch.linspace(-1.0, 1.0, K - 1)
152
+ diffs = torch.cat([targets[:1], targets[1:] - targets[:-1]])
153
+ self.raw_thresholds.copy_(
154
+ torch.log(torch.expm1(diffs.clamp(min=1e-3)))
155
+ )
156
+
157
+ # Finalises weight init bookkeeping required by PreTrainedModel.
158
+ self.post_init()
159
+
160
+ # -----------------------------------------------------------------------
161
+ # Helpers
162
+ # -----------------------------------------------------------------------
163
+
164
+ def _init_head(self) -> None:
165
+ for m in self.head.modules():
166
+ if isinstance(m, nn.Linear):
167
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
168
+ nn.init.zeros_(m.bias)
169
+
170
+ @property
171
+ def thresholds(self) -> torch.Tensor:
172
+ """Monotone thresholds θ₁ ≀ … ≀ ΞΈ_{K-1} (shape: K-1)."""
173
+ return torch.cumsum(F.softplus(self.raw_thresholds), dim=0)
174
+
175
+ # -----------------------------------------------------------------------
176
+ # Forward
177
+ # -----------------------------------------------------------------------
178
+
179
+ def forward(
180
+ self,
181
+ input_ids: torch.Tensor,
182
+ attention_mask: torch.Tensor,
183
+ token_type_ids: Optional[torch.Tensor] = None,
184
+ labels: Optional[torch.Tensor] = None,
185
+ **kwargs,
186
+ ) -> BertOrdinalOutput:
187
+ """
188
+ Parameters
189
+ ----------
190
+ input_ids : (B, L)
191
+ attention_mask : (B, L)
192
+ token_type_ids : (B, L) optional
193
+ labels : (B,) long β€” class indices in {0, …, K-1}
194
+
195
+ Returns
196
+ -------
197
+ BertOrdinalOutput
198
+ """
199
+ # ── Encode ──────────────────────────────────────────────────────────
200
+ bert_kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
201
+ if token_type_ids is not None:
202
+ bert_kwargs["token_type_ids"] = token_type_ids
203
+
204
+ cls_repr = self.bert(**bert_kwargs).last_hidden_state[:, 0, :] # (B, H)
205
+
206
+ # ── Latent score ────────────────────────────────────────────────────
207
+ score = self.head(cls_repr).squeeze(-1) # (B,)
208
+
209
+ # ── Cumulative probs P(Y ≀ j) = Οƒ(ΞΈ_j βˆ’ score) ────────────────────
210
+ cum_logits = self.thresholds.unsqueeze(0) - score.unsqueeze(1) # (B, K-1)
211
+ cum_probs = torch.sigmoid(cum_logits) # (B, K-1)
212
+
213
+ # ── Class probs P(Y = j) = P(Y ≀ j) βˆ’ P(Y ≀ j-1) ─────────────────
214
+ B, dev = cum_probs.size(0), cum_probs.device
215
+ F_ = torch.cat(
216
+ [torch.zeros(B, 1, device=dev), cum_probs, torch.ones(B, 1, device=dev)],
217
+ dim=1,
218
+ ) # (B, K+1)
219
+ class_probs = (F_[:, 1:] - F_[:, :-1]).clamp(min=1e-9) # (B, K)
220
+
221
+ # ── Predictions ──────────────────────────────────────────────────────
222
+ predictions = class_probs.argmax(dim=-1) # (B,)
223
+
224
+ # ── Loss ─────────────────────────────────────────────────────────────
225
+ loss: Optional[torch.Tensor] = None
226
+ if labels is not None:
227
+ loss = ordinal_cross_entropy(
228
+ class_probs, labels, reduction=self.config.loss_reduction
229
+ )
230
+
231
+ return BertOrdinalOutput(
232
+ loss=loss,
233
+ logits=score,
234
+ predictions=predictions,
235
+ cum_probs=cum_probs,
236
+ class_probs=class_probs,
237
+ )
238
+
239
+ # -----------------------------------------------------------------------
240
+ # Convenience
241
+ # -----------------------------------------------------------------------
242
+
243
+ @torch.no_grad()
244
+ def predict(
245
+ self,
246
+ input_ids: torch.Tensor,
247
+ attention_mask: torch.Tensor,
248
+ token_type_ids: Optional[torch.Tensor] = None,
249
+ ) -> torch.Tensor:
250
+ """Return predicted class indices (no loss computed)."""
251
+ return self.forward(input_ids, attention_mask, token_type_ids).predictions
252
+
253
+
254
+ # ---------------------------------------------------------------------------
255
+ # Loss function
256
+ # ---------------------------------------------------------------------------
257
+
258
+ def ordinal_cross_entropy(
259
+ class_probs: torch.Tensor,
260
+ labels: torch.Tensor,
261
+ reduction: str = "mean",
262
+ ) -> torch.Tensor:
263
+ """
264
+ Ordinal cross-entropy.
265
+
266
+ Parameters
267
+ ----------
268
+ class_probs : (B, K) β€” P(Y=j|x), clamped > 0
269
+ labels : (B,) β€” ground-truth indices in {0, …, K-1}
270
+ reduction : 'mean' | 'sum' | 'none'
271
+ """
272
+ return F.nll_loss(torch.log(class_probs), labels, reduction=reduction)
273
+
274
+
275
+ # ---------------------------------------------------------------------------
276
+ # Register the model with the Transformers library
277
+ # ---------------------------------------------------------------------------
278
+ from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification
279
+
280
+ AutoConfig.register("bert_ordinal", BertOrdinalConfig)
281
+ AutoModel.register(BertOrdinalConfig, BertOrdinal)
282
+ AutoModelForSequenceClassification.register(BertOrdinalConfig, BertOrdinal)