ced-base / app.py
jimbozhang's picture
Init commit.
5e9bb10
# Copyright 2021-2023 Xiaomi Corporation
import gradio as gr
import torch
import torchaudio
from ced_model.feature_extraction_ced import CedFeatureExtractor
from ced_model.modeling_ced import CedForAudioClassification
model_path = "mispeech/ced-base"
feature_extractor = CedFeatureExtractor.from_pretrained(model_path)
model = CedForAudioClassification.from_pretrained(model_path)
def process(audio_path: str) -> str:
if audio_path is None:
return "No audio file uploaded."
global model
global label_maps
audio, sr = torchaudio.load(audio_path)
if sr != 16000:
return "Models are trained on 16khz, please sample your input to 16khz mono."
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_ids = torch.argmax(logits, dim=-1).item()
predicted_label = model.config.id2label[predicted_class_ids]
return predicted_label
iface_audio_file = gr.Interface(
fn=process,
inputs=gr.Audio(sources="upload", type="filepath", streaming=False),
outputs="text",
)
gr.close_all()
iface_audio_file.launch()