fdschmidt93
commited on
Commit
•
ee8c043
1
Parent(s):
ccaa121
chore: formatting
Browse files
modeling_seamless_m4t_v2_speech_encoder.py
CHANGED
@@ -14,7 +14,11 @@ from .configuration_seamless_m4t_v2_speech_encoder import (
|
|
14 |
)
|
15 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
16 |
|
17 |
-
from transformers.models.auto import
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder):
|
@@ -88,7 +92,9 @@ class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel):
|
|
88 |
if self.config.problem_type is None:
|
89 |
if self.num_labels == 1:
|
90 |
self.config.problem_type = "regression"
|
91 |
-
elif self.num_labels > 1 and (
|
|
|
|
|
92 |
self.config.problem_type = "single_label_classification"
|
93 |
else:
|
94 |
self.config.problem_type = "multi_label_classification"
|
|
|
14 |
)
|
15 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
16 |
|
17 |
+
from transformers.models.auto import (
|
18 |
+
AutoModel,
|
19 |
+
AutoModelForAudioClassification,
|
20 |
+
AutoModelForSequenceClassification,
|
21 |
+
)
|
22 |
|
23 |
|
24 |
class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder):
|
|
|
92 |
if self.config.problem_type is None:
|
93 |
if self.num_labels == 1:
|
94 |
self.config.problem_type = "regression"
|
95 |
+
elif self.num_labels > 1 and (
|
96 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
97 |
+
):
|
98 |
self.config.problem_type = "single_label_classification"
|
99 |
else:
|
100 |
self.config.problem_type = "multi_label_classification"
|