alessandro trinca tornidor commited on
Commit
1e30c4b
·
1 Parent(s): 4aab922

feat: use alternate version of init_jit_model to try avoiding PermissionError on HuggingFAce

Browse files
Files changed (1) hide show
  1. aip_trainer/models/models.py +78 -28
aip_trainer/models/models.py CHANGED
@@ -1,17 +1,20 @@
1
  import os
2
  from pathlib import Path
3
  import tempfile
 
4
  import torch.nn as nn
5
  from silero.utils import Decoder
6
 
7
  from aip_trainer import app_logger
8
 
9
 
10
- def silero_stt(language='en',
11
- version='latest',
12
- jit_model='jit',
13
- output_folder: Path | str = None,
14
- **kwargs):
 
 
15
  """Modified Silero Speech-To-Text Model(s) function
16
  language (str): language of the model, now available are ['en', 'de', 'es']
17
  version:
@@ -21,46 +24,93 @@ def silero_stt(language='en',
21
  Please see https://github.com/snakers4/silero-models for usage examples
22
  """
23
  import torch
24
- from omegaconf import OmegaConf
25
- from silero.utils import (init_jit_model,
26
- read_audio,
27
- read_batch,
28
- split_into_batches,
29
- prepare_model_input)
 
30
 
31
- output_folder = Path(output_folder) if output_folder is not None else Path(os.path.dirname(__file__)) / ".." / ".."
32
- models_list_file = output_folder / f'latest_silero_model_{language}.yml'
 
 
 
 
33
  if not os.path.exists(models_list_file):
34
- app_logger.info(f"model yml for '{language}' language, '{version}' version not found, download it in folder {output_folder}...")
 
 
35
  torch.hub.download_url_to_file(
36
- 'https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
37
  models_list_file,
38
- progress=True
39
  )
40
- app_logger.info(f"model yml for '{language}' language, '{version}' version in folder {output_folder}: OK!")
 
 
41
  assert os.path.exists(models_list_file)
42
  models = OmegaConf.load(models_list_file)
43
  available_languages = list(models.stt_models.keys())
44
  assert language in available_languages
45
 
46
- model, decoder = init_jit_model(model_url=models.stt_models.get(language).get(version).get(jit_model),
47
- **kwargs)
48
- utils = (read_batch,
49
- split_into_batches,
50
- read_audio,
51
- prepare_model_input)
52
 
53
  return model, decoder, utils
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
57
  def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
58
  tmp_dir = tempfile.gettempdir()
59
- if language == 'de':
60
- model, decoder, _ = silero_stt(language='de', version="v4", jit_model="jit_large", output_folder=tmp_dir)
61
- elif language == 'en':
62
- model, decoder, _ = silero_stt(language='en', output_folder=tmp_dir)
 
 
63
  else:
64
- raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
 
 
 
 
65
 
66
  return model, decoder
 
1
  import os
2
  from pathlib import Path
3
  import tempfile
4
+ import torch
5
  import torch.nn as nn
6
  from silero.utils import Decoder
7
 
8
  from aip_trainer import app_logger
9
 
10
 
11
+ def silero_stt(
12
+ language="en",
13
+ version="latest",
14
+ jit_model="jit",
15
+ output_folder: Path | str = None,
16
+ **kwargs,
17
+ ):
18
  """Modified Silero Speech-To-Text Model(s) function
19
  language (str): language of the model, now available are ['en', 'de', 'es']
20
  version:
 
24
  Please see https://github.com/snakers4/silero-models for usage examples
25
  """
26
  import torch
27
+ from omegaconf import OmegaConf
28
+ from silero.utils import (
29
+ read_audio,
30
+ read_batch,
31
+ split_into_batches,
32
+ prepare_model_input,
33
+ )
34
 
35
+ output_folder = (
36
+ Path(output_folder)
37
+ if output_folder is not None
38
+ else Path(os.path.dirname(__file__)) / ".." / ".."
39
+ )
40
+ models_list_file = output_folder / f"latest_silero_model_{language}.yml"
41
  if not os.path.exists(models_list_file):
42
+ app_logger.info(
43
+ f"model yml for '{language}' language, '{version}' version not found, download it in folder {output_folder}..."
44
+ )
45
  torch.hub.download_url_to_file(
46
+ "https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml",
47
  models_list_file,
48
+ progress=True,
49
  )
50
+ app_logger.info(
51
+ f"model yml for '{language}' language, '{version}' version in folder {output_folder}: OK!"
52
+ )
53
  assert os.path.exists(models_list_file)
54
  models = OmegaConf.load(models_list_file)
55
  available_languages = list(models.stt_models.keys())
56
  assert language in available_languages
57
 
58
+ model, decoder = init_jit_model(
59
+ model_url=models.stt_models.get(language).get(version).get(jit_model), output_folder=output_folder, **kwargs
60
+ )
61
+ utils = (read_batch, split_into_batches, read_audio, prepare_model_input)
 
 
62
 
63
  return model, decoder, utils
64
 
65
 
66
+ def init_jit_model(
67
+ model_url: str,
68
+ device: torch.device = torch.device("cpu"),
69
+ output_folder: Path | str = None,
70
+ ):
71
+ torch.set_grad_enabled(False)
72
+
73
+ app_logger.info(
74
+ f"model output_folder exists? '{output_folder is None}' => '{output_folder}' ..."
75
+ )
76
+ model_dir = (
77
+ Path(output_folder)
78
+ if output_folder is not None
79
+ else Path(os.path.dirname(__file__)) / "model"
80
+ )
81
+ os.makedirs(model_dir, exist_ok=True)
82
+ model_path = model_dir / os.path.basename(model_url)
83
+ app_logger.info(
84
+ f"model_path exists? '{os.path.isfile(model_path)}' => '{model_path}' ..."
85
+ )
86
+
87
+ if not os.path.isfile(model_path):
88
+ app_logger.info(
89
+ f"downloading model_path: '{model_path}' ..."
90
+ )
91
+ torch.hub.download_url_to_file(model_url, model_path, progress=True)
92
+ app_logger.info(
93
+ f"model_path {model_path} downloaded!"
94
+ )
95
+ model = torch.jit.load(model_path, map_location=device)
96
+ model.eval()
97
+ return model, Decoder(model.labels)
98
+
99
+
100
  # second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
101
  def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
102
  tmp_dir = tempfile.gettempdir()
103
+ if language == "de":
104
+ model, decoder, _ = silero_stt(
105
+ language="de", version="v4", jit_model="jit_large", output_folder=tmp_dir
106
+ )
107
+ elif language == "en":
108
+ model, decoder, _ = silero_stt(language="en", output_folder=tmp_dir)
109
  else:
110
+ raise NotImplementedError(
111
+ "currenty works only for 'de' and 'en' languages, not for '{}'.".format(
112
+ language
113
+ )
114
+ )
115
 
116
  return model, decoder