import os import datetime import json import logging import librosa import pickle from typing import Dict import numpy as np import torch import torch.nn as nn import yaml from models.audiosep import AudioSep, get_model_class def ignore_warnings(): import warnings # Ignore UserWarning from torch.meshgrid warnings.filterwarnings('ignore', category=UserWarning, module='torch.functional') # Refined regex pattern to capture variations in the warning message pattern = r"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: \['lm_head\..*'\].*" warnings.filterwarnings('ignore', message=pattern) def create_logging(log_dir, filemode): os.makedirs(log_dir, exist_ok=True) i1 = 0 while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): i1 += 1 log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) logging.basicConfig( level=logging.DEBUG, format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", datefmt="%a, %d %b %Y %H:%M:%S", filename=log_path, filemode=filemode, ) # Print to console console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") console.setFormatter(formatter) logging.getLogger("").addHandler(console) return logging def float32_to_int16(x: float) -> int: x = np.clip(x, a_min=-1, a_max=1) return (x * 32767.0).astype(np.int16) def int16_to_float32(x: int) -> float: return (x / 32767.0).astype(np.float32) def parse_yaml(config_yaml: str) -> Dict: r"""Parse yaml file. Args: config_yaml (str): config yaml path Returns: yaml_dict (Dict): parsed yaml file """ with open(config_yaml, "r") as fr: return yaml.load(fr, Loader=yaml.FullLoader) def get_audioset632_id_to_lb(ontology_path: str) -> Dict: r"""Get AudioSet 632 classes ID to label mapping.""" audioset632_id_to_lb = {} with open(ontology_path) as f: data_list = json.load(f) for e in data_list: audioset632_id_to_lb[e["id"]] = e["name"] return audioset632_id_to_lb def load_pretrained_panns( model_type: str, checkpoint_path: str, freeze: bool ) -> nn.Module: r"""Load pretrained pretrained audio neural networks (PANNs). Args: model_type: str, e.g., "Cnn14" checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth" freeze: bool Returns: model: nn.Module """ if model_type == "Cnn14": Model = Cnn14 elif model_type == "Cnn14_DecisionLevelMax": Model = Cnn14_DecisionLevelMax else: raise NotImplementedError model = Model(sample_rate=32000, window_size=1024, hop_size=320, mel_bins=64, fmin=50, fmax=14000, classes_num=527) if checkpoint_path: checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) if freeze: for param in model.parameters(): param.requires_grad = False return model def energy(x): return torch.mean(x ** 2) def magnitude_to_db(x): eps = 1e-10 return 20. * np.log10(max(x, eps)) def db_to_magnitude(x): return 10. ** (x / 20) def ids_to_hots(ids, classes_num, device): hots = torch.zeros(classes_num).to(device) for id in ids: hots[id] = 1 return hots def calculate_sdr( ref: np.ndarray, est: np.ndarray, eps=1e-10 ) -> float: r"""Calculate SDR between reference and estimation. Args: ref (np.ndarray), reference signal est (np.ndarray), estimated signal """ reference = ref noise = est - reference numerator = np.clip(a=np.mean(reference ** 2), a_min=eps, a_max=None) denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) sdr = 10. * np.log10(numerator / denominator) return sdr def calculate_sisdr(ref, est): r"""Calculate SDR between reference and estimation. Args: ref (np.ndarray), reference signal est (np.ndarray), estimated signal """ eps = np.finfo(ref.dtype).eps reference = ref.copy() estimate = est.copy() reference = reference.reshape(reference.size, 1) estimate = estimate.reshape(estimate.size, 1) Rss = np.dot(reference.T, reference) # get the scaling factor for clean sources a = (eps + np.dot(reference.T, estimate)) / (Rss + eps) e_true = a * reference e_res = estimate - e_true Sss = (e_true**2).sum() Snn = (e_res**2).sum() sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn)) return sisdr class StatisticsContainer(object): def __init__(self, statistics_path): self.statistics_path = statistics_path self.backup_statistics_path = "{}_{}.pkl".format( os.path.splitext(self.statistics_path)[0], datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), ) self.statistics_dict = {"balanced_train": [], "test": []} def append(self, steps, statistics, split, flush=True): statistics["steps"] = steps self.statistics_dict[split].append(statistics) if flush: self.flush() def flush(self): pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) logging.info(" Dump statistics to {}".format(self.statistics_path)) logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) def get_mean_sdr_from_dict(sdris_dict): mean_sdr = np.nanmean(list(sdris_dict.values())) return mean_sdr def remove_silence(audio: np.ndarray, sample_rate: int) -> np.ndarray: r"""Remove silent frames.""" window_size = int(sample_rate * 0.1) threshold = 0.02 frames = librosa.util.frame(x=audio, frame_length=window_size, hop_length=window_size).T # shape: (frames_num, window_size) new_frames = get_active_frames(frames, threshold) # shape: (new_frames_num, window_size) new_audio = new_frames.flatten() # shape: (new_audio_samples,) return new_audio def get_active_frames(frames: np.ndarray, threshold: float) -> np.ndarray: r"""Get active frames.""" energy = np.max(np.abs(frames), axis=-1) # shape: (frames_num,) active_indexes = np.where(energy > threshold)[0] # shape: (new_frames_num,) new_frames = frames[active_indexes] # shape: (new_frames_num,) return new_frames def repeat_to_length(audio: np.ndarray, segment_samples: int) -> np.ndarray: r"""Repeat audio to length.""" repeats_num = (segment_samples // audio.shape[-1]) + 1 audio = np.tile(audio, repeats_num)[0 : segment_samples] return audio def calculate_segmentwise_sdr(ref, est, hop_samples, return_sdr_list=False): min_len = min(ref.shape[-1], est.shape[-1]) pointer = 0 sdrs = [] while pointer + hop_samples < min_len: sdr = calculate_sdr( ref=ref[:, pointer : pointer + hop_samples], est=est[:, pointer : pointer + hop_samples], ) sdrs.append(sdr) pointer += hop_samples sdr = np.nanmedian(sdrs) if return_sdr_list: return sdr, sdrs else: return sdr def loudness(data, input_loudness, target_loudness): """ Loudness normalize a signal. Normalize an input signal to a user loudness in dB LKFS. Params ------- data : torch.Tensor Input multichannel audio data. input_loudness : float Loudness of the input in dB LUFS. target_loudness : float Target loudness of the output in dB LUFS. Returns ------- output : torch.Tensor Loudness normalized output data. """ # calculate the gain needed to scale to the desired loudness level delta_loudness = target_loudness - input_loudness gain = torch.pow(10.0, delta_loudness / 20.0) output = gain * data # check for potentially clipped samples # if torch.max(torch.abs(output)) >= 1.0: # warnings.warn("Possible clipped samples in output.") return output def load_ss_model( configs: Dict, checkpoint_path: str, query_encoder: nn.Module ) -> nn.Module: r"""Load trained universal source separation model. Args: configs (Dict) checkpoint_path (str): path of the checkpoint to load device (str): e.g., "cpu" | "cuda" Returns: pl_model: pl.LightningModule """ ss_model_type = configs["model"]["model_type"] input_channels = configs["model"]["input_channels"] output_channels = configs["model"]["output_channels"] condition_size = configs["model"]["condition_size"] # Initialize separation model SsModel = get_model_class(model_type=ss_model_type) ss_model = SsModel( input_channels=input_channels, output_channels=output_channels, condition_size=condition_size, ) # Load PyTorch Lightning model pl_model = AudioSep.load_from_checkpoint( checkpoint_path=checkpoint_path, strict=False, ss_model=ss_model, waveform_mixer=None, query_encoder=query_encoder, loss_function=None, optimizer_type=None, learning_rate=None, lr_lambda_func=None, map_location=torch.device('cpu'), ) return pl_model def parse_yaml(config_yaml: str) -> Dict: r"""Parse yaml file. Args: config_yaml (str): config yaml path Returns: yaml_dict (Dict): parsed yaml file """ with open(config_yaml, "r") as fr: return yaml.load(fr, Loader=yaml.FullLoader)