akhaliq3
spaces demo
5019931
raw
history blame
No virus
5.3 kB
import datetime
import logging
import os
import pickle
from typing import Dict, NoReturn
import librosa
import numpy as np
import yaml
def create_logging(log_dir: str, filemode: str) -> logging:
r"""Create logging to write out log files.
Args:
logs_dir, str, directory to write out logs
filemode: str, e.g., "w"
Returns:
logging
"""
os.makedirs(log_dir, exist_ok=True)
i1 = 0
while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
i1 += 1
log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
datefmt="%a, %d %b %Y %H:%M:%S",
filename=log_path,
filemode=filemode,
)
# Print to console
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
console.setFormatter(formatter)
logging.getLogger("").addHandler(console)
return logging
def load_audio(
audio_path: str,
mono: bool,
sample_rate: float,
offset: float = 0.0,
duration: float = None,
) -> np.array:
r"""Load audio.
Args:
audio_path: str
mono: bool
sample_rate: float
"""
audio, _ = librosa.core.load(
audio_path, sr=sample_rate, mono=mono, offset=offset, duration=duration
)
# (audio_samples,) | (channels_num, audio_samples)
if audio.ndim == 1:
audio = audio[None, :]
# (1, audio_samples,)
return audio
def load_random_segment(
audio_path: str, random_state, segment_seconds: float, mono: bool, sample_rate: int
) -> np.array:
r"""Randomly select an audio segment from a recording."""
duration = librosa.get_duration(filename=audio_path)
start_time = random_state.uniform(0.0, duration - segment_seconds)
audio = load_audio(
audio_path=audio_path,
mono=mono,
sample_rate=sample_rate,
offset=start_time,
duration=segment_seconds,
)
# (channels_num, audio_samples)
return audio
def float32_to_int16(x: np.float32) -> np.int16:
x = np.clip(x, a_min=-1, a_max=1)
return (x * 32767.0).astype(np.int16)
def int16_to_float32(x: np.int16) -> np.float32:
return (x / 32767.0).astype(np.float32)
def read_yaml(config_yaml: str):
with open(config_yaml, "r") as fr:
configs = yaml.load(fr, Loader=yaml.FullLoader)
return configs
def check_configs_gramma(configs: Dict) -> NoReturn:
r"""Check if the gramma of the config dictionary for training is legal."""
input_source_types = configs['train']['input_source_types']
for augmentation_type in configs['train']['augmentations'].keys():
augmentation_dict = configs['train']['augmentations'][augmentation_type]
for source_type in augmentation_dict.keys():
if source_type not in input_source_types:
error_msg = (
"The source type '{}'' in configs['train']['augmentations']['{}'] "
"must be one of input_source_types {}".format(
source_type, augmentation_type, input_source_types
)
)
raise Exception(error_msg)
def magnitude_to_db(x: float) -> float:
eps = 1e-10
return 20.0 * np.log10(max(x, eps))
def db_to_magnitude(x: float) -> float:
return 10.0 ** (x / 20)
def get_pitch_shift_factor(shift_pitch: float) -> float:
r"""The factor of the audio length to be scaled."""
return 2 ** (shift_pitch / 12)
class StatisticsContainer(object):
def __init__(self, statistics_path):
self.statistics_path = statistics_path
self.backup_statistics_path = "{}_{}.pkl".format(
os.path.splitext(self.statistics_path)[0],
datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
)
self.statistics_dict = {"train": [], "test": []}
def append(self, steps, statistics, split):
statistics["steps"] = steps
self.statistics_dict[split].append(statistics)
def dump(self):
pickle.dump(self.statistics_dict, open(self.statistics_path, "wb"))
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb"))
logging.info(" Dump statistics to {}".format(self.statistics_path))
logging.info(" Dump statistics to {}".format(self.backup_statistics_path))
'''
def load_state_dict(self, resume_steps):
self.statistics_dict = pickle.load(open(self.statistics_path, "rb"))
resume_statistics_dict = {"train": [], "test": []}
for key in self.statistics_dict.keys():
for statistics in self.statistics_dict[key]:
if statistics["steps"] <= resume_steps:
resume_statistics_dict[key].append(statistics)
self.statistics_dict = resume_statistics_dict
'''
def calculate_sdr(ref: np.array, est: np.array) -> float:
s_true = ref
s_artif = est - ref
sdr = 10.0 * (
np.log10(np.clip(np.mean(s_true ** 2), 1e-8, np.inf))
- np.log10(np.clip(np.mean(s_artif ** 2), 1e-8, np.inf))
)
return sdr