Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| """ | |
| Implementation of differentiable mastering effects based on DASP-pytorch and torchcomp libraries | |
| - Distortion | |
| - Multiband Compressor | |
| - Limiter | |
| DASP-pytorch: https://github.com/csteinmetz1/dasp-pytorch | |
| torchcomp: https://github.com/yoyololicon/torchcomp | |
| """ | |
| import dasp_pytorch | |
| from dasp_pytorch.modules import Processor | |
| import torchcomp | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import numpy as np | |
| import time | |
| EPS = 1e-6 | |
| class Distortion(Processor): | |
| def __init__( | |
| self, | |
| sample_rate: int, | |
| min_gain_db: float = 0.0, | |
| max_gain_db: float = 24.0, | |
| ): | |
| super().__init__() | |
| self.sample_rate = sample_rate | |
| self.process_fn = distortion | |
| self.param_ranges = { | |
| "drive_db": (min_gain_db, max_gain_db), | |
| "parallel_weight_factor": (0.2, 0.7), | |
| } | |
| self.num_params = len(self.param_ranges) | |
| def distortion(x: torch.Tensor, | |
| sample_rate: int, | |
| drive_db: torch.Tensor, | |
| parallel_weight_factor: torch.Tensor()): | |
| """Simple soft-clipping distortion with drive control. | |
| Args: | |
| x (torch.Tensor): Input audio tensor with shape (bs, chs, seq_len) | |
| sample_rate (int): Audio sample rate. | |
| drive_db (torch.Tensor): Drive in dB with shape (bs) | |
| Returns: | |
| torch.Tensor: Output audio tensor with shape (bs, chs, seq_len) | |
| """ | |
| bs, chs, seq_len = x.size() | |
| parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1) | |
| # return torch.tanh(x * (10 ** (drive_db.view(bs, chs, -1) / 20.0))) -> wrong? | |
| x_dist = torch.tanh(x * (10 ** (drive_db.view(bs, 1, 1) / 20.0))) | |
| # parallel compuatation | |
| return parallel_weight_factor * x_dist + (1-parallel_weight_factor) * x | |
| class Multiband_Compressor(Processor): | |
| def __init__( | |
| self, | |
| sample_rate: int, | |
| min_threshold_db_comp: float = -60.0, | |
| max_threshold_db_comp: float = 0.0-EPS, | |
| min_ratio_comp: float = 1.0+EPS, | |
| max_ratio_comp: float = 20.0, | |
| min_attack_ms_comp: float = 5.0, | |
| max_attack_ms_comp: float = 100.0, | |
| min_release_ms_comp: float = 5.0, | |
| max_release_ms_comp: float = 100.0, | |
| min_threshold_db_exp: float = -60.0, | |
| max_threshold_db_exp: float = 0.0-EPS, | |
| min_ratio_exp: float = 0.0+EPS, | |
| max_ratio_exp: float = 1.0-EPS, | |
| min_attack_ms_exp: float = 5.0, | |
| max_attack_ms_exp: float = 100.0, | |
| min_release_ms_exp: float = 5.0, | |
| max_release_ms_exp: float = 100.0, | |
| ): | |
| super().__init__() | |
| self.sample_rate = sample_rate | |
| self.process_fn = multiband_compressor | |
| self.param_ranges = { | |
| "low_cutoff": (20, 300), | |
| "high_cutoff": (2000, 12000), | |
| "parallel_weight_factor": (0.2, 0.7), | |
| "low_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp), | |
| "low_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp), | |
| "low_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp), | |
| "low_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp), | |
| "low_shelf_at": (min_attack_ms_exp, max_attack_ms_exp), | |
| "low_shelf_rt": (min_release_ms_exp, max_release_ms_exp), | |
| "mid_band_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp), | |
| "mid_band_comp_ratio": (min_ratio_comp, max_ratio_comp), | |
| "mid_band_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp), | |
| "mid_band_exp_ratio": (min_ratio_exp, max_ratio_exp), | |
| "mid_band_at": (min_attack_ms_exp, max_attack_ms_exp), | |
| "mid_band_rt": (min_release_ms_exp, max_release_ms_exp), | |
| "high_shelf_comp_thresh": (min_threshold_db_comp, max_threshold_db_comp), | |
| "high_shelf_comp_ratio": (min_ratio_comp, max_ratio_comp), | |
| "high_shelf_exp_thresh": (min_threshold_db_exp, max_threshold_db_exp), | |
| "high_shelf_exp_ratio": (min_ratio_exp, max_ratio_exp), | |
| "high_shelf_at": (min_attack_ms_exp, max_attack_ms_exp), | |
| "high_shelf_rt": (min_release_ms_exp, max_release_ms_exp), | |
| } | |
| self.num_params = len(self.param_ranges) | |
| def linkwitz_riley_4th_order( | |
| x: torch.Tensor, | |
| cutoff_freq: torch.Tensor, | |
| sample_rate: float, | |
| filter_type: str): | |
| q_factor = torch.ones(cutoff_freq.shape) / torch.sqrt(torch.tensor([2.0])) | |
| gain_db = torch.zeros(cutoff_freq.shape) | |
| q_factor = q_factor.to(x.device) | |
| gain_db = gain_db.to(x.device) | |
| b, a = dasp_pytorch.signal.biquad( | |
| gain_db, | |
| cutoff_freq, | |
| q_factor, | |
| sample_rate, | |
| filter_type | |
| ) | |
| del gain_db | |
| del q_factor | |
| eff_bs = x.size(0) | |
| # six second order sections | |
| sos = torch.cat((b, a), dim=-1).unsqueeze(1) | |
| # apply filter twice to phase difference amounts of 360° | |
| x = dasp_pytorch.signal.sosfilt_via_fsm(sos, x) | |
| x_out = dasp_pytorch.signal.sosfilt_via_fsm(sos, x) | |
| return x_out | |
| def multiband_compressor( | |
| x: torch.Tensor, | |
| sample_rate: float, | |
| low_cutoff: torch.Tensor, | |
| high_cutoff: torch.Tensor, | |
| parallel_weight_factor: torch.Tensor, | |
| low_shelf_comp_thresh: torch.Tensor, | |
| low_shelf_comp_ratio: torch.Tensor, | |
| low_shelf_exp_thresh: torch.Tensor, | |
| low_shelf_exp_ratio: torch.Tensor, | |
| low_shelf_at: torch.Tensor, | |
| low_shelf_rt: torch.Tensor, | |
| mid_band_comp_thresh: torch.Tensor, | |
| mid_band_comp_ratio: torch.Tensor, | |
| mid_band_exp_thresh: torch.Tensor, | |
| mid_band_exp_ratio: torch.Tensor, | |
| mid_band_at: torch.Tensor, | |
| mid_band_rt: torch.Tensor, | |
| high_shelf_comp_thresh: torch.Tensor, | |
| high_shelf_comp_ratio: torch.Tensor, | |
| high_shelf_exp_thresh: torch.Tensor, | |
| high_shelf_exp_ratio: torch.Tensor, | |
| high_shelf_at: torch.Tensor, | |
| high_shelf_rt: torch.Tensor, | |
| ): | |
| """Multiband (Three-band) Compressor. | |
| Low-shelf -> Mid-band -> High-shelf | |
| Args: | |
| x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len) | |
| sample_rate (float): Audio sample rate. | |
| low_cutoff (torch.Tensor): Low-shelf filter cutoff frequency in Hz. | |
| high_cutoff (torch.Tensor): High-shelf filter cutoff frequency in Hz. | |
| low_shelf_comp_thresh (torch.Tensor): | |
| low_shelf_comp_ratio (torch.Tensor): | |
| low_shelf_exp_thresh (torch.Tensor): | |
| low_shelf_exp_ratio (torch.Tensor): | |
| low_shelf_at (torch.Tensor): | |
| low_shelf_rt (torch.Tensor): | |
| mid_band_comp_thresh (torch.Tensor): | |
| mid_band_comp_ratio (torch.Tensor): | |
| mid_band_exp_thresh (torch.Tensor): | |
| mid_band_exp_ratio (torch.Tensor): | |
| mid_band_at (torch.Tensor): | |
| mid_band_rt (torch.Tensor): | |
| high_shelf_comp_thresh (torch.Tensor): | |
| high_shelf_comp_ratio (torch.Tensor): | |
| high_shelf_exp_thresh (torch.Tensor): | |
| high_shelf_exp_ratio (torch.Tensor): | |
| high_shelf_at (torch.Tensor): | |
| high_shelf_rt (torch.Tensor): | |
| Returns: | |
| y (torch.Tensor): Filtered signal. | |
| """ | |
| bs, chs, seq_len = x.size() | |
| low_cutoff = low_cutoff.view(-1, 1, 1) | |
| high_cutoff = high_cutoff.view(-1, 1, 1) | |
| parallel_weight_factor = parallel_weight_factor.view(-1, 1, 1) | |
| eff_bs = x.size(0) | |
| ''' cross over filter ''' | |
| # Low-shelf band (low frequencies) | |
| low_band = linkwitz_riley_4th_order(x, low_cutoff, sample_rate, filter_type="low_pass") | |
| # High-shelf band (high frequencies) | |
| high_band = linkwitz_riley_4th_order(x, high_cutoff, sample_rate, filter_type="high_pass") | |
| # Mid-band (band-pass) | |
| mid_band = x - low_band - high_band # Subtract low and high bands from original signal | |
| ''' compressor ''' | |
| try: | |
| x_out_low = low_band * torchcomp.compexp_gain(low_band.sum(axis=1).abs(), | |
| comp_thresh=low_shelf_comp_thresh, \ | |
| comp_ratio=low_shelf_comp_ratio, \ | |
| exp_thresh=low_shelf_exp_thresh, \ | |
| exp_ratio=low_shelf_exp_ratio, \ | |
| at=torchcomp.ms2coef(low_shelf_at, sample_rate), \ | |
| rt=torchcomp.ms2coef(low_shelf_rt, sample_rate)).unsqueeze(1) | |
| except: | |
| x_out_low = low_band | |
| print('\t!!!failed computing low-band compression!!!') | |
| try: | |
| x_out_high = high_band * torchcomp.compexp_gain(high_band.sum(axis=1).abs(), | |
| comp_thresh=high_shelf_comp_thresh, \ | |
| comp_ratio=high_shelf_comp_ratio, \ | |
| exp_thresh=high_shelf_exp_thresh, \ | |
| exp_ratio=high_shelf_exp_ratio, \ | |
| at=torchcomp.ms2coef(high_shelf_at, sample_rate), \ | |
| rt=torchcomp.ms2coef(high_shelf_rt, sample_rate)).unsqueeze(1) | |
| except: | |
| x_out_high = high_band | |
| print('\t!!!failed computing high-band compression!!!') | |
| try: | |
| x_out_mid = mid_band * torchcomp.compexp_gain(mid_band.sum(axis=1).abs(), | |
| comp_thresh=mid_band_comp_thresh, \ | |
| comp_ratio=mid_band_comp_ratio, \ | |
| exp_thresh=mid_band_exp_thresh, \ | |
| exp_ratio=mid_band_exp_ratio, \ | |
| at=torchcomp.ms2coef(mid_band_at, sample_rate), \ | |
| rt=torchcomp.ms2coef(mid_band_rt, sample_rate)).unsqueeze(1) | |
| except: | |
| x_out_mid = mid_band | |
| print('\t!!!failed computing mid-band compression!!!') | |
| x_out = x_out_low + x_out_high + x_out_mid | |
| # parallel computation | |
| x_out = parallel_weight_factor * x_out + (1-parallel_weight_factor) * x | |
| # move channels back | |
| x_out = x_out.view(bs, chs, seq_len) | |
| return x_out | |
| class Limiter(Processor): | |
| def __init__( | |
| self, | |
| sample_rate: int, | |
| min_threshold_db: float = -60.0, | |
| max_threshold_db: float = 0.0-EPS, | |
| min_attack_ms: float = 5.0, | |
| max_attack_ms: float = 100.0, | |
| min_release_ms: float = 5.0, | |
| max_release_ms: float = 100.0, | |
| ): | |
| super().__init__() | |
| self.sample_rate = sample_rate | |
| self.process_fn = limiter | |
| self.param_ranges = { | |
| "threshold": (min_threshold_db, max_threshold_db), | |
| "at": (min_attack_ms, max_attack_ms), | |
| "rt": (min_release_ms, max_release_ms), | |
| } | |
| self.num_params = len(self.param_ranges) | |
| def limiter( | |
| x: torch.Tensor, | |
| sample_rate: float, | |
| threshold: float, | |
| at: float, | |
| rt: float, | |
| ): | |
| """Limiter. | |
| from Chin-yun's paper | |
| Args: | |
| x (torch.Tensor): Time domain tensor with shape (bs, chs, seq_len) | |
| sample_rate (float): Audio sample rate. | |
| threshold (torch.Tensor): Limiter threshold in dB. | |
| at (torch.Tensor): Attack time. | |
| rt (torch.Tensor): Release time. | |
| Returns: | |
| y (torch.Tensor): Limited signal. | |
| """ | |
| bs, chs, seq_len = x.size() | |
| x_out = x * torchcomp.limiter_gain(x.sum(axis=1).abs(), | |
| threshold=threshold, | |
| at=torchcomp.ms2coef(at, sample_rate), | |
| rt=torchcomp.ms2coef(rt, sample_rate)).unsqueeze(1) | |
| # move channels back | |
| x_out = x_out.view(bs, chs, seq_len) | |
| return x_out | |
| class Random_Augmentation_Dasp(nn.Module): | |
| def __init__(self, sample_rate, \ | |
| tgt_fx_names = ['eq', 'comp', 'imager', 'gain']): | |
| super(Random_Augmentation_Dasp, self).__init__() | |
| self.sample_rate = sample_rate | |
| self.tgt_fx_names = tgt_fx_names | |
| self.device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| self.device = torch.device(f"cuda") | |
| self.fx_prob = {'eq': 0.9, \ | |
| 'distortion': 0.3, \ | |
| 'comp': 0.8, \ | |
| 'multiband_comp': 0.8, \ | |
| 'gain': 0.85, \ | |
| 'imager': 0.6, \ | |
| 'limiter': 1.0} | |
| self.fx_processors = {} | |
| for cur_fx in tgt_fx_names: | |
| if cur_fx=='eq': | |
| cur_fx_module = dasp_pytorch.ParametricEQ(sample_rate=sample_rate, \ | |
| min_gain_db = -10.0, \ | |
| max_gain_db = 10.0, \ | |
| min_q_factor = 0.5, \ | |
| max_q_factor=5.0) | |
| elif cur_fx=='distortion': | |
| cur_fx_module = Distortion(sample_rate=sample_rate, | |
| min_gain_db = 0.0, | |
| max_gain_db = 4.0) | |
| elif cur_fx=='comp': | |
| cur_fx_module = dasp_pytorch.Compressor(sample_rate=sample_rate) | |
| elif cur_fx=='multiband_comp': | |
| cur_fx_module = Multiband_Compressor(sample_rate=sample_rate, | |
| min_threshold_db_comp = -30.0, | |
| max_threshold_db_comp = -5.0, | |
| min_ratio_comp = 1.5, | |
| max_ratio_comp = 6.0, | |
| min_attack_ms_comp = 1.0, | |
| max_attack_ms_comp = 20.0, | |
| min_release_ms_comp = 20.0, | |
| max_release_ms_comp = 500.0, | |
| min_threshold_db_exp = -30.0, | |
| max_threshold_db_exp = -5.0, | |
| min_ratio_exp = 0.0+EPS, | |
| max_ratio_exp = 1.0-EPS, | |
| min_attack_ms_exp = 1.0, | |
| max_attack_ms_exp = 20.0, | |
| min_release_ms_exp = 20.0, | |
| max_release_ms_exp = 500.0, | |
| ) | |
| elif cur_fx=='gain': | |
| cur_fx_module = dasp_pytorch.Gain(sample_rate=sample_rate, | |
| min_gain_db = 0.0, | |
| max_gain_db = 6.0,) | |
| elif cur_fx=='imager': | |
| continue | |
| elif cur_fx=='limiter': | |
| cur_fx_module = Limiter(sample_rate=sample_rate, | |
| min_threshold_db = -20.0, | |
| max_threshold_db = 0.0-EPS, | |
| min_attack_ms = 0.1, | |
| max_attack_ms = 5.0, | |
| min_release_ms = 20.0, | |
| max_release_ms = 1000.0,) | |
| else: | |
| raise AssertionError(f"current fx name ({cur_fx}) not found") | |
| self.fx_processors[cur_fx] = cur_fx_module | |
| total_num_param = sum([self.fx_processors[cur_fx].num_params for cur_fx in self.fx_processors]) | |
| if 'imager' in tgt_fx_names: | |
| total_num_param += 1 | |
| self.total_num_param = total_num_param | |
| # network forward operation | |
| def forward(self, x, rand_param=None, use_mask=None): | |
| if rand_param==None: | |
| rand_param = torch.rand((x.shape[0], self.total_num_param)).to(self.device) | |
| else: | |
| assert rand_param.shape[0]==x.shape[0] and rand_param.shape[1]==self.total_num_param | |
| if use_mask==None: | |
| use_mask = self.random_mask_generator(x.shape[0]) | |
| # dafx chain | |
| cur_param_idx = 0 | |
| for cur_fx in self.tgt_fx_names: | |
| cur_param_count = 1 if cur_fx=='imager' else self.fx_processors[cur_fx].num_params | |
| if cur_fx=='imager': | |
| x_processed = dasp_pytorch.functional.stereo_widener(x, \ | |
| sample_rate=self.sample_rate, \ | |
| width=rand_param[:,cur_param_idx:cur_param_idx+1]) | |
| else: | |
| cur_input_param = rand_param[:, cur_param_idx:cur_param_idx+cur_param_count] | |
| x_processed = self.fx_processors[cur_fx].process_normalized(x, cur_input_param) | |
| # process all FX but decide to use the processed output based on probability | |
| cur_mask = use_mask[cur_fx] | |
| x = x_processed*cur_mask + x*~cur_mask | |
| # update param index | |
| cur_param_idx += cur_param_count | |
| return x | |
| def random_mask_generator(self, batch_size, repeat=1): | |
| mask = {} | |
| for cur_fx in self.tgt_fx_names: | |
| mask[cur_fx] = self.fx_prob[cur_fx] > torch.rand(batch_size).view(-1, 1, 1) | |
| if repeat>1: | |
| mask[cur_fx] = mask[cur_fx].repeat(repeat, 1, 1) | |
| mask[cur_fx] = mask[cur_fx].to(self.device) | |
| return mask | |