Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torchaudio | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
# Load a smaller Wav2Vec model and processor for Persian | |
model_name = "facebook/wav2vec2-base" # Smaller model | |
processor = Wav2Vec2Processor.from_pretrained(model_name) | |
model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
def transcribe_audio(audio): | |
# Load the audio file and resample to 16kHz | |
waveform, sample_rate = torchaudio.load(audio) | |
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
waveform = resampler(waveform) | |
# Preprocess the audio | |
input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values | |
# Perform inference | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
# Decode the logits to text | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.decode(predicted_ids[0]) | |
return transcription | |
with gr.Blocks(fill_height=True) as demo: | |
with gr.Sidebar(): | |
gr.Markdown("# Inference Provider") | |
gr.Markdown("This Space showcases the google/gemma-2-2b-it model, served by the nebius API. Sign in with your Hugging Face account to use this API.") | |
button = gr.LoginButton("Sign in") | |
with gr.Tab("Text Inference"): | |
gr.load("models/google/gemma-2-2b-it", accept_token=button, provider="nebius") | |
with gr.Tab("Persian ASR"): | |
audio_input = gr.Audio(label="Upload Audio", type="filepath") | |
text_output = gr.Textbox(label="Transcription") | |
transcribe_button = gr.Button("Transcribe") | |
transcribe_button.click(transcribe_audio, inputs=audio_input, outputs=text_output) | |
demo.launch() |