fdschmidt93 commited on
Commit
1057217
1 Parent(s): 65a0eff

feat(sequence_clf): add all tasks (regression, multi-class, multi-label)

Browse files
modeling_seamless_m4t_v2_speech_encoder.py CHANGED
@@ -89,10 +89,30 @@ class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel):
89
  outputs.last_hidden_state, attention_mask
90
  )
91
  logits = self.score(hidden_states)
 
92
  if labels is not None:
93
- loss = F.cross_entropy(logits, labels)
94
- else:
95
- loss = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  return SequenceClassifierOutput(
97
  loss=loss, # type: ignore
98
  logits=logits,
 
89
  outputs.last_hidden_state, attention_mask
90
  )
91
  logits = self.score(hidden_states)
92
+
93
  if labels is not None:
94
+ # move labels to correct device to enable model parallelism
95
+ labels = labels.to(logits.device)
96
+ if self.config.problem_type is None:
97
+ if self.num_labels == 1:
98
+ self.config.problem_type = "regression"
99
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
100
+ self.config.problem_type = "single_label_classification"
101
+ else:
102
+ self.config.problem_type = "multi_label_classification"
103
+ if self.config.problem_type == "regression":
104
+ loss_fct = F.mse_loss
105
+ if self.num_labels == 1:
106
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
107
+ else:
108
+ loss = loss_fct(logits, labels)
109
+ elif self.config.problem_type == "single_label_classification":
110
+ loss_fct = F.cross_entropy
111
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
112
+ elif self.config.problem_type == "multi_label_classification":
113
+ loss_fct = F.binary_cross_entropy_with_logits
114
+ loss = loss_fct(logits, labels)
115
+
116
  return SequenceClassifierOutput(
117
  loss=loss, # type: ignore
118
  logits=logits,