arslanarjumand commited on
Commit
995bf88
1 Parent(s): e1552ea

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +19 -12
model.py CHANGED
@@ -4,6 +4,7 @@ from typing import Optional, Tuple, Union
4
  from torch.nn import MSELoss
5
  import torch
6
  import torch.nn as nn
 
7
 
8
  class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
9
  # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert
@@ -19,7 +20,19 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
19
  if config.use_weighted_layer_sum:
20
  self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
21
  self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
22
- self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Initialize weights and apply final processing
25
  self.post_init()
@@ -69,20 +82,14 @@ class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
69
  else:
70
  hidden_states = outputs[0]
71
 
72
- hidden_states = self.projector(hidden_states)
73
- if attention_mask is None:
74
- pooled_output = hidden_states.mean(dim=1)
75
- else:
76
- padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
77
- hidden_states[~padding_mask] = 0.0
78
- pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
79
-
80
- logits = self.classifier(pooled_output)
81
- logits = nn.functional.relu(logits)
82
 
83
  loss = None
84
  if labels is not None:
85
- loss_fct = MSELoss()
86
  loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1, self.config.num_labels))
87
 
88
  if not return_dict:
 
4
  from torch.nn import MSELoss
5
  import torch
6
  import torch.nn as nn
7
+ import math
8
 
9
  class Wav2Vec2BertForSequenceClassification(Wav2Vec2BertPreTrainedModel):
10
  # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Bert,wav2vec2->wav2vec2_bert
 
20
  if config.use_weighted_layer_sum:
21
  self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
22
  self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
23
+ self.pooled_conv = nn.Sequential(nn.Conv1d(config.hidden_size, config.hidden_size // 2, kernel_size=15, stride=3, padding=30),
24
+ nn.AvgPool1d(2, 2),
25
+ nn.BatchNorm1d(config.hidden_size // 2),
26
+ nn.Conv1d(config.hidden_size // 2, config.classifier_proj_size, kernel_size=7, stride=2, padding=0),
27
+ nn.ReLU()
28
+ )
29
+
30
+ self.classifier = nn.Sequential(nn.Dropout(p=0.091,),
31
+ nn.Linear(config.classifier_proj_size, config.classifier_proj_size // 2),
32
+ nn.ReLU(),
33
+ nn.Linear(config.classifier_proj_size // 2, config.num_labels),
34
+ nn.ReLU(),
35
+ )
36
 
37
  # Initialize weights and apply final processing
38
  self.post_init()
 
82
  else:
83
  hidden_states = outputs[0]
84
 
85
+ hidden_states = hidden_states.permute(0, 2, 1)
86
+ hidden_states = self.pooled_conv(hidden_states)
87
+ hidden_states = torch.mean(hidden_states, dim=2)
88
+ logits = self.classifier(hidden_states)
 
 
 
 
 
 
89
 
90
  loss = None
91
  if labels is not None:
92
+ loss_fct = nn.L1Loss()
93
  loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1, self.config.num_labels))
94
 
95
  if not return_dict: