KhaldiAbderrhmane commited on
Commit
2df661b
1 Parent(s): 528a5f2

Update modeling_emotion_classifier.py

Browse files
Files changed (1) hide show
  1. modeling_emotion_classifier.py +5 -147
modeling_emotion_classifier.py CHANGED
@@ -1,35 +1,20 @@
1
- import os
2
- import torch
3
  import torch.nn as nn
4
- from transformers import PreTrainedModel, HubertConfig, HubertModel,PretrainedConfig
5
- from transformers.file_utils import (
6
- WEIGHTS_NAME,
7
- TF2_WEIGHTS_NAME,
8
- TF_WEIGHTS_NAME,
9
- cached_path,
10
- hf_bucket_url,
11
- is_remote_url,
12
- )
13
- from transformers.utils import logging
14
  from .configuration_emotion_classifier import EmotionClassifierConfig
15
 
16
- logger = logging.get_logger(__name__)
17
 
18
  class EmotionClassifierHuBERT(PreTrainedModel):
19
  config_class = EmotionClassifierConfig
20
 
21
  def __init__(self, config):
22
  super().__init__(config)
23
-
24
- # Initialize HuBERT without pre-trained weights
25
- hubert_config = HubertConfig.from_pretrained("facebook/hubert-large-ls960-ft")
26
- self.hubert = HubertModel(hubert_config)
27
-
28
  self.conv1 = nn.Conv1d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)
29
  self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
30
  self.transformer_encoder = nn.TransformerEncoderLayer(d_model=256, nhead=8)
31
  self.bilstm = nn.LSTM(input_size=256, hidden_size=config.hidden_size_lstm, num_layers=2, batch_first=True, bidirectional=True)
32
- self.fc = nn.Linear(config.hidden_size_lstm * 2, config.num_classes)
33
 
34
  def forward(self, x):
35
  with torch.no_grad():
@@ -41,131 +26,4 @@ class EmotionClassifierHuBERT(PreTrainedModel):
41
  x = self.transformer_encoder(x)
42
  x, _ = self.bilstm(x)
43
  x = self.fc(x[:, -1, :])
44
- return x
45
-
46
- @classmethod
47
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
48
- config = kwargs.pop("config", None)
49
- state_dict = kwargs.pop("state_dict", None)
50
- cache_dir = kwargs.pop("cache_dir", None)
51
- from_tf = kwargs.pop("from_tf", False)
52
- force_download = kwargs.pop("force_download", False)
53
- resume_download = kwargs.pop("resume_download", False)
54
- proxies = kwargs.pop("proxies", None)
55
- output_loading_info = kwargs.pop("output_loading_info", False)
56
- local_files_only = kwargs.pop("local_files_only", False)
57
- use_auth_token = kwargs.pop("use_auth_token", None)
58
- revision = kwargs.pop("revision", None)
59
- mirror = kwargs.pop("mirror", None)
60
-
61
- # Load config if we don't provide a configuration
62
- if not isinstance(config, EmotionClassifierConfig):
63
- config_path = config if config is not None else pretrained_model_name_or_path
64
- config, model_kwargs = cls.config_class.from_pretrained(
65
- config_path,
66
- *model_args,
67
- cache_dir=cache_dir,
68
- return_unused_kwargs=True,
69
- force_download=force_download,
70
- resume_download=resume_download,
71
- proxies=proxies,
72
- local_files_only=local_files_only,
73
- use_auth_token=use_auth_token,
74
- revision=revision,
75
- **kwargs,
76
- )
77
- else:
78
- model_kwargs = kwargs
79
-
80
- # Load model
81
- if pretrained_model_name_or_path is not None:
82
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
83
- if os.path.isdir(pretrained_model_name_or_path):
84
- if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
85
- # Load from a TF 1.0 checkpoint in priority if from_tf
86
- archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
87
- elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
88
- # Load from a TF 2.0 checkpoint in priority if from_tf
89
- archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
90
- elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
91
- # Load from a PyTorch checkpoint
92
- archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
93
- else:
94
- raise EnvironmentError(
95
- f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index']} found in "
96
- f"directory {pretrained_model_name_or_path} or '{pretrained_model_name_or_path}' is not a directory."
97
- )
98
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
99
- archive_file = pretrained_model_name_or_path
100
- else:
101
- # Load from URL or cache
102
- archive_file = hf_bucket_url(
103
- pretrained_model_name_or_path,
104
- filename=WEIGHTS_NAME,
105
- revision=revision,
106
- mirror=mirror,
107
- )
108
-
109
- try:
110
- # Load from URL or cache
111
- resolved_archive_file = cached_path(
112
- archive_file,
113
- cache_dir=cache_dir,
114
- force_download=force_download,
115
- proxies=proxies,
116
- resume_download=resume_download,
117
- local_files_only=local_files_only,
118
- use_auth_token=use_auth_token,
119
- )
120
- except EnvironmentError as err:
121
- logger.error(err)
122
- msg = (
123
- f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
124
- f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
125
- f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
126
- )
127
- raise EnvironmentError(msg)
128
-
129
- if resolved_archive_file == archive_file:
130
- logger.info(f"loading weights file {archive_file}")
131
- else:
132
- logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
133
- else:
134
- resolved_archive_file = None
135
-
136
- # Initialize the model
137
- model = cls(config)
138
-
139
- if state_dict is None:
140
- try:
141
- state_dict = torch.load(resolved_archive_file, map_location="cpu")
142
- except Exception:
143
- raise OSError(
144
- f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
145
- f"at '{resolved_archive_file}'"
146
- )
147
-
148
- # Remove the prefix 'module' from the keys if present (happens when using DataParallel)
149
- state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
150
-
151
- # Load only the custom model weights, excluding HuBERT
152
- custom_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('hubert.')}
153
- missing_keys, unexpected_keys = model.load_state_dict(custom_state_dict, strict=False)
154
-
155
- if len(missing_keys) > 0:
156
- logger.warning(f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
157
- f"and are newly initialized: {missing_keys}\n"
158
- f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
159
- if len(unexpected_keys) > 0:
160
- logger.warning(f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
161
- f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
162
- f"This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
163
- f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
164
- f"This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical "
165
- f"(initializing a BertForSequenceClassification model from a BertForSequenceClassification model).")
166
-
167
- if output_loading_info:
168
- loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
169
- return model, loading_info
170
-
171
- return model
 
1
+ from transformers import PreTrainedModel, HubertModel
 
2
  import torch.nn as nn
3
+ import torch
 
 
 
 
 
 
 
 
 
4
  from .configuration_emotion_classifier import EmotionClassifierConfig
5
 
 
6
 
7
  class EmotionClassifierHuBERT(PreTrainedModel):
8
  config_class = EmotionClassifierConfig
9
 
10
  def __init__(self, config):
11
  super().__init__(config)
12
+ self.hubert = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
 
 
 
 
13
  self.conv1 = nn.Conv1d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)
14
  self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
15
  self.transformer_encoder = nn.TransformerEncoderLayer(d_model=256, nhead=8)
16
  self.bilstm = nn.LSTM(input_size=256, hidden_size=config.hidden_size_lstm, num_layers=2, batch_first=True, bidirectional=True)
17
+ self.fc = nn.Linear(config.hidden_size_lstm * 2, config.num_classes) # * 2 for bidirectional
18
 
19
  def forward(self, x):
20
  with torch.no_grad():
 
26
  x = self.transformer_encoder(x)
27
  x, _ = self.bilstm(x)
28
  x = self.fc(x[:, -1, :])
29
+ return x