alessandro trinca tornidor commited on
Commit
70d4503
1 Parent(s): 823d44e

feat: support pytorch and torchaudio, update test, add requirements-dev.txt

Browse files
.gitignore CHANGED
@@ -199,6 +199,7 @@ tmp
199
  nohup.out
200
  /tests/events.tar
201
  function_dump_*.json
 
202
 
203
  # onnx models
204
  *.onnx
 
199
  nohup.out
200
  /tests/events.tar
201
  function_dump_*.json
202
+ *.yml
203
 
204
  # onnx models
205
  *.onnx
aip_trainer/models/models.py CHANGED
@@ -1,25 +1,14 @@
1
- from typing import Any
2
-
3
- import torch
4
  import torch.nn as nn
 
 
5
 
6
 
7
  # second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
8
- def getASRModel(language: str) -> tuple[nn.Module, Any]:
9
-
10
-
11
  if language == 'de':
12
-
13
- model, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models',
14
- model='silero_stt',
15
- language='de',
16
- device=torch.device('cpu'))
17
-
18
  elif language == 'en':
19
- model, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models',
20
- model='silero_stt',
21
- language='en',
22
- device=torch.device('cpu'))
23
  else:
24
  raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
25
 
 
 
 
 
1
  import torch.nn as nn
2
+ from silero import silero_stt
3
+ from silero.utils import Decoder
4
 
5
 
6
  # second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
7
+ def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
 
 
8
  if language == 'de':
9
+ model, decoder, _ = silero_stt(language='de', version="v4", jit_model="jit_large")
 
 
 
 
 
10
  elif language == 'en':
11
+ model, decoder, _ = silero_stt(language='en')
 
 
 
12
  else:
13
  raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
14
 
requirements-dev.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pytest
2
+ pytest-cov
requirements.txt CHANGED
@@ -7,7 +7,6 @@ flask_cors
7
  omegaconf
8
  ortools==9.11.4210
9
  pandas
10
- numpy<2.0.0
11
  pickle-mixin
12
  python-dotenv
13
  requests
@@ -15,6 +14,6 @@ sentencepiece
15
  soundfile==0.12.1
16
  sqlalchemy
17
  structlog
18
- torch==1.13.1
19
- torchaudio==0.13.1
20
  transformers
 
7
  omegaconf
8
  ortools==9.11.4210
9
  pandas
 
10
  pickle-mixin
11
  python-dotenv
12
  requests
 
14
  soundfile==0.12.1
15
  sqlalchemy
16
  structlog
17
+ torch
18
+ torchaudio
19
  transformers
tests/events/GetAccuracyFromRecordedAudio.json CHANGED
The diff for this file is too large to render. See raw diff
 
tests/test_GetAccuracyFromRecordedAudio.py CHANGED
@@ -40,7 +40,9 @@ class TestGetAccuracyFromRecordedAudio(unittest.TestCase):
40
  output["matched_transcripts"] = expected_output["matched_transcripts"]
41
  output["matched_transcripts_ipa"] = expected_output["matched_transcripts_ipa"]
42
  output["pronunciation_accuracy"] = expected_output["pronunciation_accuracy"]
 
43
  output["ipa_transcript"] = expected_output["ipa_transcript"]
 
44
  output["real_transcripts_ipa"] = expected_output["real_transcripts_ipa"]
45
  self.assertEqual(expected_output, output)
46
 
 
40
  output["matched_transcripts"] = expected_output["matched_transcripts"]
41
  output["matched_transcripts_ipa"] = expected_output["matched_transcripts_ipa"]
42
  output["pronunciation_accuracy"] = expected_output["pronunciation_accuracy"]
43
+ output["pair_accuracy_category"] = expected_output["pair_accuracy_category"]
44
  output["ipa_transcript"] = expected_output["ipa_transcript"]
45
+ output["real_transcript"] = expected_output["real_transcript"]
46
  output["real_transcripts_ipa"] = expected_output["real_transcripts_ipa"]
47
  self.assertEqual(expected_output, output)
48