modeling: sync xLSTMForSequenceClassification with Patrick's codebase from https://github.com/HallerPatrick/helibrunna/blob/a1b377271867d5f23201ccacb55e017749aba487/model/modeling_xlstm.py
Browse files- modeling_xlstm.py +83 -1
modeling_xlstm.py
CHANGED
@@ -2,8 +2,9 @@ from typing import Optional, Sequence, Tuple, Union
|
|
2 |
|
3 |
import torch
|
4 |
from torch import nn
|
|
|
5 |
from transformers import PreTrainedModel
|
6 |
-
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
|
7 |
from xlstm.components.init import small_init_init_
|
8 |
from xlstm.utils import WeightDecayOptimGroupMixin
|
9 |
from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
|
@@ -212,3 +213,84 @@ class xLSTMForCausalLM(xLSTMPreTrainedModel, WeightDecayOptimGroupMixin):
|
|
212 |
"input_ids": input_ids.to(self.device),
|
213 |
}
|
214 |
return model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
6 |
from transformers import PreTrainedModel
|
7 |
+
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
8 |
from xlstm.components.init import small_init_init_
|
9 |
from xlstm.utils import WeightDecayOptimGroupMixin
|
10 |
from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
|
|
|
213 |
"input_ids": input_ids.to(self.device),
|
214 |
}
|
215 |
return model_inputs
|
216 |
+
|
217 |
+
|
218 |
+
class xLSTMForSequenceClassification(xLSTMPreTrainedModel):
|
219 |
+
|
220 |
+
def __init__(self, config: xLSTMConfig, **kwargs):
|
221 |
+
super().__init__(config)
|
222 |
+
self.num_labels = config.num_labels
|
223 |
+
self.config = config
|
224 |
+
self.model = xLSTMModel(config)
|
225 |
+
self.classifier = nn.Linear(config.embedding_dim, config.num_labels, bias=False)
|
226 |
+
|
227 |
+
self.init_weights()
|
228 |
+
|
229 |
+
def forward(
|
230 |
+
self,
|
231 |
+
input_ids: torch.Tensor,
|
232 |
+
labels: Optional[torch.LongTensor] = None,
|
233 |
+
output_hidden_states: Optional[bool] = None,
|
234 |
+
return_dict: Optional[bool] = None,
|
235 |
+
):
|
236 |
+
output = self.model(
|
237 |
+
input_ids,
|
238 |
+
output_hidden_states=output_hidden_states,
|
239 |
+
)
|
240 |
+
|
241 |
+
hidden_state = output[0]
|
242 |
+
|
243 |
+
logits = self.classifier(hidden_state)
|
244 |
+
batch_size = input_ids.shape[0]
|
245 |
+
|
246 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
247 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
248 |
+
if self.config.pad_token_id is None:
|
249 |
+
sequence_lengths = -1
|
250 |
+
else:
|
251 |
+
if input_ids is not None:
|
252 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
253 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
254 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
255 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
256 |
+
else:
|
257 |
+
sequence_lengths = -1
|
258 |
+
|
259 |
+
|
260 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
261 |
+
|
262 |
+
loss = None
|
263 |
+
|
264 |
+
if labels is not None:
|
265 |
+
labels = labels.to(logits.device)
|
266 |
+
if self.config.problem_type is None:
|
267 |
+
if self.num_labels == 1:
|
268 |
+
self.config.problem_type = "regression"
|
269 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
270 |
+
self.config.problem_type = "single_label_classification"
|
271 |
+
else:
|
272 |
+
self.config.problem_type = "multi_label_classification"
|
273 |
+
|
274 |
+
if self.config.problem_type == "regression":
|
275 |
+
loss_fct = MSELoss()
|
276 |
+
if self.num_labels == 1:
|
277 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
278 |
+
else:
|
279 |
+
loss = loss_fct(pooled_logits, labels)
|
280 |
+
elif self.config.problem_type == "single_label_classification":
|
281 |
+
loss_fct = CrossEntropyLoss()
|
282 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
283 |
+
elif self.config.problem_type == "multi_label_classification":
|
284 |
+
loss_fct = BCEWithLogitsLoss()
|
285 |
+
loss = loss_fct(pooled_logits, labels)
|
286 |
+
|
287 |
+
if not return_dict:
|
288 |
+
output = (pooled_logits,) + output[1:]
|
289 |
+
return ((loss,) + output) if loss is not None else output
|
290 |
+
|
291 |
+
|
292 |
+
return SequenceClassifierOutputWithPast(
|
293 |
+
loss=loss,
|
294 |
+
logits=pooled_logits,
|
295 |
+
hidden_states=output.hidden_states,
|
296 |
+
)
|