| import os
|
| import datetime
|
| import json
|
| import logging
|
| import librosa
|
| import pickle
|
| from typing import Dict
|
| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| import yaml
|
| from models.audiosep import AudioSep, get_model_class
|
|
|
|
|
| def ignore_warnings():
|
| import warnings
|
|
|
| warnings.filterwarnings('ignore', category=UserWarning, module='torch.functional')
|
|
|
|
|
| pattern = r"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: \['lm_head\..*'\].*"
|
| warnings.filterwarnings('ignore', message=pattern)
|
|
|
|
|
|
|
| def create_logging(log_dir, filemode):
|
| os.makedirs(log_dir, exist_ok=True)
|
| i1 = 0
|
|
|
| while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
|
| i1 += 1
|
|
|
| log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
|
| logging.basicConfig(
|
| level=logging.DEBUG,
|
| format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
|
| datefmt="%a, %d %b %Y %H:%M:%S",
|
| filename=log_path,
|
| filemode=filemode,
|
| )
|
|
|
|
|
| console = logging.StreamHandler()
|
| console.setLevel(logging.INFO)
|
| formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
|
| console.setFormatter(formatter)
|
| logging.getLogger("").addHandler(console)
|
|
|
| return logging
|
|
|
|
|
| def float32_to_int16(x: float) -> int:
|
| x = np.clip(x, a_min=-1, a_max=1)
|
| return (x * 32767.0).astype(np.int16)
|
|
|
|
|
| def int16_to_float32(x: int) -> float:
|
| return (x / 32767.0).astype(np.float32)
|
|
|
|
|
| def parse_yaml(config_yaml: str) -> Dict:
|
| r"""Parse yaml file.
|
|
|
| Args:
|
| config_yaml (str): config yaml path
|
|
|
| Returns:
|
| yaml_dict (Dict): parsed yaml file
|
| """
|
|
|
| with open(config_yaml, "r") as fr:
|
| return yaml.load(fr, Loader=yaml.FullLoader)
|
|
|
|
|
| def get_audioset632_id_to_lb(ontology_path: str) -> Dict:
|
| r"""Get AudioSet 632 classes ID to label mapping."""
|
|
|
| audioset632_id_to_lb = {}
|
|
|
| with open(ontology_path) as f:
|
| data_list = json.load(f)
|
|
|
| for e in data_list:
|
| audioset632_id_to_lb[e["id"]] = e["name"]
|
|
|
| return audioset632_id_to_lb
|
|
|
|
|
| def load_pretrained_panns(
|
| model_type: str,
|
| checkpoint_path: str,
|
| freeze: bool
|
| ) -> nn.Module:
|
| r"""Load pretrained pretrained audio neural networks (PANNs).
|
|
|
| Args:
|
| model_type: str, e.g., "Cnn14"
|
| checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth"
|
| freeze: bool
|
|
|
| Returns:
|
| model: nn.Module
|
| """
|
|
|
| if model_type == "Cnn14":
|
| Model = Cnn14
|
|
|
| elif model_type == "Cnn14_DecisionLevelMax":
|
| Model = Cnn14_DecisionLevelMax
|
|
|
| else:
|
| raise NotImplementedError
|
|
|
| model = Model(sample_rate=32000, window_size=1024, hop_size=320,
|
| mel_bins=64, fmin=50, fmax=14000, classes_num=527)
|
|
|
| if checkpoint_path:
|
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| model.load_state_dict(checkpoint["model"])
|
|
|
| if freeze:
|
| for param in model.parameters():
|
| param.requires_grad = False
|
|
|
| return model
|
|
|
|
|
| def energy(x):
|
| return torch.mean(x ** 2)
|
|
|
|
|
| def magnitude_to_db(x):
|
| eps = 1e-10
|
| return 20. * np.log10(max(x, eps))
|
|
|
|
|
| def db_to_magnitude(x):
|
| return 10. ** (x / 20)
|
|
|
|
|
| def ids_to_hots(ids, classes_num, device):
|
| hots = torch.zeros(classes_num).to(device)
|
| for id in ids:
|
| hots[id] = 1
|
| return hots
|
|
|
|
|
| def calculate_sdr(
|
| ref: np.ndarray,
|
| est: np.ndarray,
|
| eps=1e-10
|
| ) -> float:
|
| r"""Calculate SDR between reference and estimation.
|
|
|
| Args:
|
| ref (np.ndarray), reference signal
|
| est (np.ndarray), estimated signal
|
| """
|
| reference = ref
|
| noise = est - reference
|
|
|
|
|
| numerator = np.clip(a=np.mean(reference ** 2), a_min=eps, a_max=None)
|
|
|
| denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None)
|
|
|
| sdr = 10. * np.log10(numerator / denominator)
|
|
|
| return sdr
|
|
|
|
|
| def calculate_sisdr(ref, est):
|
| r"""Calculate SDR between reference and estimation.
|
|
|
| Args:
|
| ref (np.ndarray), reference signal
|
| est (np.ndarray), estimated signal
|
| """
|
|
|
| eps = np.finfo(ref.dtype).eps
|
|
|
| reference = ref.copy()
|
| estimate = est.copy()
|
|
|
| reference = reference.reshape(reference.size, 1)
|
| estimate = estimate.reshape(estimate.size, 1)
|
|
|
| Rss = np.dot(reference.T, reference)
|
|
|
| a = (eps + np.dot(reference.T, estimate)) / (Rss + eps)
|
|
|
| e_true = a * reference
|
| e_res = estimate - e_true
|
|
|
| Sss = (e_true**2).sum()
|
| Snn = (e_res**2).sum()
|
|
|
| sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn))
|
|
|
| return sisdr
|
|
|
|
|
| class StatisticsContainer(object):
|
| def __init__(self, statistics_path):
|
| self.statistics_path = statistics_path
|
|
|
| self.backup_statistics_path = "{}_{}.pkl".format(
|
| os.path.splitext(self.statistics_path)[0],
|
| datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
|
| )
|
|
|
| self.statistics_dict = {"balanced_train": [], "test": []}
|
|
|
| def append(self, steps, statistics, split, flush=True):
|
| statistics["steps"] = steps
|
| self.statistics_dict[split].append(statistics)
|
|
|
| if flush:
|
| self.flush()
|
|
|
| def flush(self):
|
| pickle.dump(self.statistics_dict, open(self.statistics_path, "wb"))
|
| pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb"))
|
| logging.info(" Dump statistics to {}".format(self.statistics_path))
|
| logging.info(" Dump statistics to {}".format(self.backup_statistics_path))
|
|
|
|
|
| def get_mean_sdr_from_dict(sdris_dict):
|
| mean_sdr = np.nanmean(list(sdris_dict.values()))
|
| return mean_sdr
|
|
|
|
|
| def remove_silence(audio: np.ndarray, sample_rate: int) -> np.ndarray:
|
| r"""Remove silent frames."""
|
| window_size = int(sample_rate * 0.1)
|
| threshold = 0.02
|
|
|
| frames = librosa.util.frame(x=audio, frame_length=window_size, hop_length=window_size).T
|
|
|
|
|
| new_frames = get_active_frames(frames, threshold)
|
|
|
|
|
| new_audio = new_frames.flatten()
|
|
|
|
|
| return new_audio
|
|
|
|
|
| def get_active_frames(frames: np.ndarray, threshold: float) -> np.ndarray:
|
| r"""Get active frames."""
|
|
|
| energy = np.max(np.abs(frames), axis=-1)
|
|
|
|
|
| active_indexes = np.where(energy > threshold)[0]
|
|
|
|
|
| new_frames = frames[active_indexes]
|
|
|
|
|
| return new_frames
|
|
|
|
|
| def repeat_to_length(audio: np.ndarray, segment_samples: int) -> np.ndarray:
|
| r"""Repeat audio to length."""
|
|
|
| repeats_num = (segment_samples // audio.shape[-1]) + 1
|
| audio = np.tile(audio, repeats_num)[0 : segment_samples]
|
|
|
| return audio
|
|
|
| def calculate_segmentwise_sdr(ref, est, hop_samples, return_sdr_list=False):
|
| min_len = min(ref.shape[-1], est.shape[-1])
|
| pointer = 0
|
| sdrs = []
|
| while pointer + hop_samples < min_len:
|
| sdr = calculate_sdr(
|
| ref=ref[:, pointer : pointer + hop_samples],
|
| est=est[:, pointer : pointer + hop_samples],
|
| )
|
| sdrs.append(sdr)
|
| pointer += hop_samples
|
|
|
| sdr = np.nanmedian(sdrs)
|
|
|
| if return_sdr_list:
|
| return sdr, sdrs
|
| else:
|
| return sdr
|
|
|
|
|
| def loudness(data, input_loudness, target_loudness):
|
| """ Loudness normalize a signal.
|
|
|
| Normalize an input signal to a user loudness in dB LKFS.
|
|
|
| Params
|
| -------
|
| data : torch.Tensor
|
| Input multichannel audio data.
|
| input_loudness : float
|
| Loudness of the input in dB LUFS.
|
| target_loudness : float
|
| Target loudness of the output in dB LUFS.
|
|
|
| Returns
|
| -------
|
| output : torch.Tensor
|
| Loudness normalized output data.
|
| """
|
|
|
|
|
| delta_loudness = target_loudness - input_loudness
|
| gain = torch.pow(10.0, delta_loudness / 20.0)
|
|
|
| output = gain * data
|
|
|
|
|
|
|
|
|
|
|
| return output
|
|
|
|
|
| def load_ss_model(
|
| configs: Dict,
|
| checkpoint_path: str,
|
| query_encoder: nn.Module
|
| ) -> nn.Module:
|
| r"""Load trained universal source separation model.
|
|
|
| Args:
|
| configs (Dict)
|
| checkpoint_path (str): path of the checkpoint to load
|
| device (str): e.g., "cpu" | "cuda"
|
|
|
| Returns:
|
| pl_model: pl.LightningModule
|
| """
|
|
|
| ss_model_type = configs["model"]["model_type"]
|
| input_channels = configs["model"]["input_channels"]
|
| output_channels = configs["model"]["output_channels"]
|
| condition_size = configs["model"]["condition_size"]
|
|
|
|
|
| SsModel = get_model_class(model_type=ss_model_type)
|
|
|
| ss_model = SsModel(
|
| input_channels=input_channels,
|
| output_channels=output_channels,
|
| condition_size=condition_size,
|
| )
|
|
|
|
|
| pl_model = AudioSep.load_from_checkpoint(
|
| checkpoint_path=checkpoint_path,
|
| strict=False,
|
| ss_model=ss_model,
|
| waveform_mixer=None,
|
| query_encoder=query_encoder,
|
| loss_function=None,
|
| optimizer_type=None,
|
| learning_rate=None,
|
| lr_lambda_func=None,
|
| map_location=torch.device('cpu'),
|
| )
|
|
|
| return pl_model
|
|
|
|
|
| def parse_yaml(config_yaml: str) -> Dict:
|
| r"""Parse yaml file.
|
|
|
| Args:
|
| config_yaml (str): config yaml path
|
|
|
| Returns:
|
| yaml_dict (Dict): parsed yaml file
|
| """
|
|
|
| with open(config_yaml, "r") as fr:
|
| return yaml.load(fr, Loader=yaml.FullLoader) |