Spaces:
Running
Running
Luis J Camargo commited on
Commit ·
9b7b4e8
1
Parent(s): 481d76f
wrong class fix
Browse files
app.py
CHANGED
|
@@ -19,27 +19,55 @@ class WhisperEncoderOnlyConfig(WhisperConfig):
|
|
| 19 |
self.n_super = n_super
|
| 20 |
self.n_code = n_code
|
| 21 |
|
| 22 |
-
class WhisperEncoderOnlyForClassification(
|
| 23 |
config_class = WhisperEncoderOnlyConfig
|
| 24 |
-
|
| 25 |
def __init__(self, config):
|
| 26 |
-
super().__init__()
|
| 27 |
-
|
| 28 |
self.encoder = WhisperEncoder(config)
|
| 29 |
-
|
| 30 |
hidden = config.d_model
|
| 31 |
self.fam_head = nn.Linear(hidden, config.n_fam)
|
| 32 |
self.super_head = nn.Linear(hidden, config.n_super)
|
| 33 |
self.code_head = nn.Linear(hidden, config.n_code)
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
enc_out = self.encoder(input_features=input_features)
|
| 37 |
pooled = enc_out.last_hidden_state.mean(dim=1)
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
return {
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
# === REGISTER MODEL ===
|
|
|
|
| 19 |
self.n_super = n_super
|
| 20 |
self.n_code = n_code
|
| 21 |
|
| 22 |
+
class WhisperEncoderOnlyForClassification(WhisperPreTrainedModel):
|
| 23 |
config_class = WhisperEncoderOnlyConfig
|
| 24 |
+
|
| 25 |
def __init__(self, config):
|
| 26 |
+
super().__init__(config)
|
| 27 |
+
|
| 28 |
self.encoder = WhisperEncoder(config)
|
| 29 |
+
|
| 30 |
hidden = config.d_model
|
| 31 |
self.fam_head = nn.Linear(hidden, config.n_fam)
|
| 32 |
self.super_head = nn.Linear(hidden, config.n_super)
|
| 33 |
self.code_head = nn.Linear(hidden, config.n_code)
|
| 34 |
+
|
| 35 |
+
self.post_init()
|
| 36 |
+
|
| 37 |
+
def get_input_embeddings(self):
|
| 38 |
+
"""Whisper doesn't have token embeddings"""
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
def set_input_embeddings(self, value):
|
| 42 |
+
"""Ignore"""
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def enable_input_require_grads(self):
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
def forward(self, input_features, labels=None):
|
| 49 |
enc_out = self.encoder(input_features=input_features)
|
| 50 |
pooled = enc_out.last_hidden_state.mean(dim=1)
|
| 51 |
+
|
| 52 |
+
fam_logits = self.fam_head(pooled)
|
| 53 |
+
super_logits = self.super_head(pooled)
|
| 54 |
+
code_logits = self.code_head(pooled)
|
| 55 |
+
|
| 56 |
+
loss = None
|
| 57 |
+
if labels is not None:
|
| 58 |
+
fam_labels, super_labels, code_labels = labels
|
| 59 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 60 |
+
loss = (
|
| 61 |
+
loss_fn(fam_logits, fam_labels) +
|
| 62 |
+
loss_fn(super_logits, super_labels) +
|
| 63 |
+
loss_fn(code_logits, code_labels)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
return {
|
| 67 |
+
"loss": loss,
|
| 68 |
+
"fam_logits": fam_logits,
|
| 69 |
+
"super_logits": super_logits,
|
| 70 |
+
"code_logits": code_logits,
|
| 71 |
}
|
| 72 |
|
| 73 |
# === REGISTER MODEL ===
|