Spaces:
Running
Running
File size: 8,494 Bytes
936f6fa 58f8c9b 936f6fa 58f8c9b 936f6fa 58f8c9b 936f6fa 58f8c9b 936f6fa c932300 936f6fa 25ce57e 09a4143 c932300 09a4143 7cec2b9 c932300 09a4143 4a76f9b 58f8c9b 6354b1b 25ce57e 58f8c9b 6354b1b 25ce57e 58f8c9b 6354b1b 4a76f9b 936f6fa 58f8c9b 936f6fa |
|
import os
import librosa
import soundfile as sf
import resampy
import numpy as np
from scores.srmr.srmr import SRMR
from scores.dnsmos.dnsmos import DNSMOS
from scores.pesq import PESQ
from scores.nb_pesq import NB_PESQ
from scores.sisdr import SISDR
from scores.stoi import STOI
from scores.fwsegsnr import FWSEGSNR
from scores.lsd import LSD
from scores.bsseval import BSSEval
from scores.snr import SNR
from scores.ssnr import SSNR
from scores.llr import LLR
from scores.csig import CSIG
from scores.cbak import CBAK
from scores.covl import COVL
from scores.mcd import MCD
def compute_mean_results(*results, round_digits=None):
mean_result = {}
# Use the first dictionary as a reference for keys
for key in results[0]:
# If the value is a nested dictionary, recurse
if isinstance(results[0][key], dict):
nested_results = [d[key] for d in results]
mean_result[key] = compute_mean_results(*nested_results, round_digits=round_digits)
# Otherwise, compute the mean of the values
else:
if round_digits is not None:
mean_result[key] = round(sum(d[key] for d in results) / len(results), round_digits)
else:
mean_result[key] = sum(d[key] for d in results) / len(results)
return mean_result
class ScoresList:
def __init__(self):
self.scores = []
def __add__(self, score):
self.scores += [score]
return self
def __str__(self):
return 'Scores: ' + ' '.join([x.name for x in self.scores])
def __call__(self, test_path, reference_path, window=None, score_rate=None, return_mean=False, round_digits=None):
"""
window: float
the window length in seconds to use for scoring the files.
score_rate:
the sampling rate specified for scoring the files.
"""
if score_rate is None:
score_rate = 16000
if test_path is None:
print(f'Please provide audio path for test_path')
return
results = {}
if isinstance(test_path, tuple):
sr, audio = test_path
if sr != score_rate:
audio = resampy.resample(audio, sr, score_rate, axis=0)
data = {}
data['audio'] = [audio]
data['rate'] = score_rate
for score in self.scores:
result_score = score.scoring(data, window, score_rate, round_digits)
if result_score is not None:
results[score.name] = result_score
else:
if os.path.isdir(test_path):
audio_list = self.get_audio_list(test_path)
if audio_list is None: return
for audio_id in audio_list:
results_id = {}
if reference_path is not None:
data = self.audio_reader(test_path+'/'+audio_id, reference_path+'/'+audio_id)
else:
data = self.audio_reader(test_path+'/'+audio_id, None)
for score in self.scores:
result_score = score.scoring(data, window, score_rate, round_digits)
if result_score is not None:
results_id[score.name] = result_score
results[audio_id] = results_id
elif os.path.isfile(test_path):
data = self.audio_reader(test_path, reference_path)
for score in self.scores:
result_score = score.scoring(data, window, score_rate, round_digits)
if result_score is not None:
results[score.name] = result_score
if return_mean:
mean_result = compute_mean_results(*results.values(), round_digits=round_digits)
results['Mean_Score'] = mean_result
return results
def get_audio_list(self, path):
# Initialize an empty list to store audio file names
audio_list = []
# Find all '.wav' audio files in the given path
path_list = librosa.util.find_files(path, ext="wav")
# If no '.wav' files are found, try to find '.flac' audio files instead
if len(path_list) == 0:
path_list = librosa.util.find_files(path, ext="flac")
# If no audio files are found at all, print an error message and return None
if len(path_list) == 0:
print(f'No audio files found in {path}, scoring ended!')
return None
# Loop through the list of found audio file paths
for audio_path in path_list:
# Split the file path by '/' and append the last element (the file name) to the audio_list
audio_path_s = audio_path.split('/')
audio_list.append(audio_path_s[-1])
# Return the list of audio file names
return audio_list
def audio_reader(self, test_path, reference_path):
"""loading sound files and making sure they all have the same lengths
(zero-padding to the largest). Also works with numpy arrays.
"""
data = {}
audios = []
maxlen = 0
audio_test, rate_test = sf.read(test_path, always_2d=True)
if audio_test.shape[1] > 1:
audio_test = audio_test[..., 0, None]
rate = rate_test
if reference_path is not None:
audio_ref, rate_ref = sf.read(reference_path, always_2d=True)
if audio_ref.shape[1] > 1:
audio_ref = audio_ref[..., 0, None]
if rate_test != rate_ref:
rate = min(rate_test, rate_ref)
if rate_test != rate:
audio_test = resampy.resample(audio_test, rate_test, rate, axis=0)
if rate_ref != rate:
audio_ref = resampy.resample(audio_ref, rate_ref, rate, axis=0)
audios += [audio_test]
audios += [audio_ref]
else:
audios += [audio_test]
maxlen = 0
for index, audio in enumerate(audios):
maxlen = max(maxlen, audio.shape[0])
##padding
for index, audio in enumerate(audios):
if audio.shape[0] != maxlen:
new = np.zeros((maxlen,))
new[:audio.shape[0]] = audio[...,0]
audios[index] = new
else:
audios[index] = audio[...,0]
data['audio'] = audios
data['rate'] = rate
return data
def SpeechScore(scores=''):
""" Load the desired scores inside a Metrics object that can then
be called to compute all the desired scores.
Parameters:
----------
scores: str or list of str
the scores matching any of these will be automatically loaded. this
match is relative to the structure of the speechscores package.
For instance:
* 'absolute' will match all non-instrusive scores
* 'absolute.srmr' or 'srmr' will only match SRMR
* '' will match all
Returns:
--------
A ScoresList object, that can be run to get the desired scores
"""
score_cls = ScoresList()
for score in scores:
if score.lower() == 'srmr':
score_cls += SRMR()
elif score.lower() == 'pesq':
score_cls += PESQ()
elif score.lower() == 'nb_pesq':
score_cls += NB_PESQ()
elif score.lower() == 'stoi':
score_cls += STOI()
elif score.lower() == 'sisdr':
score_cls += SISDR()
elif score.lower() == 'fwsegsnr':
score_cls += FWSEGSNR()
elif score.lower() == 'lsd':
score_cls += LSD()
elif score.lower() == 'bsseval':
score_cls += BSSEval()
elif score.lower() == 'dnsmos':
score_cls += DNSMOS()
elif score.lower() == 'snr':
score_cls += SNR()
elif score.lower() == 'ssnr':
score_cls += SSNR()
elif score.lower() == 'llr':
score_cls += LLR()
elif score.lower() == 'csig':
score_cls += CSIG()
elif score.lower() == 'cbak':
score_cls += CBAK()
elif score.lower() == 'covl':
score_cls += COVL()
elif score.lower() == 'mcd':
score_cls += MCD()
else:
print('score is pending implementation...')
return score_cls
|