File size: 1,738 Bytes
4ae98c5
5292e3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ae98c5
 
 
 
 
 
 
5292e3e
 
 
 
 
 
 
 
 
4ae98c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# Load a smaller Wav2Vec model and processor for Persian
model_name = "facebook/wav2vec2-base"  # Smaller model
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

def transcribe_audio(audio):
    # Load the audio file and resample to 16kHz
    waveform, sample_rate = torchaudio.load(audio)
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)

    # Preprocess the audio
    input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values

    # Perform inference
    with torch.no_grad():
        logits = model(input_values).logits

    # Decode the logits to text
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])

    return transcription

with gr.Blocks(fill_height=True) as demo:
    with gr.Sidebar():
        gr.Markdown("# Inference Provider")
        gr.Markdown("This Space showcases the google/gemma-2-2b-it model, served by the nebius API. Sign in with your Hugging Face account to use this API.")
        button = gr.LoginButton("Sign in")
    
    with gr.Tab("Text Inference"):
        gr.load("models/google/gemma-2-2b-it", accept_token=button, provider="nebius")
    
    with gr.Tab("Persian ASR"):
        audio_input = gr.Audio(label="Upload Audio", type="filepath")
        text_output = gr.Textbox(label="Transcription")
        transcribe_button = gr.Button("Transcribe")
        transcribe_button.click(transcribe_audio, inputs=audio_input, outputs=text_output)

demo.launch()