hackergeek's picture
Update app.py
5292e3e verified
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
# Load a smaller Wav2Vec model and processor for Persian
model_name = "facebook/wav2vec2-base" # Smaller model
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
def transcribe_audio(audio):
# Load the audio file and resample to 16kHz
waveform, sample_rate = torchaudio.load(audio)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
# Preprocess the audio
input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
# Perform inference
with torch.no_grad():
logits = model(input_values).logits
# Decode the logits to text
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
return transcription
with gr.Blocks(fill_height=True) as demo:
with gr.Sidebar():
gr.Markdown("# Inference Provider")
gr.Markdown("This Space showcases the google/gemma-2-2b-it model, served by the nebius API. Sign in with your Hugging Face account to use this API.")
button = gr.LoginButton("Sign in")
with gr.Tab("Text Inference"):
gr.load("models/google/gemma-2-2b-it", accept_token=button, provider="nebius")
with gr.Tab("Persian ASR"):
audio_input = gr.Audio(label="Upload Audio", type="filepath")
text_output = gr.Textbox(label="Transcription")
transcribe_button = gr.Button("Transcribe")
transcribe_button.click(transcribe_audio, inputs=audio_input, outputs=text_output)
demo.launch()