PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
f73208e
·
verified ·
1 Parent(s): b41b9d8

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +6 -5
model_hf.py CHANGED
@@ -9,6 +9,7 @@ from torch import Tensor
9
  from .config_ssl import SSLConfig
10
  from huggingface_hub import hf_hub_download
11
  from transformers import Wav2Vec2ForPreTraining
 
12
 
13
  ___author__ = "Hemlata Tak"
14
  __email__ = "tak@eurecom.fr"
@@ -23,11 +24,11 @@ class SSLModel(nn.Module):
23
  super(SSLModel, self).__init__()
24
  # eliminate fairseq dependency
25
  # facebook/wav2vec2-xls-r-300m
26
- # repo_id = "facebook/wav2vec2-xlsr-300m"
27
- model = Wav2Vec2.from_pretrained("facebook/wav2vec2-xls-r-300m")
28
- # cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
29
- # model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
30
- self.model = model
31
  self.model_device=device
32
  self.out_dim = 1024
33
  return
 
9
  from .config_ssl import SSLConfig
10
  from huggingface_hub import hf_hub_download
11
  from transformers import Wav2Vec2ForPreTraining
12
+ import fairseq
13
 
14
  ___author__ = "Hemlata Tak"
15
  __email__ = "tak@eurecom.fr"
 
24
  super(SSLModel, self).__init__()
25
  # eliminate fairseq dependency
26
  # facebook/wav2vec2-xls-r-300m
27
+ repo_id = "ash56/ssl-aasist"
28
+ fname = "xlsr2_300m.pt"
29
+ cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
30
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
31
+ self.model = model[0]
32
  self.model_device=device
33
  self.out_dim = 1024
34
  return