vineelpratap commited on
Commit
69b07b9
1 Parent(s): 6003020

Rename asr_lm.py to asr_lm_eng.py

Browse files
Files changed (2) hide show
  1. asr_lm.py +0 -0
  2. asr_lm_eng.py +136 -0
asr_lm.py DELETED
File without changes
asr_lm_eng.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
3
+ import torch
4
+ import numpy as np
5
+ from pathlib import Path
6
+
7
+ from huggingface_hub import hf_hub_download
8
+ from torchaudio.models.decoder import ctc_decoder
9
+
10
+ ASR_SAMPLING_RATE = 16_000
11
+
12
+ ASR_LANGUAGES = {}
13
+ with open(f"data/asr/all_langs.tsv") as f:
14
+ for line in f:
15
+ iso, name = line.split(" ", 1)
16
+ ASR_LANGUAGES[iso.strip()] = name.strip()
17
+
18
+ MODEL_ID = "facebook/mms-1b-all"
19
+
20
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
21
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
22
+
23
+
24
+ # lm_decoding_config = {}
25
+ # lm_decoding_configfile = hf_hub_download(
26
+ # repo_id="facebook/mms-cclms",
27
+ # filename="decoding_config.json",
28
+ # subfolder="mms-1b-all",
29
+ # )
30
+
31
+ # with open(lm_decoding_configfile) as f:
32
+ # lm_decoding_config = json.loads(f.read())
33
+
34
+ # # allow language model decoding for "eng"
35
+
36
+ # decoding_config = lm_decoding_config["eng"]
37
+
38
+ # lm_file = hf_hub_download(
39
+ # repo_id="facebook/mms-cclms",
40
+ # filename=decoding_config["lmfile"].rsplit("/", 1)[1],
41
+ # subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
42
+ # )
43
+ # token_file = hf_hub_download(
44
+ # repo_id="facebook/mms-cclms",
45
+ # filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
46
+ # subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
47
+ # )
48
+ # lexicon_file = None
49
+ # if decoding_config["lexiconfile"] is not None:
50
+ # lexicon_file = hf_hub_download(
51
+ # repo_id="facebook/mms-cclms",
52
+ # filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
53
+ # subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
54
+ # )
55
+
56
+ # beam_search_decoder = ctc_decoder(
57
+ # lexicon=lexicon_file,
58
+ # tokens=token_file,
59
+ # lm=lm_file,
60
+ # nbest=1,
61
+ # beam_size=500,
62
+ # beam_size_token=50,
63
+ # lm_weight=float(decoding_config["lmweight"]),
64
+ # word_score=float(decoding_config["wordscore"]),
65
+ # sil_score=float(decoding_config["silweight"]),
66
+ # blank_token="<s>",
67
+ # )
68
+
69
+
70
+ def transcribe(audio_data=None, lang="eng (English)"):
71
+
72
+ if not audio_data:
73
+ return "<<ERROR: Empty Audio Input>>"
74
+
75
+ if isinstance(audio_data, tuple):
76
+ # microphone
77
+ sr, audio_samples = audio_data
78
+ audio_samples = (audio_samples / 32768.0).astype(np.float32)
79
+ if sr != ASR_SAMPLING_RATE:
80
+ audio_samples = librosa.resample(
81
+ audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
82
+ )
83
+ else:
84
+ # file upload
85
+
86
+ if not isinstance(audio_data, str):
87
+ return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
88
+ audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
89
+
90
+ lang_code = lang.split()[0]
91
+ processor.tokenizer.set_target_lang(lang_code)
92
+ model.load_adapter(lang_code)
93
+
94
+ inputs = processor(
95
+ audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
96
+ )
97
+
98
+ # set device
99
+ if torch.cuda.is_available():
100
+ device = torch.device("cuda")
101
+ elif (
102
+ hasattr(torch.backends, "mps")
103
+ and torch.backends.mps.is_available()
104
+ and torch.backends.mps.is_built()
105
+ ):
106
+ device = torch.device("mps")
107
+ else:
108
+ device = torch.device("cpu")
109
+
110
+ model.to(device)
111
+ inputs = inputs.to(device)
112
+
113
+ with torch.no_grad():
114
+ outputs = model(**inputs).logits
115
+
116
+ if lang_code != "eng" or True:
117
+ ids = torch.argmax(outputs, dim=-1)[0]
118
+ transcription = processor.decode(ids)
119
+ else:
120
+ assert False
121
+ # beam_search_result = beam_search_decoder(outputs.to("cpu"))
122
+ # transcription = " ".join(beam_search_result[0][0].words).strip()
123
+
124
+ return transcription
125
+
126
+
127
+ ASR_EXAMPLES = [
128
+ ["upload/english.mp3", "eng (English)"],
129
+ # ["upload/tamil.mp3", "tam (Tamil)"],
130
+ # ["upload/burmese.mp3", "mya (Burmese)"],
131
+ ]
132
+
133
+ ASR_NOTE = """
134
+ The above demo doesn't use beam-search decoding using a language model.
135
+ Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
136
+ """