Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # Copyright 2025 Xiaomi Corp. (authors: Han Zhu | |
| # Wei Kang) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system | |
| adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo | |
| # Download model checkpoints | |
| wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt | |
| wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| import fairseq | |
| import librosa | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import soundfile as sf | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| logging.basicConfig(level=logging.INFO) | |
| def get_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--wav-path", type=str, help="path of the evaluated speech directory" | |
| ) | |
| parser.add_argument( | |
| "--utmos-model-path", | |
| type=str, | |
| default="model/huggingface/utmos/utmos.pt", | |
| help="path of the UTMOS model", | |
| ) | |
| parser.add_argument( | |
| "--ssl-model-path", | |
| type=str, | |
| default="model/huggingface/utmos/wav2vec_small.pt", | |
| help="path of the wav2vec SSL model", | |
| ) | |
| return parser | |
| class UTMOSScore: | |
| """Predicting score for each audio clip.""" | |
| def __init__(self, utmos_model_path, ssl_model_path): | |
| self.sample_rate = 16000 | |
| self.device = ( | |
| torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| ) | |
| self.model = ( | |
| BaselineLightningModule.load_from_checkpoint( | |
| utmos_model_path, ssl_model_path=ssl_model_path | |
| ) | |
| .eval() | |
| .to(self.device) | |
| ) | |
| def score(self, wavs: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| wavs: waveforms to be evaluated. When len(wavs) == 1 or 2, | |
| the model processes the input as a single audio clip. The model | |
| performs batch processing when len(wavs) == 3. | |
| """ | |
| if len(wavs.shape) == 1: | |
| out_wavs = wavs.unsqueeze(0).unsqueeze(0) | |
| elif len(wavs.shape) == 2: | |
| out_wavs = wavs.unsqueeze(0) | |
| elif len(wavs.shape) == 3: | |
| out_wavs = wavs | |
| else: | |
| raise ValueError("Dimension of input tensor needs to be <= 3.") | |
| bs = out_wavs.shape[0] | |
| batch = { | |
| "wav": out_wavs, | |
| "domains": torch.zeros(bs, dtype=torch.int).to(self.device), | |
| "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288, | |
| } | |
| with torch.no_grad(): | |
| output = self.model(batch) | |
| return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3 | |
| def score_dir(self, dir, dtype="float32"): | |
| def _load_speech_task(fname, sample_rate): | |
| wav_data, sr = sf.read(fname, dtype=dtype) | |
| if sr != sample_rate: | |
| wav_data = librosa.resample( | |
| wav_data, orig_sr=sr, target_sr=self.sample_rate | |
| ) | |
| wav_data = torch.from_numpy(wav_data) | |
| return wav_data | |
| score_lst = [] | |
| for fname in tqdm(os.listdir(dir)): | |
| speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate) | |
| speech = speech.to(self.device) | |
| with torch.no_grad(): | |
| score = self.score(speech) | |
| score_lst.append(score.item()) | |
| return np.mean(score_lst) | |
| def load_ssl_model(ckpt_path="wav2vec_small.pt"): | |
| SSL_OUT_DIM = 768 | |
| model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( | |
| [ckpt_path] | |
| ) | |
| ssl_model = model[0] | |
| ssl_model.remove_pretraining_modules() | |
| return SSL_model(ssl_model, SSL_OUT_DIM) | |
| class BaselineLightningModule(pl.LightningModule): | |
| def __init__(self, ssl_model_path): | |
| super().__init__() | |
| self.construct_model(ssl_model_path) | |
| self.save_hyperparameters() | |
| def construct_model(self, ssl_model_path): | |
| self.feature_extractors = nn.ModuleList( | |
| [ | |
| load_ssl_model(ckpt_path=ssl_model_path), | |
| DomainEmbedding(3, 128), | |
| ] | |
| ) | |
| output_dim = sum( | |
| [ | |
| feature_extractor.get_output_dim() | |
| for feature_extractor in self.feature_extractors | |
| ] | |
| ) | |
| output_layers = [ | |
| LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim) | |
| ] | |
| output_dim = output_layers[-1].get_output_dim() | |
| output_layers.append( | |
| Projection( | |
| hidden_dim=2048, | |
| activation=torch.nn.ReLU(), | |
| range_clipping=False, | |
| input_dim=output_dim, | |
| ) | |
| ) | |
| self.output_layers = nn.ModuleList(output_layers) | |
| def forward(self, inputs): | |
| outputs = {} | |
| for feature_extractor in self.feature_extractors: | |
| outputs.update(feature_extractor(inputs)) | |
| x = outputs | |
| for output_layer in self.output_layers: | |
| x = output_layer(x, inputs) | |
| return x | |
| class SSL_model(nn.Module): | |
| def __init__(self, ssl_model, ssl_out_dim) -> None: | |
| super(SSL_model, self).__init__() | |
| self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim | |
| def forward(self, batch): | |
| wav = batch["wav"] | |
| wav = wav.squeeze(1) # [batches, wav_len] | |
| res = self.ssl_model(wav, mask=False, features_only=True) | |
| x = res["x"] | |
| return {"ssl-feature": x} | |
| def get_output_dim(self): | |
| return self.ssl_out_dim | |
| class DomainEmbedding(nn.Module): | |
| def __init__(self, n_domains, domain_dim) -> None: | |
| super().__init__() | |
| self.embedding = nn.Embedding(n_domains, domain_dim) | |
| self.output_dim = domain_dim | |
| def forward(self, batch): | |
| return {"domain-feature": self.embedding(batch["domains"])} | |
| def get_output_dim(self): | |
| return self.output_dim | |
| class LDConditioner(nn.Module): | |
| """ | |
| Conditions ssl output by listener embedding | |
| """ | |
| def __init__(self, input_dim, judge_dim, num_judges=None): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.judge_dim = judge_dim | |
| self.num_judges = num_judges | |
| assert num_judges is not None | |
| self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) | |
| # concat [self.output_layer, phoneme features] | |
| self.decoder_rnn = nn.LSTM( | |
| input_size=self.input_dim + self.judge_dim, | |
| hidden_size=512, | |
| num_layers=1, | |
| batch_first=True, | |
| bidirectional=True, | |
| ) # linear? | |
| self.out_dim = self.decoder_rnn.hidden_size * 2 | |
| def get_output_dim(self): | |
| return self.out_dim | |
| def forward(self, x, batch): | |
| judge_ids = batch["judge_id"] | |
| if "phoneme-feature" in x.keys(): | |
| concatenated_feature = torch.cat( | |
| ( | |
| x["ssl-feature"], | |
| x["phoneme-feature"] | |
| .unsqueeze(1) | |
| .expand(-1, x["ssl-feature"].size(1), -1), | |
| ), | |
| dim=2, | |
| ) | |
| else: | |
| concatenated_feature = x["ssl-feature"] | |
| if "domain-feature" in x.keys(): | |
| concatenated_feature = torch.cat( | |
| ( | |
| concatenated_feature, | |
| x["domain-feature"] | |
| .unsqueeze(1) | |
| .expand(-1, concatenated_feature.size(1), -1), | |
| ), | |
| dim=2, | |
| ) | |
| if judge_ids is not None: | |
| concatenated_feature = torch.cat( | |
| ( | |
| concatenated_feature, | |
| self.judge_embedding(judge_ids) | |
| .unsqueeze(1) | |
| .expand(-1, concatenated_feature.size(1), -1), | |
| ), | |
| dim=2, | |
| ) | |
| decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) | |
| return decoder_output | |
| class Projection(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): | |
| super(Projection, self).__init__() | |
| self.range_clipping = range_clipping | |
| output_dim = 1 | |
| if range_clipping: | |
| self.proj = nn.Tanh() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| activation, | |
| nn.Dropout(0.3), | |
| nn.Linear(hidden_dim, output_dim), | |
| ) | |
| self.output_dim = output_dim | |
| def forward(self, x, batch): | |
| output = self.net(x) | |
| # range clipping | |
| if self.range_clipping: | |
| return self.proj(output) * 2.0 + 3 | |
| else: | |
| return output | |
| def get_output_dim(self): | |
| return self.output_dim | |
| if __name__ == "__main__": | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| UTMOS = UTMOSScore( | |
| utmos_model_path=args.utmos_model_path, | |
| ssl_model_path=args.ssl_model_path, | |
| ) | |
| score = UTMOS.score_dir(args.wav_path) | |
| logging.info(f"UTMOS score: {score:.2f}") | |