# Copyright 2021-2023 Xiaomi Corporation import gradio as gr import torch import torchaudio from ced_model.feature_extraction_ced import CedFeatureExtractor from ced_model.modeling_ced import CedForAudioClassification model_path = "mispeech/ced-base" feature_extractor = CedFeatureExtractor.from_pretrained(model_path) model = CedForAudioClassification.from_pretrained(model_path) def process(audio_path: str) -> str: if audio_path is None: return "No audio file uploaded." global model global label_maps audio, sr = torchaudio.load(audio_path) if sr != 16000: return "Models are trained on 16khz, please sample your input to 16khz mono." inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_class_ids = torch.argmax(logits, dim=-1).item() predicted_label = model.config.id2label[predicted_class_ids] return predicted_label iface_audio_file = gr.Interface( fn=process, inputs=gr.Audio(sources="upload", type="filepath", streaming=False), outputs="text", ) gr.close_all() iface_audio_file.launch()