Spaces:
Running
Running
alessandro trinca tornidor
commited on
Commit
·
1e30c4b
1
Parent(s):
4aab922
feat: use alternate version of init_jit_model to try avoiding PermissionError on HuggingFAce
Browse files- aip_trainer/models/models.py +78 -28
aip_trainer/models/models.py
CHANGED
@@ -1,17 +1,20 @@
|
|
1 |
import os
|
2 |
from pathlib import Path
|
3 |
import tempfile
|
|
|
4 |
import torch.nn as nn
|
5 |
from silero.utils import Decoder
|
6 |
|
7 |
from aip_trainer import app_logger
|
8 |
|
9 |
|
10 |
-
def silero_stt(
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
"""Modified Silero Speech-To-Text Model(s) function
|
16 |
language (str): language of the model, now available are ['en', 'de', 'es']
|
17 |
version:
|
@@ -21,46 +24,93 @@ def silero_stt(language='en',
|
|
21 |
Please see https://github.com/snakers4/silero-models for usage examples
|
22 |
"""
|
23 |
import torch
|
24 |
-
from omegaconf import OmegaConf
|
25 |
-
from silero.utils import (
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
30 |
|
31 |
-
output_folder =
|
32 |
-
|
|
|
|
|
|
|
|
|
33 |
if not os.path.exists(models_list_file):
|
34 |
-
app_logger.info(
|
|
|
|
|
35 |
torch.hub.download_url_to_file(
|
36 |
-
|
37 |
models_list_file,
|
38 |
-
progress=True
|
39 |
)
|
40 |
-
app_logger.info(
|
|
|
|
|
41 |
assert os.path.exists(models_list_file)
|
42 |
models = OmegaConf.load(models_list_file)
|
43 |
available_languages = list(models.stt_models.keys())
|
44 |
assert language in available_languages
|
45 |
|
46 |
-
model, decoder = init_jit_model(
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
read_audio,
|
51 |
-
prepare_model_input)
|
52 |
|
53 |
return model, decoder, utils
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
# second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
|
57 |
def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
|
58 |
tmp_dir = tempfile.gettempdir()
|
59 |
-
if language ==
|
60 |
-
model, decoder, _ = silero_stt(
|
61 |
-
|
62 |
-
|
|
|
|
|
63 |
else:
|
64 |
-
raise NotImplementedError(
|
|
|
|
|
|
|
|
|
65 |
|
66 |
return model, decoder
|
|
|
1 |
import os
|
2 |
from pathlib import Path
|
3 |
import tempfile
|
4 |
+
import torch
|
5 |
import torch.nn as nn
|
6 |
from silero.utils import Decoder
|
7 |
|
8 |
from aip_trainer import app_logger
|
9 |
|
10 |
|
11 |
+
def silero_stt(
|
12 |
+
language="en",
|
13 |
+
version="latest",
|
14 |
+
jit_model="jit",
|
15 |
+
output_folder: Path | str = None,
|
16 |
+
**kwargs,
|
17 |
+
):
|
18 |
"""Modified Silero Speech-To-Text Model(s) function
|
19 |
language (str): language of the model, now available are ['en', 'de', 'es']
|
20 |
version:
|
|
|
24 |
Please see https://github.com/snakers4/silero-models for usage examples
|
25 |
"""
|
26 |
import torch
|
27 |
+
from omegaconf import OmegaConf
|
28 |
+
from silero.utils import (
|
29 |
+
read_audio,
|
30 |
+
read_batch,
|
31 |
+
split_into_batches,
|
32 |
+
prepare_model_input,
|
33 |
+
)
|
34 |
|
35 |
+
output_folder = (
|
36 |
+
Path(output_folder)
|
37 |
+
if output_folder is not None
|
38 |
+
else Path(os.path.dirname(__file__)) / ".." / ".."
|
39 |
+
)
|
40 |
+
models_list_file = output_folder / f"latest_silero_model_{language}.yml"
|
41 |
if not os.path.exists(models_list_file):
|
42 |
+
app_logger.info(
|
43 |
+
f"model yml for '{language}' language, '{version}' version not found, download it in folder {output_folder}..."
|
44 |
+
)
|
45 |
torch.hub.download_url_to_file(
|
46 |
+
"https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml",
|
47 |
models_list_file,
|
48 |
+
progress=True,
|
49 |
)
|
50 |
+
app_logger.info(
|
51 |
+
f"model yml for '{language}' language, '{version}' version in folder {output_folder}: OK!"
|
52 |
+
)
|
53 |
assert os.path.exists(models_list_file)
|
54 |
models = OmegaConf.load(models_list_file)
|
55 |
available_languages = list(models.stt_models.keys())
|
56 |
assert language in available_languages
|
57 |
|
58 |
+
model, decoder = init_jit_model(
|
59 |
+
model_url=models.stt_models.get(language).get(version).get(jit_model), output_folder=output_folder, **kwargs
|
60 |
+
)
|
61 |
+
utils = (read_batch, split_into_batches, read_audio, prepare_model_input)
|
|
|
|
|
62 |
|
63 |
return model, decoder, utils
|
64 |
|
65 |
|
66 |
+
def init_jit_model(
|
67 |
+
model_url: str,
|
68 |
+
device: torch.device = torch.device("cpu"),
|
69 |
+
output_folder: Path | str = None,
|
70 |
+
):
|
71 |
+
torch.set_grad_enabled(False)
|
72 |
+
|
73 |
+
app_logger.info(
|
74 |
+
f"model output_folder exists? '{output_folder is None}' => '{output_folder}' ..."
|
75 |
+
)
|
76 |
+
model_dir = (
|
77 |
+
Path(output_folder)
|
78 |
+
if output_folder is not None
|
79 |
+
else Path(os.path.dirname(__file__)) / "model"
|
80 |
+
)
|
81 |
+
os.makedirs(model_dir, exist_ok=True)
|
82 |
+
model_path = model_dir / os.path.basename(model_url)
|
83 |
+
app_logger.info(
|
84 |
+
f"model_path exists? '{os.path.isfile(model_path)}' => '{model_path}' ..."
|
85 |
+
)
|
86 |
+
|
87 |
+
if not os.path.isfile(model_path):
|
88 |
+
app_logger.info(
|
89 |
+
f"downloading model_path: '{model_path}' ..."
|
90 |
+
)
|
91 |
+
torch.hub.download_url_to_file(model_url, model_path, progress=True)
|
92 |
+
app_logger.info(
|
93 |
+
f"model_path {model_path} downloaded!"
|
94 |
+
)
|
95 |
+
model = torch.jit.load(model_path, map_location=device)
|
96 |
+
model.eval()
|
97 |
+
return model, Decoder(model.labels)
|
98 |
+
|
99 |
+
|
100 |
# second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
|
101 |
def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
|
102 |
tmp_dir = tempfile.gettempdir()
|
103 |
+
if language == "de":
|
104 |
+
model, decoder, _ = silero_stt(
|
105 |
+
language="de", version="v4", jit_model="jit_large", output_folder=tmp_dir
|
106 |
+
)
|
107 |
+
elif language == "en":
|
108 |
+
model, decoder, _ = silero_stt(language="en", output_folder=tmp_dir)
|
109 |
else:
|
110 |
+
raise NotImplementedError(
|
111 |
+
"currenty works only for 'de' and 'en' languages, not for '{}'.".format(
|
112 |
+
language
|
113 |
+
)
|
114 |
+
)
|
115 |
|
116 |
return model, decoder
|