khof312 commited on
Commit
a84c313
1 Parent(s): 30539df

Time execution and fix small bug in STT.

Browse files
src/language_id.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  import librosa
3
  import torch
4
  from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
@@ -22,6 +22,8 @@ def identify_language(fp:str) -> str:
22
  # Ensure replicability
23
  set_seed(555)
24
 
 
 
25
  # Load language ID model
26
  model_id = "facebook/mms-lid-256" # Need to find the appropriate model for the language -- 256 languages is the first that contains MOS
27
  processor = AutoFeatureExtractor.from_pretrained(model_id)
@@ -37,5 +39,6 @@ def identify_language(fp:str) -> str:
37
 
38
  lang_id = torch.argmax(outputs, dim=-1)[0].item()
39
  detected_lang = model.config.id2label[lang_id]
40
-
 
41
  return detected_lang
 
1
+ import time
2
  import librosa
3
  import torch
4
  from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
 
22
  # Ensure replicability
23
  set_seed(555)
24
 
25
+ start_time = time.time()
26
+
27
  # Load language ID model
28
  model_id = "facebook/mms-lid-256" # Need to find the appropriate model for the language -- 256 languages is the first that contains MOS
29
  processor = AutoFeatureExtractor.from_pretrained(model_id)
 
39
 
40
  lang_id = torch.argmax(outputs, dim=-1)[0].item()
41
  detected_lang = model.config.id2label[lang_id]
42
+
43
+ print("Time elapsed: ", int(time.time() - start_time), " seconds")
44
  return detected_lang
src/speech_to_text.py CHANGED
@@ -3,6 +3,7 @@ import librosa
3
  import torch
4
  from transformers import Wav2Vec2ForCTC, AutoProcessor
5
  from transformers import set_seed
 
6
 
7
 
8
  def transcribe(fp:str, target_lang:str) -> str:
@@ -23,10 +24,10 @@ def transcribe(fp:str, target_lang:str) -> str:
23
  '''
24
  # Ensure replicability
25
  set_seed(555)
26
-
 
27
  # Load transcription model
28
  model_id = "facebook/mms-1b-all"
29
- target_lang = "mos"
30
 
31
  processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
32
  model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True)
@@ -42,4 +43,5 @@ def transcribe(fp:str, target_lang:str) -> str:
42
  ids = torch.argmax(outputs, dim=-1)[0]
43
  transcript = processor.decode(ids)
44
 
 
45
  return transcript
 
3
  import torch
4
  from transformers import Wav2Vec2ForCTC, AutoProcessor
5
  from transformers import set_seed
6
+ import time
7
 
8
 
9
  def transcribe(fp:str, target_lang:str) -> str:
 
24
  '''
25
  # Ensure replicability
26
  set_seed(555)
27
+ start_time = time.time()
28
+
29
  # Load transcription model
30
  model_id = "facebook/mms-1b-all"
 
31
 
32
  processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang)
33
  model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True)
 
43
  ids = torch.argmax(outputs, dim=-1)[0]
44
  transcript = processor.decode(ids)
45
 
46
+ print("Time elapsed: ", int(time.time() - start_time), " seconds")
47
  return transcript
src/text_to_speech.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  import torch
3
  from transformers import set_seed
4
  from transformers import VitsTokenizer, VitsModel
@@ -19,6 +19,11 @@ def synthesize_facebook(s:str, iso3:str) -> str:
19
  synth:str
20
  The synthesized audio.
21
  '''
 
 
 
 
 
22
  # Load synthesizer
23
  tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{iso3}")
24
  model = VitsModel.from_pretrained(f"facebook/mms-tts-{iso3}")
@@ -31,4 +36,5 @@ def synthesize_facebook(s:str, iso3:str) -> str:
31
 
32
  synth = outputs.waveform[0]
33
 
 
34
  return synth.numpy()
 
1
+ import time
2
  import torch
3
  from transformers import set_seed
4
  from transformers import VitsTokenizer, VitsModel
 
19
  synth:str
20
  The synthesized audio.
21
  '''
22
+
23
+ # Ensure replicability
24
+ set_seed(555)
25
+ start_time = time.time()
26
+
27
  # Load synthesizer
28
  tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{iso3}")
29
  model = VitsModel.from_pretrained(f"facebook/mms-tts-{iso3}")
 
36
 
37
  synth = outputs.waveform[0]
38
 
39
+ print("Time elapsed: ", int(time.time() - start_time), " seconds")
40
  return synth.numpy()
src/translation.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from transformers import set_seed, pipeline
3
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
5
 
6
  ######### HELSINKI NLP ##################
7
  def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str:
@@ -118,6 +119,10 @@ def translate(s, src_iso, dest_iso):
118
  translation:str
119
  The translated text, concatenated over different models
120
  '''
 
 
 
 
121
  # Translate with Meta NLLB
122
  translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso)
123
 
@@ -133,5 +138,7 @@ def translate(s, src_iso, dest_iso):
133
  dest_iso = dest_iso.replace("fra", "fr")
134
  translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso)
135
 
 
 
136
  return translation
137
 
 
2
  from transformers import set_seed, pipeline
3
  from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+ import time
6
 
7
  ######### HELSINKI NLP ##################
8
  def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str:
 
119
  translation:str
120
  The translated text, concatenated over different models
121
  '''
122
+
123
+ # Ensure replicability
124
+ start_time = time.time()
125
+
126
  # Translate with Meta NLLB
127
  translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso)
128
 
 
138
  dest_iso = dest_iso.replace("fra", "fr")
139
  translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso)
140
 
141
+ print("Time elapsed: ", int(time.time() - start_time), " seconds")
142
+
143
  return translation
144