Spaces:
Running
Running
arslanarjumand
commited on
Commit
•
995bf88
1
Parent(s):
e1552ea
Update model.py
Browse files
model.py
CHANGED
@@ -4,6 +4,7 @@ from typing import Optional, Tuple, Union
|
|
4 |
from torch.nn import MSELoss
|
5 |
import torch
|
6 |
import torch.nn as nn
|
|
|
7 |
|
8 |
class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
|
9 |
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert
|
@@ -19,7 +20,19 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
|
|
19 |
if config.use_weighted_layer_sum:
|
20 |
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
21 |
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
22 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# Initialize weights and apply final processing
|
25 |
self.post_init()
|
@@ -69,20 +82,14 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
|
|
69 |
else:
|
70 |
hidden_states = outputs[0]
|
71 |
|
72 |
-
hidden_states =
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
77 |
-
hidden_states[~padding_mask] = 0.0
|
78 |
-
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
79 |
-
|
80 |
-
logits = self.classifier(pooled_output)
|
81 |
-
logits = nn.functional.relu(logits)
|
82 |
|
83 |
loss = None
|
84 |
if labels is not None:
|
85 |
-
loss_fct =
|
86 |
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1, self.config.num_labels))
|
87 |
|
88 |
if not return_dict:
|
|
|
4 |
from torch.nn import MSELoss
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
+
import math
|
8 |
|
9 |
class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
|
10 |
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert
|
|
|
20 |
if config.use_weighted_layer_sum:
|
21 |
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
22 |
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
23 |
+
self.pooled_conv = nn.Sequential(nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=15, stride=3, padding=30),
|
24 |
+
nn.AvgPool1d(2, 2),
|
25 |
+
nn.BatchNorm1d(config.hidden_size // 2),
|
26 |
+
nn.Conv1d(config.hidden_size // 2, config.classifier_proj_size, kernel_size=7, stride=2, padding=0),
|
27 |
+
nn.ReLU()
|
28 |
+
)
|
29 |
+
|
30 |
+
self.classifier = nn.Sequential(nn.Dropout(p=0.091,),
|
31 |
+
nn.Linear(config.classifier_proj_size, config.classifier_proj_size // 2),
|
32 |
+
nn.ReLU(),
|
33 |
+
nn.Linear(config.classifier_proj_size // 2, config.num_labels),
|
34 |
+
nn.ReLU(),
|
35 |
+
)
|
36 |
|
37 |
# Initialize weights and apply final processing
|
38 |
self.post_init()
|
|
|
82 |
else:
|
83 |
hidden_states = outputs[0]
|
84 |
|
85 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
86 |
+
hidden_states = self.pooled_conv(hidden_states)
|
87 |
+
hidden_states = torch.mean(hidden_states, dim=2)
|
88 |
+
logits = self.classifier(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
loss = None
|
91 |
if labels is not None:
|
92 |
+
loss_fct = nn.L1Loss()
|
93 |
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1, self.config.num_labels))
|
94 |
|
95 |
if not return_dict:
|