|
|
|
|
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
|
|
from ced_model.feature_extraction_ced import CedFeatureExtractor |
|
from ced_model.modeling_ced import CedForAudioClassification |
|
|
|
model_path = "mispeech/ced-base" |
|
feature_extractor = CedFeatureExtractor.from_pretrained(model_path) |
|
model = CedForAudioClassification.from_pretrained(model_path) |
|
|
|
|
|
def process(audio_path: str) -> str: |
|
if audio_path is None: |
|
return "No audio file uploaded." |
|
|
|
global model |
|
global label_maps |
|
audio, sr = torchaudio.load(audio_path) |
|
if sr != 16000: |
|
return "Models are trained on 16khz, please sample your input to 16khz mono." |
|
|
|
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
predicted_class_ids = torch.argmax(logits, dim=-1).item() |
|
predicted_label = model.config.id2label[predicted_class_ids] |
|
|
|
return predicted_label |
|
|
|
|
|
iface_audio_file = gr.Interface( |
|
fn=process, |
|
inputs=gr.Audio(sources="upload", type="filepath", streaming=False), |
|
outputs="text", |
|
) |
|
gr.close_all() |
|
iface_audio_file.launch() |
|
|