Spaces:
Running
Running
File size: 4,305 Bytes
544e017 3a18b3b 7a0f405 1dfec92 544e017 3a18b3b 499b2c1 3a18b3b 1dfec92 3a18b3b 499b2c1 7bc4048 c492cbb 3a18b3b 1dfec92 3a18b3b c492cbb 70da837 1dfec92 70da837 3a18b3b 544e017 1dfec92 ef107e3 1dfec92 544e017 1dfec92 6502e85 7a0f405 339c131 7a0f405 6502e85 71494c3 3da96bb 499b2c1 71494c3 3a18b3b 6502e85 3a18b3b 544e017 3da96bb 544e017 6502e85 544e017 6502e85 544e017 3a18b3b 544e017 1dfec92 81e83c9 1dfec92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torchaudio
import torch
from transformers import (
WhisperProcessor,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
AutoModelForCTC,
Wav2Vec2Processor,
Wav2Vec2ForCTC
)
import numpy as np
import util
# Load processor and model
models_info = {
"OpenAI-Whisper": {
"processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"),
"model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"),
"ctc_model": False,
"arabic_script": False
},
"Meta-MMS": {
"processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'),
"model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True),
"ctc_model": True,
"arabic_script": True
},
"Ixxan-FineTuned-Whisper": {
"processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
"model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
"ctc_model": False,
"arabic_script": False
},
"Ixxan-FineTuned-MMS": {
"processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
"model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
"ctc_model": True,
"arabic_script": False
},
}
# def transcribe(audio_data, model_id) -> str:
# if model_id == "Compare All Models":
# return transcribe_all_models(audio_data)
# else:
# return transcribe_with_model(audio_data, model_id)
# def transcribe_all_models(audio_data) -> dict:
# transcriptions = {}
# for model_id in models_info.keys():
# transcriptions[model_id] = transcribe_with_model(audio_data, model_id)
# return transcriptions
def transcribe(audio_data, model_id) -> str:
# Load user audio
if isinstance(audio_data, tuple):
# microphone
sampling_rate, audio_input = audio_data
audio_input = (audio_input / 32768.0).astype(np.float32)
elif isinstance(audio_data, str):
# file upload
audio_input, sampling_rate = torchaudio.load(audio_data)
else:
return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data)), None
# # Check audio duration
# duration = audio_input.shape[1] / sampling_rate
# if duration > 10:
# return f"<<ERROR: Audio duration ({duration:.2f}s) exceeds 10 seconds. Please upload a shorter audio clip for faster processing.>>", None
model = models_info[model_id]["model"]
processor = models_info[model_id]["processor"]
target_sr = processor.feature_extractor.sampling_rate
ctc_model = models_info[model_id]["ctc_model"]
# Resample if needed
if sampling_rate != target_sr:
resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
audio_input = resampler(audio_input)
sampling_rate = target_sr
# Preprocess the audio input
inputs = processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt")
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Generate transcription
with torch.no_grad():
if ctc_model:
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
else:
generated_ids = model.generate(inputs["input_features"], max_length=225)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if models_info[model_id]["arabic_script"]:
transcription_arabic = transcription
transcription_latin = util.ug_arab_to_latn(transcription)
else: # Latin script output
transcription_arabic = util.ug_latn_to_arab(transcription)
transcription_latin = transcription
print(model_id, transcription_arabic, transcription_latin)
return transcription_arabic, transcription_latin |