Update model_hf.py
Browse files- model_hf.py +4 -2
model_hf.py
CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|
8 |
from torch import Tensor
|
9 |
import fairseq
|
10 |
from .config_ssl import SSLConfig
|
|
|
11 |
|
12 |
___author__ = "Hemlata Tak"
|
13 |
__email__ = "tak@eurecom.fr"
|
@@ -20,8 +21,9 @@ __email__ = "tak@eurecom.fr"
|
|
20 |
class SSLModel(nn.Module):
|
21 |
def __init__(self,device):
|
22 |
super(SSLModel, self).__init__()
|
23 |
-
|
24 |
-
|
|
|
25 |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
|
26 |
self.model = model[0]
|
27 |
self.model_device=device
|
|
|
8 |
from torch import Tensor
|
9 |
import fairseq
|
10 |
from .config_ssl import SSLConfig
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
___author__ = "Hemlata Tak"
|
14 |
__email__ = "tak@eurecom.fr"
|
|
|
21 |
class SSLModel(nn.Module):
|
22 |
def __init__(self,device):
|
23 |
super(SSLModel, self).__init__()
|
24 |
+
repo_id = 'ash56/ssl-aasist'
|
25 |
+
fname = 'xlsr2_300m.pt'
|
26 |
+
cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
|
27 |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
|
28 |
self.model = model[0]
|
29 |
self.model_device=device
|