import gradio as gr import torch import torchaudio from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor # Load a smaller Wav2Vec model and processor for Persian model_name = "facebook/wav2vec2-base" # Smaller model processor = Wav2Vec2Processor.from_pretrained(model_name) model = Wav2Vec2ForCTC.from_pretrained(model_name) def transcribe_audio(audio): # Load the audio file and resample to 16kHz waveform, sample_rate = torchaudio.load(audio) resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) # Preprocess the audio input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values # Perform inference with torch.no_grad(): logits = model(input_values).logits # Decode the logits to text predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.decode(predicted_ids[0]) return transcription with gr.Blocks(fill_height=True) as demo: with gr.Sidebar(): gr.Markdown("# Inference Provider") gr.Markdown("This Space showcases the google/gemma-2-2b-it model, served by the nebius API. Sign in with your Hugging Face account to use this API.") button = gr.LoginButton("Sign in") with gr.Tab("Text Inference"): gr.load("models/google/gemma-2-2b-it", accept_token=button, provider="nebius") with gr.Tab("Persian ASR"): audio_input = gr.Audio(label="Upload Audio", type="filepath") text_output = gr.Textbox(label="Transcription") transcribe_button = gr.Button("Transcribe") transcribe_button.click(transcribe_audio, inputs=audio_input, outputs=text_output) demo.launch()