date3k2 commited on
Commit
2d6f13c
1 Parent(s): e6b0814

Update hf_mamba_classification.py

Browse files
Files changed (1) hide show
  1. hf_mamba_classification.py +29 -3
hf_mamba_classification.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from torch import nn
3
- from torch.nn import CrossEntropyLoss
4
  from transformers.models.mamba.modeling_mamba import (
5
  MambaPreTrainedModel,
6
  MambaModel,
@@ -44,7 +44,9 @@ class MambaSequenceClassifierOutput(ModelOutput):
44
 
45
  loss: Optional[torch.FloatTensor] = None
46
  logits: torch.FloatTensor = None
 
47
  cache_params: Optional[List[torch.FloatTensor]] = None
 
48
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
49
 
50
 
@@ -149,8 +151,32 @@ class MambaForSequenceClassification(MambaPreTrainedModel):
149
  torch.arange(batch_size, device=logits.device), sequence_lengths
150
  ]
151
 
152
- loss_fct = CrossEntropyLoss()
153
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  if not return_dict:
156
  output = (pooled_logits,) + mamba_outputs[1:]
 
1
  import torch
2
  from torch import nn
3
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4
  from transformers.models.mamba.modeling_mamba import (
5
  MambaPreTrainedModel,
6
  MambaModel,
 
44
 
45
  loss: Optional[torch.FloatTensor] = None
46
  logits: torch.FloatTensor = None
47
+ # cache_params: Optional[MambaCache] = None,
48
  cache_params: Optional[List[torch.FloatTensor]] = None
49
+ # cache_params: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
50
  hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
51
 
52
 
 
151
  torch.arange(batch_size, device=logits.device), sequence_lengths
152
  ]
153
 
154
+ loss = None
155
+ if labels is not None:
156
+ if self.config.problem_type is None:
157
+ if self.num_labels == 1:
158
+ self.config.problem_type = "regression"
159
+ elif self.num_labels > 1 and (
160
+ labels.dtype == torch.long or labels.dtype == torch.int
161
+ ):
162
+ self.config.problem_type = "single_label_classification"
163
+ else:
164
+ self.config.problem_type = "multi_label_classification"
165
+
166
+ if self.config.problem_type == "regression":
167
+ loss_fct = MSELoss()
168
+ if self.num_labels == 1:
169
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
170
+ else:
171
+ loss = loss_fct(pooled_logits, labels)
172
+ elif self.config.problem_type == "single_label_classification":
173
+ loss_fct = CrossEntropyLoss()
174
+ loss = loss_fct(
175
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
176
+ )
177
+ elif self.config.problem_type == "multi_label_classification":
178
+ loss_fct = BCEWithLogitsLoss()
179
+ loss = loss_fct(pooled_logits, labels)
180
 
181
  if not return_dict:
182
  output = (pooled_logits,) + mamba_outputs[1:]