Spaces:
Running
Running
import os | |
import torch.nn.functional as F | |
import torchaudio | |
from loguru import logger | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import torch | |
import yaml | |
# ---------- Settings ---------- | |
GPU_ID = '-1' | |
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID | |
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu' | |
SERVER_PORT = 42208 | |
SERVER_NAME = "0.0.0.0" | |
SSL_DIR = './keyble_ssl' | |
FS = 16000 | |
resamplers = {} | |
MIN_REQUIRED_WAV_LENGTH = 1040 | |
# EXAMPLE_DIR = './examples' | |
# en_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "en", '*.wav'))) | |
# jp_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "jp", '*.wav'))) | |
# zh_examples = sorted(glob(os.path.join(EXAMPLE_DIR, "zh", '*.wav'))) | |
# ---------- Logging ---------- | |
logger.add('app.log', mode='a') | |
logger.info('============================= App restarted =============================') | |
# ---------- Download models ---------- | |
logger.info('============================= Download models ===========================') | |
model_paths = { | |
"SSL-MOS, all training sets": { | |
"ckpt": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/checkpoint-86000steps.pkl"), | |
"config": hf_hub_download(repo_id="unilight/sheet-models", filename="bvcc+nisqa+pstn+singmos+somos+tencent+tmhint-qi/sslmos+mdf/2337/config.yml"), | |
} | |
} | |
# ---------- Model ---------- | |
models = {} | |
for name, path_dict in model_paths.items(): | |
logger.info(f'============================= Setting up model for {name} =============') | |
checkpoint_path = path_dict["ckpt"] | |
config_path = path_dict["config"] | |
with open(config_path) as f: | |
config = yaml.load(f, Loader=yaml.Loader) | |
if config["model_type"] == "SSLMOS": | |
from models.sslmos import SSLMOS | |
model = SSLMOS( | |
config["model_input"], | |
num_listeners=config.get("num_listeners", None), | |
num_domains=config.get("num_domains", None), | |
**config["model_params"], | |
).to(DEVICE) | |
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]) | |
model = model.eval().to(DEVICE) | |
logger.info(f"Loaded model parameters from {checkpoint_path}.") | |
models[name] = model | |
def read_wav(wav_path): | |
# read waveform | |
waveform, sample_rate = torchaudio.load( | |
wav_path, channels_first=False | |
) # waveform: [T, 1] | |
# resample if needed | |
if sample_rate != FS: | |
resampler_key = f"{sample_rate}-{FS}" | |
if resampler_key not in resamplers: | |
resamplers[resampler_key] = torchaudio.transforms.Resample( | |
sample_rate, FS, dtype=waveform.dtype | |
) | |
waveform = resamplers[resampler_key](waveform) | |
waveform = waveform.squeeze(-1) | |
# always pad to a minumum length | |
if waveform.shape[0] < MIN_REQUIRED_WAV_LENGTH: | |
to_pad = (MIN_REQUIRED_WAV_LENGTH - waveform.shape[0]) // 2 | |
waveform = F.pad(waveform, (to_pad, to_pad), "constant", 0) | |
return waveform, sample_rate | |
def predict(model_name, wav_file): | |
x, fs = read_wav(wav_file) | |
logger.info('wav file loaded') | |
# set up model input | |
model_input = x.unsqueeze(0).to(DEVICE) | |
model_lengths = model_input.new_tensor([model_input.size(1)]).long() | |
inputs = { | |
config["model_input"]: model_input, | |
config["model_input"] + "_lengths": model_lengths, | |
} | |
with torch.no_grad(): | |
# model forward | |
if config["inference_mode"] == "mean_listener": | |
outputs = models[model_name].mean_listener_inference(inputs) | |
elif config["inference_mode"] == "mean_net": | |
outputs = models[model_name].mean_net_inference(inputs) | |
pred_mean_scores = outputs["scores"].cpu().detach().numpy()[0] | |
return pred_mean_scores | |
with gr.Blocks(title="S3PRL-VC: Any-to-one voice conversion demo on VCC2020") as demo: | |
gr.Markdown( | |
""" | |
# Demo for SHEET: Speech Human Evaluation Estimation Toolkit | |
### [[Paper (arXiv)]](https://arxiv.org/abs/2411.03715) [[Code]](https://github.com/unilight/sheet) | |
**SHEET** is a subjective speech quality assessment (SSQA) toolkit designed to conduct SSQA research. It was specifically designed to interactive with MOS-Bench, a collective of datasets to benchmark SSQA models. | |
In this demo, you can record your own voice or upload speech files to assess the quality. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Record your speech here!") | |
input_wav = gr.Audio(label="Input speech", type='filepath') | |
gr.Markdown("## Select a model!") | |
model_name = gr.Radio(label="Model", choices=list(model_paths.keys())) | |
evaluate_btn = gr.Button(value="Evaluate!") | |
# gr.Markdown("### You can use these examples if using a microphone is too troublesome!") | |
# gr.Markdown("I recorded the samples using my Macbook Pro, so there might be some noises.") | |
# gr.Examples( | |
# examples=en_examples, | |
# inputs=input_wav, | |
# label="English examples" | |
# ) | |
# gr.Examples( | |
# examples=jp_examples, | |
# inputs=input_wav, | |
# label="Japanese examples" | |
# ) | |
# gr.Examples( | |
# examples=zh_examples, | |
# inputs=input_wav, | |
# label="Mandarin examples" | |
# ) | |
with gr.Column(): | |
gr.Markdown("## The predicted scores is here:") | |
output_score = gr.Textbox(label="Prediction", interactive=False) | |
evaluate_btn.click(predict, [model_name, input_wav], output_score) | |
if __name__ == '__main__': | |
try: | |
demo.launch(debug=True) | |
except KeyboardInterrupt as e: | |
print(e) | |
finally: | |
demo.close() |