Spaces:
Running
Running
alessandro trinca tornidor
commited on
Commit
•
70d4503
1
Parent(s):
823d44e
feat: support pytorch and torchaudio, update test, add requirements-dev.txt
Browse files- .gitignore +1 -0
- aip_trainer/models/models.py +5 -16
- requirements-dev.txt +2 -0
- requirements.txt +2 -3
- tests/events/GetAccuracyFromRecordedAudio.json +0 -0
- tests/test_GetAccuracyFromRecordedAudio.py +2 -0
.gitignore
CHANGED
@@ -199,6 +199,7 @@ tmp
|
|
199 |
nohup.out
|
200 |
/tests/events.tar
|
201 |
function_dump_*.json
|
|
|
202 |
|
203 |
# onnx models
|
204 |
*.onnx
|
|
|
199 |
nohup.out
|
200 |
/tests/events.tar
|
201 |
function_dump_*.json
|
202 |
+
*.yml
|
203 |
|
204 |
# onnx models
|
205 |
*.onnx
|
aip_trainer/models/models.py
CHANGED
@@ -1,25 +1,14 @@
|
|
1 |
-
from typing import Any
|
2 |
-
|
3 |
-
import torch
|
4 |
import torch.nn as nn
|
|
|
|
|
5 |
|
6 |
|
7 |
# second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
|
8 |
-
def getASRModel(language: str) -> tuple[nn.Module,
|
9 |
-
|
10 |
-
|
11 |
if language == 'de':
|
12 |
-
|
13 |
-
model, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models',
|
14 |
-
model='silero_stt',
|
15 |
-
language='de',
|
16 |
-
device=torch.device('cpu'))
|
17 |
-
|
18 |
elif language == 'en':
|
19 |
-
model, decoder,
|
20 |
-
model='silero_stt',
|
21 |
-
language='en',
|
22 |
-
device=torch.device('cpu'))
|
23 |
else:
|
24 |
raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
|
25 |
|
|
|
|
|
|
|
|
|
1 |
import torch.nn as nn
|
2 |
+
from silero import silero_stt
|
3 |
+
from silero.utils import Decoder
|
4 |
|
5 |
|
6 |
# second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models
|
7 |
+
def getASRModel(language: str) -> tuple[nn.Module, Decoder]:
|
|
|
|
|
8 |
if language == 'de':
|
9 |
+
model, decoder, _ = silero_stt(language='de', version="v4", jit_model="jit_large")
|
|
|
|
|
|
|
|
|
|
|
10 |
elif language == 'en':
|
11 |
+
model, decoder, _ = silero_stt(language='en')
|
|
|
|
|
|
|
12 |
else:
|
13 |
raise NotImplementedError("currenty works only for 'de' and 'en' languages, not for '{}'.".format(language))
|
14 |
|
requirements-dev.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
pytest
|
2 |
+
pytest-cov
|
requirements.txt
CHANGED
@@ -7,7 +7,6 @@ flask_cors
|
|
7 |
omegaconf
|
8 |
ortools==9.11.4210
|
9 |
pandas
|
10 |
-
numpy<2.0.0
|
11 |
pickle-mixin
|
12 |
python-dotenv
|
13 |
requests
|
@@ -15,6 +14,6 @@ sentencepiece
|
|
15 |
soundfile==0.12.1
|
16 |
sqlalchemy
|
17 |
structlog
|
18 |
-
torch
|
19 |
-
torchaudio
|
20 |
transformers
|
|
|
7 |
omegaconf
|
8 |
ortools==9.11.4210
|
9 |
pandas
|
|
|
10 |
pickle-mixin
|
11 |
python-dotenv
|
12 |
requests
|
|
|
14 |
soundfile==0.12.1
|
15 |
sqlalchemy
|
16 |
structlog
|
17 |
+
torch
|
18 |
+
torchaudio
|
19 |
transformers
|
tests/events/GetAccuracyFromRecordedAudio.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
tests/test_GetAccuracyFromRecordedAudio.py
CHANGED
@@ -40,7 +40,9 @@ class TestGetAccuracyFromRecordedAudio(unittest.TestCase):
|
|
40 |
output["matched_transcripts"] = expected_output["matched_transcripts"]
|
41 |
output["matched_transcripts_ipa"] = expected_output["matched_transcripts_ipa"]
|
42 |
output["pronunciation_accuracy"] = expected_output["pronunciation_accuracy"]
|
|
|
43 |
output["ipa_transcript"] = expected_output["ipa_transcript"]
|
|
|
44 |
output["real_transcripts_ipa"] = expected_output["real_transcripts_ipa"]
|
45 |
self.assertEqual(expected_output, output)
|
46 |
|
|
|
40 |
output["matched_transcripts"] = expected_output["matched_transcripts"]
|
41 |
output["matched_transcripts_ipa"] = expected_output["matched_transcripts_ipa"]
|
42 |
output["pronunciation_accuracy"] = expected_output["pronunciation_accuracy"]
|
43 |
+
output["pair_accuracy_category"] = expected_output["pair_accuracy_category"]
|
44 |
output["ipa_transcript"] = expected_output["ipa_transcript"]
|
45 |
+
output["real_transcript"] = expected_output["real_transcript"]
|
46 |
output["real_transcripts_ipa"] = expected_output["real_transcripts_ipa"]
|
47 |
self.assertEqual(expected_output, output)
|
48 |
|