KhaldiAbderrhmane commited on
Commit
d181956
1 Parent(s): 9efc15b

Update modeling_emotion_classifier.py

Browse files
Files changed (1) hide show
  1. modeling_emotion_classifier.py +16 -4
modeling_emotion_classifier.py CHANGED
@@ -1,8 +1,20 @@
1
- from transformers import PreTrainedModel, HubertConfig, HubertModel,PretrainedConfig
2
- import torch.nn as nn
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
4
  from .configuration_emotion_classifier import EmotionClassifierConfig
5
 
 
 
6
  class EmotionClassifierHuBERT(PreTrainedModel):
7
  config_class = EmotionClassifierConfig
8
 
@@ -47,7 +59,7 @@ class EmotionClassifierHuBERT(PreTrainedModel):
47
  mirror = kwargs.pop("mirror", None)
48
 
49
  # Load config if we don't provide a configuration
50
- if not isinstance(config, PretrainedConfig):
51
  config_path = config if config is not None else pretrained_model_name_or_path
52
  config, model_kwargs = cls.config_class.from_pretrained(
53
  config_path,
@@ -156,4 +168,4 @@ class EmotionClassifierHuBERT(PreTrainedModel):
156
  loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
157
  return model, loading_info
158
 
159
- return model
 
1
+ import os
 
2
  import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel, HubertConfig, HubertModel
5
+ from transformers.file_utils import (
6
+ WEIGHTS_NAME,
7
+ TF2_WEIGHTS_NAME,
8
+ TF_WEIGHTS_NAME,
9
+ cached_path,
10
+ hf_bucket_url,
11
+ is_remote_url,
12
+ )
13
+ from transformers.utils import logging
14
  from .configuration_emotion_classifier import EmotionClassifierConfig
15
 
16
+ logger = logging.get_logger(__name__)
17
+
18
  class EmotionClassifierHuBERT(PreTrainedModel):
19
  config_class = EmotionClassifierConfig
20
 
 
59
  mirror = kwargs.pop("mirror", None)
60
 
61
  # Load config if we don't provide a configuration
62
+ if not isinstance(config, EmotionClassifierConfig):
63
  config_path = config if config is not None else pretrained_model_name_or_path
64
  config, model_kwargs = cls.config_class.from_pretrained(
65
  config_path,
 
168
  loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
169
  return model, loading_info
170
 
171
+ return model