k1ngtai commited on
Commit
723cd11
·
1 Parent(s): 5361747

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +103 -14
asr.py CHANGED
@@ -1,9 +1,18 @@
1
  import librosa
2
  from transformers import Wav2Vec2ForCTC, AutoProcessor
3
  import torch
 
 
 
 
4
 
5
  ASR_SAMPLING_RATE = 16_000
6
 
 
 
 
 
 
7
 
8
  MODEL_ID = "facebook/mms-1b-all"
9
 
@@ -11,21 +20,68 @@ processor = AutoProcessor.from_pretrained(MODEL_ID)
11
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
12
 
13
 
14
- def transcribe(microphone, file_upload, lang):
 
 
 
 
 
15
 
16
- warn_output = ""
17
- if (microphone is not None) and (file_upload is not None):
18
- warn_output = (
19
- "WARNING: You've uploaded an audio file and used the microphone. "
20
- "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
21
- )
22
- elif (microphone is None) and (file_upload is None):
23
- return "ERROR: You have to either use the microphone or upload an audio file"
24
 
25
- audio_fp = microphone if microphone is not None else file_upload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
27
 
28
- lang_code = lang.split(":")[0]
29
  processor.tokenizer.set_target_lang(lang_code)
30
  model.load_adapter(lang_code)
31
 
@@ -33,9 +89,42 @@ def transcribe(microphone, file_upload, lang):
33
  audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
34
  )
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  with torch.no_grad():
37
  outputs = model(**inputs).logits
38
 
39
- ids = torch.argmax(outputs, dim=-1)[0]
40
- transcription = processor.decode(ids)
41
- return warn_output + transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import librosa
2
  from transformers import Wav2Vec2ForCTC, AutoProcessor
3
  import torch
4
+ import json
5
+
6
+ from huggingface_hub import hf_hub_download
7
+ from torchaudio.models.decoder import ctc_decoder
8
 
9
  ASR_SAMPLING_RATE = 16_000
10
 
11
+ ASR_LANGUAGES = {}
12
+ with open(f"data/asr/all_langs.tsv") as f:
13
+ for line in f:
14
+ iso, name = line.split(" ", 1)
15
+ ASR_LANGUAGES[iso] = name
16
 
17
  MODEL_ID = "facebook/mms-1b-all"
18
 
 
20
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
21
 
22
 
23
+ # lm_decoding_config = {}
24
+ # lm_decoding_configfile = hf_hub_download(
25
+ # repo_id="facebook/mms-cclms",
26
+ # filename="decoding_config.json",
27
+ # subfolder="mms-1b-all",
28
+ # )
29
 
30
+ # with open(lm_decoding_configfile) as f:
31
+ # lm_decoding_config = json.loads(f.read())
32
+
33
+ # # allow language model decoding for "eng"
34
+
35
+ # decoding_config = lm_decoding_config["eng"]
 
 
36
 
37
+ # lm_file = hf_hub_download(
38
+ # repo_id="facebook/mms-cclms",
39
+ # filename=decoding_config["lmfile"].rsplit("/", 1)[1],
40
+ # subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
41
+ # )
42
+ # token_file = hf_hub_download(
43
+ # repo_id="facebook/mms-cclms",
44
+ # filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
45
+ # subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
46
+ # )
47
+ # lexicon_file = None
48
+ # if decoding_config["lexiconfile"] is not None:
49
+ # lexicon_file = hf_hub_download(
50
+ # repo_id="facebook/mms-cclms",
51
+ # filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
52
+ # subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
53
+ # )
54
+
55
+ # beam_search_decoder = ctc_decoder(
56
+ # lexicon=lexicon_file,
57
+ # tokens=token_file,
58
+ # lm=lm_file,
59
+ # nbest=1,
60
+ # beam_size=500,
61
+ # beam_size_token=50,
62
+ # lm_weight=float(decoding_config["lmweight"]),
63
+ # word_score=float(decoding_config["wordscore"]),
64
+ # sil_score=float(decoding_config["silweight"]),
65
+ # blank_token="<s>",
66
+ # )
67
+
68
+
69
+ def transcribe(
70
+ audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
71
+ ):
72
+ if type(microphone) is dict:
73
+ # HACK: microphone variable is a dict when running on examples
74
+ microphone = microphone["name"]
75
+ audio_fp = (
76
+ file_upload if "upload" in str(audio_source or "").lower() else microphone
77
+ )
78
+
79
+ if audio_fp is None:
80
+ return "ERROR: You have to either use the microphone or upload an audio file"
81
+
82
  audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
83
 
84
+ lang_code = lang.split()[0]
85
  processor.tokenizer.set_target_lang(lang_code)
86
  model.load_adapter(lang_code)
87
 
 
89
  audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
90
  )
91
 
92
+ # set device
93
+ if torch.cuda.is_available():
94
+ device = torch.device("cuda")
95
+ elif (
96
+ hasattr(torch.backends, "mps")
97
+ and torch.backends.mps.is_available()
98
+ and torch.backends.mps.is_built()
99
+ ):
100
+ device = torch.device("mps")
101
+ else:
102
+ device = torch.device("cpu")
103
+
104
+ model.to(device)
105
+ inputs = inputs.to(device)
106
+
107
  with torch.no_grad():
108
  outputs = model(**inputs).logits
109
 
110
+ if lang_code != "eng" or True:
111
+ ids = torch.argmax(outputs, dim=-1)[0]
112
+ transcription = processor.decode(ids)
113
+ else:
114
+ assert False
115
+ # beam_search_result = beam_search_decoder(outputs.to("cpu"))
116
+ # transcription = " ".join(beam_search_result[0][0].words).strip()
117
+
118
+ return transcription
119
+
120
+
121
+ ASR_EXAMPLES = [
122
+ [None, "assets/english.mp3", None, "eng (English)"],
123
+ # [None, "assets/tamil.mp3", None, "tam (Tamil)"],
124
+ # [None, "assets/burmese.mp3", None, "mya (Burmese)"],
125
+ ]
126
+
127
+ ASR_NOTE = """
128
+ The above demo doesn't use beam-search decoding using a language model.
129
+ Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
130
+ """