KhaldiAbderrhmane commited on
Commit
d9473b2
·
verified ·
1 Parent(s): a85b3f5

Update modeling_emotion_classifier.py

Browse files
Files changed (1) hide show
  1. modeling_emotion_classifier.py +134 -4
modeling_emotion_classifier.py CHANGED
@@ -1,20 +1,23 @@
1
- from transformers import PreTrainedModel, HubertModel
2
  import torch.nn as nn
3
  import torch
4
  from .configuration_emotion_classifier import EmotionClassifierConfig
5
 
6
-
7
  class EmotionClassifierHuBERT(PreTrainedModel):
8
  config_class = EmotionClassifierConfig
9
 
10
  def __init__(self, config):
11
  super().__init__(config)
12
- self.hubert = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
 
 
 
 
13
  self.conv1 = nn.Conv1d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)
14
  self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
15
  self.transformer_encoder = nn.TransformerEncoderLayer(d_model=256, nhead=8)
16
  self.bilstm = nn.LSTM(input_size=256, hidden_size=config.hidden_size_lstm, num_layers=2, batch_first=True, bidirectional=True)
17
- self.fc = nn.Linear(config.hidden_size_lstm * 2, config.num_classes) # * 2 for bidirectional
18
 
19
  def forward(self, x):
20
  with torch.no_grad():
@@ -27,3 +30,130 @@ class EmotionClassifierHuBERT(PreTrainedModel):
27
  x, _ = self.bilstm(x)
28
  x = self.fc(x[:, -1, :])
29
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, HubertConfig, HubertModel
2
  import torch.nn as nn
3
  import torch
4
  from .configuration_emotion_classifier import EmotionClassifierConfig
5
 
 
6
  class EmotionClassifierHuBERT(PreTrainedModel):
7
  config_class = EmotionClassifierConfig
8
 
9
  def __init__(self, config):
10
  super().__init__(config)
11
+
12
+ # Initialize HuBERT without pre-trained weights
13
+ hubert_config = HubertConfig.from_pretrained("facebook/hubert-large-ls960-ft")
14
+ self.hubert = HubertModel(hubert_config)
15
+
16
  self.conv1 = nn.Conv1d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)
17
  self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
18
  self.transformer_encoder = nn.TransformerEncoderLayer(d_model=256, nhead=8)
19
  self.bilstm = nn.LSTM(input_size=256, hidden_size=config.hidden_size_lstm, num_layers=2, batch_first=True, bidirectional=True)
20
+ self.fc = nn.Linear(config.hidden_size_lstm * 2, config.num_classes)
21
 
22
  def forward(self, x):
23
  with torch.no_grad():
 
30
  x, _ = self.bilstm(x)
31
  x = self.fc(x[:, -1, :])
32
  return x
33
+
34
+ @classmethod
35
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
36
+ config = kwargs.pop("config", None)
37
+ state_dict = kwargs.pop("state_dict", None)
38
+ cache_dir = kwargs.pop("cache_dir", None)
39
+ from_tf = kwargs.pop("from_tf", False)
40
+ force_download = kwargs.pop("force_download", False)
41
+ resume_download = kwargs.pop("resume_download", False)
42
+ proxies = kwargs.pop("proxies", None)
43
+ output_loading_info = kwargs.pop("output_loading_info", False)
44
+ local_files_only = kwargs.pop("local_files_only", False)
45
+ use_auth_token = kwargs.pop("use_auth_token", None)
46
+ revision = kwargs.pop("revision", None)
47
+ mirror = kwargs.pop("mirror", None)
48
+
49
+ # Load config if we don't provide a configuration
50
+ if not isinstance(config, PretrainedConfig):
51
+ config_path = config if config is not None else pretrained_model_name_or_path
52
+ config, model_kwargs = cls.config_class.from_pretrained(
53
+ config_path,
54
+ *model_args,
55
+ cache_dir=cache_dir,
56
+ return_unused_kwargs=True,
57
+ force_download=force_download,
58
+ resume_download=resume_download,
59
+ proxies=proxies,
60
+ local_files_only=local_files_only,
61
+ use_auth_token=use_auth_token,
62
+ revision=revision,
63
+ **kwargs,
64
+ )
65
+ else:
66
+ model_kwargs = kwargs
67
+
68
+ # Load model
69
+ if pretrained_model_name_or_path is not None:
70
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
71
+ if os.path.isdir(pretrained_model_name_or_path):
72
+ if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
73
+ # Load from a TF 1.0 checkpoint in priority if from_tf
74
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
75
+ elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
76
+ # Load from a TF 2.0 checkpoint in priority if from_tf
77
+ archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
78
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
79
+ # Load from a PyTorch checkpoint
80
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
81
+ else:
82
+ raise EnvironmentError(
83
+ f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index']} found in "
84
+ f"directory {pretrained_model_name_or_path} or '{pretrained_model_name_or_path}' is not a directory."
85
+ )
86
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
87
+ archive_file = pretrained_model_name_or_path
88
+ else:
89
+ # Load from URL or cache
90
+ archive_file = hf_bucket_url(
91
+ pretrained_model_name_or_path,
92
+ filename=WEIGHTS_NAME,
93
+ revision=revision,
94
+ mirror=mirror,
95
+ )
96
+
97
+ try:
98
+ # Load from URL or cache
99
+ resolved_archive_file = cached_path(
100
+ archive_file,
101
+ cache_dir=cache_dir,
102
+ force_download=force_download,
103
+ proxies=proxies,
104
+ resume_download=resume_download,
105
+ local_files_only=local_files_only,
106
+ use_auth_token=use_auth_token,
107
+ )
108
+ except EnvironmentError as err:
109
+ logger.error(err)
110
+ msg = (
111
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
112
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
113
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
114
+ )
115
+ raise EnvironmentError(msg)
116
+
117
+ if resolved_archive_file == archive_file:
118
+ logger.info(f"loading weights file {archive_file}")
119
+ else:
120
+ logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
121
+ else:
122
+ resolved_archive_file = None
123
+
124
+ # Initialize the model
125
+ model = cls(config)
126
+
127
+ if state_dict is None:
128
+ try:
129
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
130
+ except Exception:
131
+ raise OSError(
132
+ f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
133
+ f"at '{resolved_archive_file}'"
134
+ )
135
+
136
+ # Remove the prefix 'module' from the keys if present (happens when using DataParallel)
137
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
138
+
139
+ # Load only the custom model weights, excluding HuBERT
140
+ custom_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('hubert.')}
141
+ missing_keys, unexpected_keys = model.load_state_dict(custom_state_dict, strict=False)
142
+
143
+ if len(missing_keys) > 0:
144
+ logger.warning(f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
145
+ f"and are newly initialized: {missing_keys}\n"
146
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
147
+ if len(unexpected_keys) > 0:
148
+ logger.warning(f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
149
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
150
+ f"This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
151
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
152
+ f"This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical "
153
+ f"(initializing a BertForSequenceClassification model from a BertForSequenceClassification model).")
154
+
155
+ if output_loading_info:
156
+ loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
157
+ return model, loading_info
158
+
159
+ return model