File size: 3,786 Bytes
7bcf8d7 5442f52 6f53272 5442f52 4090e0d 7bcf8d7 d4e3aaf 7bcf8d7 1b7f5da 7bcf8d7 6f53272 7bcf8d7 cd75560 7bcf8d7 a45002a 7bcf8d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import json
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
ASR_SAMPLING_RATE = 16_000
ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
for line in f:
iso, name = line.split(" ", 1)
ASR_LANGUAGES[iso] = name
MODEL_ID = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
# lm_decoding_config = {}
# lm_decoding_configfile = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename="decoding_config.json",
# subfolder="mms-1b-all",
# )
# with open(lm_decoding_configfile) as f:
# lm_decoding_config = json.loads(f.read())
# # allow language model decoding for "eng"
# decoding_config = lm_decoding_config["eng"]
# lm_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["lmfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
# )
# token_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
# )
# lexicon_file = None
# if decoding_config["lexiconfile"] is not None:
# lexicon_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
# )
# beam_search_decoder = ctc_decoder(
# lexicon=lexicon_file,
# tokens=token_file,
# lm=lm_file,
# nbest=1,
# beam_size=500,
# beam_size_token=50,
# lm_weight=float(decoding_config["lmweight"]),
# word_score=float(decoding_config["wordscore"]),
# sil_score=float(decoding_config["silweight"]),
# blank_token="<s>",
# )
def transcribe(audio_data, lang="eng (English)"):
if isinstance(audio_data, tuple):
# microphone
sr, audio_samples = audio_data
print("case1", audio_samples[:5])
assert sr == ASR_SAMPLING_RATE, "Invalid sampling rate"
else:
# file upload
isinstance(audio_data, str)
print("case2 1", audio_data)
audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
print("case2", audio_samples[:5])
lang_code = lang.split()[0]
processor.tokenizer.set_target_lang(lang_code)
model.load_adapter(lang_code)
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
if lang_code != "eng" or True:
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
else:
assert False
# beam_search_result = beam_search_decoder(outputs.to("cpu"))
# transcription = " ".join(beam_search_result[0][0].words).strip()
return transcription
ASR_EXAMPLES = [
["assets/english.mp3", "eng (English)"],
# ["assets/tamil.mp3", "tam (Tamil)"],
# ["assets/burmese.mp3", "mya (Burmese)"],
]
ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model.
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
"""
|