# Copyright (c) microsoft # 2022 Chengdong Liang (liangchengdong@mail.nwpu.edu.cn) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gradio as gr import torchaudio import torchaudio.compliance.kaldi as kaldi import torch import onnxruntime as ort from sklearn.metrics.pairwise import cosine_similarity STYLE = """ """ OUTPUT_OK = (STYLE + """

The speakers are

{:.1f}%

similar

Welcome, human!

(You must get at least 50% to be considered the same person)
""") OUTPUT_FAIL = (STYLE + """

The speakers are

{:.1f}%

similar

Warning! stranger!

(You must get at least 50% to be considered the same person)
""") OUTPUT_ERROR = (STYLE + """

Input Error

{}!

""") def compute_fbank(wav_path, num_bel_bins=80, frame_length=25, frame_shift=10, dither=0.0, resample_rate=16000): """ Extract fbank, simlilar to the one in wespeaker.dataset.processor, While integrating the wave reading and CMN. """ waveform, sample_rate = torchaudio.load(wav_path) # resample if sample_rate != resample_rate: waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate)(waveform) waveform = waveform * (1 << 15) mat = kaldi.fbank(waveform, num_mel_bins=num_bel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, sample_frequency=sample_rate, window_type='hamming', use_energy=False) # CMN, without CVN mat = mat - torch.mean(mat, dim=0) return mat class OnnxModel(object): def __init__(self, model_path): so = ort.SessionOptions() so.inter_op_num_threads = 1 so.intra_op_num_threads = 1 self.session = ort.InferenceSession(model_path, sess_options=so) def extract_embedding(self, wav_path): feats = compute_fbank(wav_path) feats = feats.unsqueeze(0).numpy() embeddings = self.session.run(output_names=['embs'], input_feed={'feats': feats}) return embeddings[0] def speaker_verification(audio_path1, audio_path2, lang='CN'): if audio_path1 == None or audio_path2 == None: output = OUTPUT_ERROR.format('Please enter two audios') return output if lang == 'EN': model = OnnxModel('pre_model/voxceleb_resnet34.onnx') elif lang == 'CN': model = OnnxModel('pre_model/cnceleb_resnet34.onnx') else: output = OUTPUT_ERROR.format('Please select a language') return output emb1 = model.extract_embedding(audio_path1) emb2 = model.extract_embedding(audio_path2) cos_score = cosine_similarity(emb1.reshape(1, -1), emb2.reshape(1, -1))[0][0] cos_score = (cos_score + 1) / 2.0 if cos_score >= 0.5: output = OUTPUT_OK.format(cos_score * 100) else: output = OUTPUT_FAIL.format(cos_score * 100) return output # input inputs = [ gr.inputs.Audio(source="microphone", type="filepath", optional=True, label='Speaker#1'), gr.inputs.Audio(source="microphone", type="filepath", optional=True, label='Speaker#2'), gr.Radio(['EN', 'CN'], label='Language'), ] output = gr.outputs.HTML(label="") # description description = ("WeSpeaker Demo ! Try it with your own voice !") article = ( "

" "Github: Learn more about WeSpeaker" "

") examples = [ ['examples/BAC009S0764W0228.wav', 'examples/BAC009S0764W0328.wav', 'CN'], ['examples/BAC009S0913W0133.wav', 'examples/BAC009S0764W0228.wav', 'CN'], ['examples/00001_spk1.wav', 'examples/00003_spk2.wav', 'EN'], ['examples/00010_spk2.wav', 'examples/00024_spk1.wav', 'EN'], ['examples/00001_spk1.wav', 'examples/00024_spk1.wav', 'EN'], ['examples/00010_spk2.wav', 'examples/00003_spk2.wav', 'EN'], ] interface = gr.Interface( fn=speaker_verification, inputs=inputs, outputs=output, title="Speaker verification in WeSpeaker : 基于 WeSpeaker 的说话人确认", description=description, article=article, examples=examples, theme="huggingface", ) interface.launch(enable_queue=True)