|
import os |
|
import datetime |
|
import json |
|
import logging |
|
import librosa |
|
import pickle |
|
from typing import Dict |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import yaml |
|
from models.audiosep import AudioSep, get_model_class |
|
|
|
|
|
def ignore_warnings(): |
|
import warnings |
|
|
|
warnings.filterwarnings('ignore', category=UserWarning, module='torch.functional') |
|
|
|
|
|
pattern = r"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: \['lm_head\..*'\].*" |
|
warnings.filterwarnings('ignore', message=pattern) |
|
|
|
|
|
|
|
def create_logging(log_dir, filemode): |
|
os.makedirs(log_dir, exist_ok=True) |
|
i1 = 0 |
|
|
|
while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): |
|
i1 += 1 |
|
|
|
log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) |
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", |
|
datefmt="%a, %d %b %Y %H:%M:%S", |
|
filename=log_path, |
|
filemode=filemode, |
|
) |
|
|
|
|
|
console = logging.StreamHandler() |
|
console.setLevel(logging.INFO) |
|
formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") |
|
console.setFormatter(formatter) |
|
logging.getLogger("").addHandler(console) |
|
|
|
return logging |
|
|
|
|
|
def float32_to_int16(x: float) -> int: |
|
x = np.clip(x, a_min=-1, a_max=1) |
|
return (x * 32767.0).astype(np.int16) |
|
|
|
|
|
def int16_to_float32(x: int) -> float: |
|
return (x / 32767.0).astype(np.float32) |
|
|
|
|
|
def parse_yaml(config_yaml: str) -> Dict: |
|
r"""Parse yaml file. |
|
|
|
Args: |
|
config_yaml (str): config yaml path |
|
|
|
Returns: |
|
yaml_dict (Dict): parsed yaml file |
|
""" |
|
|
|
with open(config_yaml, "r") as fr: |
|
return yaml.load(fr, Loader=yaml.FullLoader) |
|
|
|
|
|
def get_audioset632_id_to_lb(ontology_path: str) -> Dict: |
|
r"""Get AudioSet 632 classes ID to label mapping.""" |
|
|
|
audioset632_id_to_lb = {} |
|
|
|
with open(ontology_path) as f: |
|
data_list = json.load(f) |
|
|
|
for e in data_list: |
|
audioset632_id_to_lb[e["id"]] = e["name"] |
|
|
|
return audioset632_id_to_lb |
|
|
|
|
|
def load_pretrained_panns( |
|
model_type: str, |
|
checkpoint_path: str, |
|
freeze: bool |
|
) -> nn.Module: |
|
r"""Load pretrained pretrained audio neural networks (PANNs). |
|
|
|
Args: |
|
model_type: str, e.g., "Cnn14" |
|
checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth" |
|
freeze: bool |
|
|
|
Returns: |
|
model: nn.Module |
|
""" |
|
|
|
if model_type == "Cnn14": |
|
Model = Cnn14 |
|
|
|
elif model_type == "Cnn14_DecisionLevelMax": |
|
Model = Cnn14_DecisionLevelMax |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
model = Model(sample_rate=32000, window_size=1024, hop_size=320, |
|
mel_bins=64, fmin=50, fmax=14000, classes_num=527) |
|
|
|
if checkpoint_path: |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
model.load_state_dict(checkpoint["model"]) |
|
|
|
if freeze: |
|
for param in model.parameters(): |
|
param.requires_grad = False |
|
|
|
return model |
|
|
|
|
|
def energy(x): |
|
return torch.mean(x ** 2) |
|
|
|
|
|
def magnitude_to_db(x): |
|
eps = 1e-10 |
|
return 20. * np.log10(max(x, eps)) |
|
|
|
|
|
def db_to_magnitude(x): |
|
return 10. ** (x / 20) |
|
|
|
|
|
def ids_to_hots(ids, classes_num, device): |
|
hots = torch.zeros(classes_num).to(device) |
|
for id in ids: |
|
hots[id] = 1 |
|
return hots |
|
|
|
|
|
def calculate_sdr( |
|
ref: np.ndarray, |
|
est: np.ndarray, |
|
eps=1e-10 |
|
) -> float: |
|
r"""Calculate SDR between reference and estimation. |
|
|
|
Args: |
|
ref (np.ndarray), reference signal |
|
est (np.ndarray), estimated signal |
|
""" |
|
reference = ref |
|
noise = est - reference |
|
|
|
|
|
numerator = np.clip(a=np.mean(reference ** 2), a_min=eps, a_max=None) |
|
|
|
denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) |
|
|
|
sdr = 10. * np.log10(numerator / denominator) |
|
|
|
return sdr |
|
|
|
|
|
def calculate_sisdr(ref, est): |
|
r"""Calculate SDR between reference and estimation. |
|
|
|
Args: |
|
ref (np.ndarray), reference signal |
|
est (np.ndarray), estimated signal |
|
""" |
|
|
|
eps = np.finfo(ref.dtype).eps |
|
|
|
reference = ref.copy() |
|
estimate = est.copy() |
|
|
|
reference = reference.reshape(reference.size, 1) |
|
estimate = estimate.reshape(estimate.size, 1) |
|
|
|
Rss = np.dot(reference.T, reference) |
|
|
|
a = (eps + np.dot(reference.T, estimate)) / (Rss + eps) |
|
|
|
e_true = a * reference |
|
e_res = estimate - e_true |
|
|
|
Sss = (e_true**2).sum() |
|
Snn = (e_res**2).sum() |
|
|
|
sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn)) |
|
|
|
return sisdr |
|
|
|
|
|
class StatisticsContainer(object): |
|
def __init__(self, statistics_path): |
|
self.statistics_path = statistics_path |
|
|
|
self.backup_statistics_path = "{}_{}.pkl".format( |
|
os.path.splitext(self.statistics_path)[0], |
|
datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), |
|
) |
|
|
|
self.statistics_dict = {"balanced_train": [], "test": []} |
|
|
|
def append(self, steps, statistics, split, flush=True): |
|
statistics["steps"] = steps |
|
self.statistics_dict[split].append(statistics) |
|
|
|
if flush: |
|
self.flush() |
|
|
|
def flush(self): |
|
pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) |
|
pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) |
|
logging.info(" Dump statistics to {}".format(self.statistics_path)) |
|
logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) |
|
|
|
|
|
def get_mean_sdr_from_dict(sdris_dict): |
|
mean_sdr = np.nanmean(list(sdris_dict.values())) |
|
return mean_sdr |
|
|
|
|
|
def remove_silence(audio: np.ndarray, sample_rate: int) -> np.ndarray: |
|
r"""Remove silent frames.""" |
|
window_size = int(sample_rate * 0.1) |
|
threshold = 0.02 |
|
|
|
frames = librosa.util.frame(x=audio, frame_length=window_size, hop_length=window_size).T |
|
|
|
|
|
new_frames = get_active_frames(frames, threshold) |
|
|
|
|
|
new_audio = new_frames.flatten() |
|
|
|
|
|
return new_audio |
|
|
|
|
|
def get_active_frames(frames: np.ndarray, threshold: float) -> np.ndarray: |
|
r"""Get active frames.""" |
|
|
|
energy = np.max(np.abs(frames), axis=-1) |
|
|
|
|
|
active_indexes = np.where(energy > threshold)[0] |
|
|
|
|
|
new_frames = frames[active_indexes] |
|
|
|
|
|
return new_frames |
|
|
|
|
|
def repeat_to_length(audio: np.ndarray, segment_samples: int) -> np.ndarray: |
|
r"""Repeat audio to length.""" |
|
|
|
repeats_num = (segment_samples // audio.shape[-1]) + 1 |
|
audio = np.tile(audio, repeats_num)[0 : segment_samples] |
|
|
|
return audio |
|
|
|
def calculate_segmentwise_sdr(ref, est, hop_samples, return_sdr_list=False): |
|
min_len = min(ref.shape[-1], est.shape[-1]) |
|
pointer = 0 |
|
sdrs = [] |
|
while pointer + hop_samples < min_len: |
|
sdr = calculate_sdr( |
|
ref=ref[:, pointer : pointer + hop_samples], |
|
est=est[:, pointer : pointer + hop_samples], |
|
) |
|
sdrs.append(sdr) |
|
pointer += hop_samples |
|
|
|
sdr = np.nanmedian(sdrs) |
|
|
|
if return_sdr_list: |
|
return sdr, sdrs |
|
else: |
|
return sdr |
|
|
|
|
|
def loudness(data, input_loudness, target_loudness): |
|
""" Loudness normalize a signal. |
|
|
|
Normalize an input signal to a user loudness in dB LKFS. |
|
|
|
Params |
|
------- |
|
data : torch.Tensor |
|
Input multichannel audio data. |
|
input_loudness : float |
|
Loudness of the input in dB LUFS. |
|
target_loudness : float |
|
Target loudness of the output in dB LUFS. |
|
|
|
Returns |
|
------- |
|
output : torch.Tensor |
|
Loudness normalized output data. |
|
""" |
|
|
|
|
|
delta_loudness = target_loudness - input_loudness |
|
gain = torch.pow(10.0, delta_loudness / 20.0) |
|
|
|
output = gain * data |
|
|
|
|
|
|
|
|
|
|
|
return output |
|
|
|
|
|
def load_ss_model( |
|
configs: Dict, |
|
checkpoint_path: str, |
|
query_encoder: nn.Module |
|
) -> nn.Module: |
|
r"""Load trained universal source separation model. |
|
|
|
Args: |
|
configs (Dict) |
|
checkpoint_path (str): path of the checkpoint to load |
|
device (str): e.g., "cpu" | "cuda" |
|
|
|
Returns: |
|
pl_model: pl.LightningModule |
|
""" |
|
|
|
ss_model_type = configs["model"]["model_type"] |
|
input_channels = configs["model"]["input_channels"] |
|
output_channels = configs["model"]["output_channels"] |
|
condition_size = configs["model"]["condition_size"] |
|
|
|
|
|
SsModel = get_model_class(model_type=ss_model_type) |
|
|
|
ss_model = SsModel( |
|
input_channels=input_channels, |
|
output_channels=output_channels, |
|
condition_size=condition_size, |
|
) |
|
|
|
|
|
pl_model = AudioSep.load_from_checkpoint( |
|
checkpoint_path=checkpoint_path, |
|
strict=False, |
|
ss_model=ss_model, |
|
waveform_mixer=None, |
|
query_encoder=query_encoder, |
|
loss_function=None, |
|
optimizer_type=None, |
|
learning_rate=None, |
|
lr_lambda_func=None, |
|
map_location=torch.device('cpu'), |
|
) |
|
|
|
return pl_model |
|
|
|
|
|
def parse_yaml(config_yaml: str) -> Dict: |
|
r"""Parse yaml file. |
|
|
|
Args: |
|
config_yaml (str): config yaml path |
|
|
|
Returns: |
|
yaml_dict (Dict): parsed yaml file |
|
""" |
|
|
|
with open(config_yaml, "r") as fr: |
|
return yaml.load(fr, Loader=yaml.FullLoader) |