Yehor's picture
Replace the code
d4bbf90
raw
history blame
No virus
3.67 kB
import time
import torch
import librosa
import gradio as gr
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
model_name = "Yehor/w2v-bert-2.0-uk"
device = "cpu"
max_duration = 30
asr_model = AutoModelForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
audio_samples = [
"sample_1.wav",
"sample_2.wav",
"sample_3.wav",
"sample_4.wav",
"sample_5.wav",
"sample_6.wav",
]
description_head = """
# Speech-to-Text for Ukrainian
## Overview
This space uses https://huggingface.co/Yehor/w2v-bert-2.0-uk model that solves
a Speech-to-Text task for the Ukrainian language.
""".strip()
description_foot = """
## Community
- Join our Discord server - https://discord.gg/yVAjkBgmt4 - where we're talking about Data Science,
Machine Learning, Deep Learning, and Artificial Intelligence.
- Join our Speech Recognition Group in Telegram: https://t.me/speech_recognition_uk
""".strip()
def inference(audio_path, progress=gr.Progress()):
gr.Info("Starting process", duration=2)
progress(0, desc="Starting")
duration = librosa.get_duration(path=audio_path)
if duration > max_duration:
raise gr.Error("The duration of the file exceeds 10 seconds.")
paths = [
audio_path,
]
results = []
for path in progress.tqdm(paths, desc="Recognizing...", unit="file"):
t0 = time.time()
audio_duration = librosa.get_duration(path=path, sr=16_000)
audio_input, _ = librosa.load(path, mono=True, sr=16_000)
features = processor([audio_input], sampling_rate=16_000).input_features
features = torch.tensor(features).to(device)
with torch.inference_mode():
logits = asr_model(features).logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids)
elapsed_time = round(time.time() - t0, 2)
rtf = round(elapsed_time / audio_duration, 4)
audio_duration = round(audio_duration, 2)
results.append(
{
"path": path.split("/")[-1],
"transcription": "\n".join(predictions),
"audio_duration": audio_duration,
"rtf": rtf,
}
)
gr.Info("Finished...", duration=2)
result_texts = []
for result in results:
result_texts.append(f'**{result["path"]}**')
result_texts.append("\n\n")
result_texts.append(f"> {result['transcription']}")
result_texts.append("\n\n")
result_texts.append(f'**Audio duration**: {result['audio_duration']}')
result_texts.append("\n")
result_texts.append(f'**Real-Time Factor**: {result['rtf']}')
return "\n".join(result_texts)
demo = gr.Blocks(
title="Speech-to-Text for Ukrainian",
analytics_enabled=False,
)
with demo:
gr.Markdown(description_head)
gr.Markdown(f"## Demo (max. duration: **{max_duration}** seconds)")
with gr.Row():
audio_file = gr.Audio(label="Audio file", type="filepath")
transcription = gr.Markdown(
label="Transcription",
value="Recognized text will appear here. Use **an example file** below the Recognize button,"
"upload **your audio file**, or use **the microphone** to record something...",
)
gr.Button("Recognize").click(inference, inputs=audio_file, outputs=transcription)
with gr.Row():
gr.Examples(
label="Choose an example audio", inputs=audio_file, examples=audio_samples
)
gr.Markdown(description_foot)
if __name__ == "__main__":
demo.launch()