Luis J Camargo commited on
Commit
9b7b4e8
·
1 Parent(s): 481d76f

wrong class fix

Browse files
Files changed (1) hide show
  1. app.py +39 -11
app.py CHANGED
@@ -19,27 +19,55 @@ class WhisperEncoderOnlyConfig(WhisperConfig):
19
  self.n_super = n_super
20
  self.n_code = n_code
21
 
22
- class WhisperEncoderOnlyForClassification(nn.Module):
23
  config_class = WhisperEncoderOnlyConfig
24
-
25
  def __init__(self, config):
26
- super().__init__()
27
- self.config = config
28
  self.encoder = WhisperEncoder(config)
29
-
30
  hidden = config.d_model
31
  self.fam_head = nn.Linear(hidden, config.n_fam)
32
  self.super_head = nn.Linear(hidden, config.n_super)
33
  self.code_head = nn.Linear(hidden, config.n_code)
34
-
35
- def forward(self, input_features):
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  enc_out = self.encoder(input_features=input_features)
37
  pooled = enc_out.last_hidden_state.mean(dim=1)
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  return {
40
- "fam_logits": self.fam_head(pooled),
41
- "super_logits": self.super_head(pooled),
42
- "code_logits": self.code_head(pooled),
 
43
  }
44
 
45
  # === REGISTER MODEL ===
 
19
  self.n_super = n_super
20
  self.n_code = n_code
21
 
22
+ class WhisperEncoderOnlyForClassification(WhisperPreTrainedModel):
23
  config_class = WhisperEncoderOnlyConfig
24
+
25
  def __init__(self, config):
26
+ super().__init__(config)
27
+
28
  self.encoder = WhisperEncoder(config)
29
+
30
  hidden = config.d_model
31
  self.fam_head = nn.Linear(hidden, config.n_fam)
32
  self.super_head = nn.Linear(hidden, config.n_super)
33
  self.code_head = nn.Linear(hidden, config.n_code)
34
+
35
+ self.post_init()
36
+
37
+ def get_input_embeddings(self):
38
+ """Whisper doesn't have token embeddings"""
39
+ return None
40
+
41
+ def set_input_embeddings(self, value):
42
+ """Ignore"""
43
+ pass
44
+
45
+ def enable_input_require_grads(self):
46
+ return
47
+
48
+ def forward(self, input_features, labels=None):
49
  enc_out = self.encoder(input_features=input_features)
50
  pooled = enc_out.last_hidden_state.mean(dim=1)
51
+
52
+ fam_logits = self.fam_head(pooled)
53
+ super_logits = self.super_head(pooled)
54
+ code_logits = self.code_head(pooled)
55
+
56
+ loss = None
57
+ if labels is not None:
58
+ fam_labels, super_labels, code_labels = labels
59
+ loss_fn = nn.CrossEntropyLoss()
60
+ loss = (
61
+ loss_fn(fam_logits, fam_labels) +
62
+ loss_fn(super_logits, super_labels) +
63
+ loss_fn(code_logits, code_labels)
64
+ )
65
+
66
  return {
67
+ "loss": loss,
68
+ "fam_logits": fam_logits,
69
+ "super_logits": super_logits,
70
+ "code_logits": code_logits,
71
  }
72
 
73
  # === REGISTER MODEL ===