Time execution and fix small bug in STT.
Browse files- src/language_id.py +5 -2
- src/speech_to_text.py +4 -2
- src/text_to_speech.py +7 -1
- src/translation.py +7 -0
src/language_id.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
import librosa
|
3 |
import torch
|
4 |
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
|
@@ -22,6 +22,8 @@ def identify_language(fp:str) -> str:
|
|
22 |
# Ensure replicability
|
23 |
set_seed(555)
|
24 |
|
|
|
|
|
25 |
# Load language ID model
|
26 |
model_id = "facebook/mms-lid-256" # Need to find the appropriate model for the language -- 256 languages is the first that contains MOS
|
27 |
processor = AutoFeatureExtractor.from_pretrained(model_id)
|
@@ -37,5 +39,6 @@ def identify_language(fp:str) -> str:
|
|
37 |
|
38 |
lang_id = torch.argmax(outputs, dim=-1)[0].item()
|
39 |
detected_lang = model.config.id2label[lang_id]
|
40 |
-
|
|
|
41 |
return detected_lang
|
|
|
1 |
+
import time
|
2 |
import librosa
|
3 |
import torch
|
4 |
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
|
|
|
22 |
# Ensure replicability
|
23 |
set_seed(555)
|
24 |
|
25 |
+
start_time = time.time()
|
26 |
+
|
27 |
# Load language ID model
|
28 |
model_id = "facebook/mms-lid-256" # Need to find the appropriate model for the language -- 256 languages is the first that contains MOS
|
29 |
processor = AutoFeatureExtractor.from_pretrained(model_id)
|
|
|
39 |
|
40 |
lang_id = torch.argmax(outputs, dim=-1)[0].item()
|
41 |
detected_lang = model.config.id2label[lang_id]
|
42 |
+
|
43 |
+
print("Time elapsed: ", int(time.time() - start_time), " seconds")
|
44 |
return detected_lang
|
src/speech_to_text.py
CHANGED
@@ -3,6 +3,7 @@ import librosa
|
|
3 |
import torch
|
4 |
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
5 |
from transformers import set_seed
|
|
|
6 |
|
7 |
|
8 |
def transcribe(fp:str, target_lang:str) -> str:
|
@@ -23,10 +24,10 @@ def transcribe(fp:str, target_lang:str) -> str:
|
|
23 |
'''
|
24 |
# Ensure replicability
|
25 |
set_seed(555)
|
26 |
-
|
|
|
27 |
# Load transcription model
|
28 |
model_id = "facebook/mms-1b-all"
|
29 |
-
target_lang = "mos"
|
30 |
|
31 |
processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
|
32 |
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True)
|
@@ -42,4 +43,5 @@ def transcribe(fp:str, target_lang:str) -> str:
|
|
42 |
ids = torch.argmax(outputs, dim=-1)[0]
|
43 |
transcript = processor.decode(ids)
|
44 |
|
|
|
45 |
return transcript
|
|
|
3 |
import torch
|
4 |
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
5 |
from transformers import set_seed
|
6 |
+
import time
|
7 |
|
8 |
|
9 |
def transcribe(fp:str, target_lang:str) -> str:
|
|
|
24 |
'''
|
25 |
# Ensure replicability
|
26 |
set_seed(555)
|
27 |
+
start_time = time.time()
|
28 |
+
|
29 |
# Load transcription model
|
30 |
model_id = "facebook/mms-1b-all"
|
|
|
31 |
|
32 |
processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
|
33 |
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True)
|
|
|
43 |
ids = torch.argmax(outputs, dim=-1)[0]
|
44 |
transcript = processor.decode(ids)
|
45 |
|
46 |
+
print("Time elapsed: ", int(time.time() - start_time), " seconds")
|
47 |
return transcript
|
src/text_to_speech.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
import torch
|
3 |
from transformers import set_seed
|
4 |
from transformers import VitsTokenizer, VitsModel
|
@@ -19,6 +19,11 @@ def synthesize_facebook(s:str, iso3:str) -> str:
|
|
19 |
synth:str
|
20 |
The synthesized audio.
|
21 |
'''
|
|
|
|
|
|
|
|
|
|
|
22 |
# Load synthesizer
|
23 |
tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{iso3}")
|
24 |
model = VitsModel.from_pretrained(f"facebook/mms-tts-{iso3}")
|
@@ -31,4 +36,5 @@ def synthesize_facebook(s:str, iso3:str) -> str:
|
|
31 |
|
32 |
synth = outputs.waveform[0]
|
33 |
|
|
|
34 |
return synth.numpy()
|
|
|
1 |
+
import time
|
2 |
import torch
|
3 |
from transformers import set_seed
|
4 |
from transformers import VitsTokenizer, VitsModel
|
|
|
19 |
synth:str
|
20 |
The synthesized audio.
|
21 |
'''
|
22 |
+
|
23 |
+
# Ensure replicability
|
24 |
+
set_seed(555)
|
25 |
+
start_time = time.time()
|
26 |
+
|
27 |
# Load synthesizer
|
28 |
tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{iso3}")
|
29 |
model = VitsModel.from_pretrained(f"facebook/mms-tts-{iso3}")
|
|
|
36 |
|
37 |
synth = outputs.waveform[0]
|
38 |
|
39 |
+
print("Time elapsed: ", int(time.time() - start_time), " seconds")
|
40 |
return synth.numpy()
|
src/translation.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
from transformers import set_seed, pipeline
|
3 |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
4 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
5 |
|
6 |
######### HELSINKI NLP ##################
|
7 |
def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str:
|
@@ -118,6 +119,10 @@ def translate(s, src_iso, dest_iso):
|
|
118 |
translation:str
|
119 |
The translated text, concatenated over different models
|
120 |
'''
|
|
|
|
|
|
|
|
|
121 |
# Translate with Meta NLLB
|
122 |
translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso)
|
123 |
|
@@ -133,5 +138,7 @@ def translate(s, src_iso, dest_iso):
|
|
133 |
dest_iso = dest_iso.replace("fra", "fr")
|
134 |
translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso)
|
135 |
|
|
|
|
|
136 |
return translation
|
137 |
|
|
|
2 |
from transformers import set_seed, pipeline
|
3 |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
4 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
5 |
+
import time
|
6 |
|
7 |
######### HELSINKI NLP ##################
|
8 |
def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str:
|
|
|
119 |
translation:str
|
120 |
The translated text, concatenated over different models
|
121 |
'''
|
122 |
+
|
123 |
+
# Ensure replicability
|
124 |
+
start_time = time.time()
|
125 |
+
|
126 |
# Translate with Meta NLLB
|
127 |
translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso)
|
128 |
|
|
|
138 |
dest_iso = dest_iso.replace("fra", "fr")
|
139 |
translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso)
|
140 |
|
141 |
+
print("Time elapsed: ", int(time.time() - start_time), " seconds")
|
142 |
+
|
143 |
return translation
|
144 |
|