Spaces:
Running
Running
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
import sys | |
import random | |
import math | |
import torchaudio | |
import torch | |
torchaudio.set_audio_backend("sox_io") | |
def db2amp(db): | |
return pow(10, db / 20) | |
def amp2db(amp): | |
return 20 * math.log10(amp) | |
def make_poly_distortion(conf): | |
"""Generate a db-domain ploynomial distortion function | |
f(x) = a * x^m * (1-x)^n + x | |
Args: | |
conf: a dict {'a': #int, 'm': #int, 'n': #int} | |
Returns: | |
The ploynomial function, which could be applied on | |
a float amplitude value | |
""" | |
a = conf["a"] | |
m = conf["m"] | |
n = conf["n"] | |
def poly_distortion(x): | |
abs_x = abs(x) | |
if abs_x < 0.000001: | |
x = x | |
else: | |
db_norm = amp2db(abs_x) / 100 + 1 | |
if db_norm < 0: | |
db_norm = 0 | |
db_norm = a * pow(db_norm, m) * pow((1 - db_norm), n) + db_norm | |
if db_norm > 1: | |
db_norm = 1 | |
db = (db_norm - 1) * 100 | |
amp = db2amp(db) | |
if amp >= 0.9997: | |
amp = 0.9997 | |
if x > 0: | |
x = amp | |
else: | |
x = -amp | |
return x | |
return poly_distortion | |
def make_quad_distortion(): | |
return make_poly_distortion({"a": 1, "m": 1, "n": 1}) | |
# the amplitude are set to max for all non-zero point | |
def make_max_distortion(conf): | |
"""Generate a max distortion function | |
Args: | |
conf: a dict {'max_db': float } | |
'max_db': the maxium value. | |
Returns: | |
The max function, which could be applied on | |
a float amplitude value | |
""" | |
max_db = conf["max_db"] | |
if max_db: | |
max_amp = db2amp(max_db) # < 0.997 | |
else: | |
max_amp = 0.997 | |
def max_distortion(x): | |
if x > 0: | |
x = max_amp | |
elif x < 0: | |
x = -max_amp | |
else: | |
x = 0.0 | |
return x | |
return max_distortion | |
def make_amp_mask(db_mask=None): | |
"""Get a amplitude domain mask from db domain mask | |
Args: | |
db_mask: Optional. A list of tuple. if None, using default value. | |
Returns: | |
A list of tuple. The amplitude domain mask | |
""" | |
if db_mask is None: | |
db_mask = [(-110, -95), (-90, -80), (-65, -60), (-50, -30), (-15, 0)] | |
amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask] | |
return amp_mask | |
default_mask = make_amp_mask() | |
def generate_amp_mask(mask_num): | |
"""Generate amplitude domain mask randomly in [-100db, 0db] | |
Args: | |
mask_num: the slot number of the mask | |
Returns: | |
A list of tuple. each tuple defines a slot. | |
e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)] | |
for #mask_num = 4 | |
""" | |
a = [0] * 2 * mask_num | |
a[0] = 0 | |
m = [] | |
for i in range(1, 2 * mask_num): | |
a[i] = a[i - 1] + random.uniform(0.5, 1) | |
max_val = a[2 * mask_num - 1] | |
for i in range(0, mask_num): | |
l = ((a[2 * i] - max_val) / max_val) * 100 | |
r = ((a[2 * i + 1] - max_val) / max_val) * 100 | |
m.append((l, r)) | |
return make_amp_mask(m) | |
def make_fence_distortion(conf): | |
"""Generate a fence distortion function | |
In this fence-like shape function, the values in mask slots are | |
set to maxium, while the values not in mask slots are set to 0. | |
Use seperated masks for Positive and negetive amplitude. | |
Args: | |
conf: a dict {'mask_number': int,'max_db': float } | |
'mask_number': the slot number in mask. | |
'max_db': the maxium value. | |
Returns: | |
The fence function, which could be applied on | |
a float amplitude value | |
""" | |
mask_number = conf["mask_number"] | |
max_db = conf["max_db"] | |
max_amp = db2amp(max_db) # 0.997 | |
if mask_number <= 0: | |
positive_mask = default_mask | |
negative_mask = make_amp_mask([(-50, 0)]) | |
else: | |
positive_mask = generate_amp_mask(mask_number) | |
negative_mask = generate_amp_mask(mask_number) | |
def fence_distortion(x): | |
is_in_mask = False | |
if x > 0: | |
for mask in positive_mask: | |
if x >= mask[0] and x <= mask[1]: | |
is_in_mask = True | |
return max_amp | |
if not is_in_mask: | |
return 0.0 | |
elif x < 0: | |
abs_x = abs(x) | |
for mask in negative_mask: | |
if abs_x >= mask[0] and abs_x <= mask[1]: | |
is_in_mask = True | |
return max_amp | |
if not is_in_mask: | |
return 0.0 | |
return x | |
return fence_distortion | |
# | |
def make_jag_distortion(conf): | |
"""Generate a jag distortion function | |
In this jag-like shape function, the values in mask slots are | |
not changed, while the values not in mask slots are set to 0. | |
Use seperated masks for Positive and negetive amplitude. | |
Args: | |
conf: a dict {'mask_number': #int} | |
'mask_number': the slot number in mask. | |
Returns: | |
The jag function,which could be applied on | |
a float amplitude value | |
""" | |
mask_number = conf["mask_number"] | |
if mask_number <= 0: | |
positive_mask = default_mask | |
negative_mask = make_amp_mask([(-50, 0)]) | |
else: | |
positive_mask = generate_amp_mask(mask_number) | |
negative_mask = generate_amp_mask(mask_number) | |
def jag_distortion(x): | |
is_in_mask = False | |
if x > 0: | |
for mask in positive_mask: | |
if x >= mask[0] and x <= mask[1]: | |
is_in_mask = True | |
return x | |
if not is_in_mask: | |
return 0.0 | |
elif x < 0: | |
abs_x = abs(x) | |
for mask in negative_mask: | |
if abs_x >= mask[0] and abs_x <= mask[1]: | |
is_in_mask = True | |
return x | |
if not is_in_mask: | |
return 0.0 | |
return x | |
return jag_distortion | |
# gaining 20db means amp = amp * 10 | |
# gaining -20db means amp = amp / 10 | |
def make_gain_db(conf): | |
"""Generate a db domain gain function | |
Args: | |
conf: a dict {'db': #float} | |
'db': the gaining value | |
Returns: | |
The db gain function, which could be applied on | |
a float amplitude value | |
""" | |
db = conf["db"] | |
def gain_db(x): | |
return min(0.997, x * pow(10, db / 20)) | |
return gain_db | |
def distort(x, func, rate=0.8): | |
"""Distort a waveform in sample point level | |
Args: | |
x: the origin wavefrom | |
func: the distort function | |
rate: sample point-level distort probability | |
Returns: | |
the distorted waveform | |
""" | |
for i in range(0, x.shape[1]): | |
a = random.uniform(0, 1) | |
if a < rate: | |
x[0][i] = func(float(x[0][i])) | |
return x | |
def distort_chain(x, funcs, rate=0.8): | |
for i in range(0, x.shape[1]): | |
a = random.uniform(0, 1) | |
if a < rate: | |
for func in funcs: | |
x[0][i] = func(float(x[0][i])) | |
return x | |
# x is numpy | |
def distort_wav_conf(x, distort_type, distort_conf, rate=0.1): | |
if distort_type == "gain_db": | |
gain_db = make_gain_db(distort_conf) | |
x = distort(x, gain_db) | |
elif distort_type == "max_distortion": | |
max_distortion = make_max_distortion(distort_conf) | |
x = distort(x, max_distortion, rate=rate) | |
elif distort_type == "fence_distortion": | |
fence_distortion = make_fence_distortion(distort_conf) | |
x = distort(x, fence_distortion, rate=rate) | |
elif distort_type == "jag_distortion": | |
jag_distortion = make_jag_distortion(distort_conf) | |
x = distort(x, jag_distortion, rate=rate) | |
elif distort_type == "poly_distortion": | |
poly_distortion = make_poly_distortion(distort_conf) | |
x = distort(x, poly_distortion, rate=rate) | |
elif distort_type == "quad_distortion": | |
quad_distortion = make_quad_distortion() | |
x = distort(x, quad_distortion, rate=rate) | |
elif distort_type == "none_distortion": | |
pass | |
else: | |
print("unsupport type") | |
return x | |
def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, wav_out): | |
x, sr = torchaudio.load(wav_in) | |
x = x.detach().numpy() | |
out = distort_wav_conf(x, distort_type, distort_conf, rate) | |
torchaudio.save(wav_out, torch.from_numpy(out), sr) | |
if __name__ == "__main__": | |
distort_type = sys.argv[1] | |
wav_in = sys.argv[2] | |
wav_out = sys.argv[3] | |
conf = None | |
rate = 0.1 | |
if distort_type == "new_jag_distortion": | |
conf = {"mask_number": 4} | |
elif distort_type == "new_fence_distortion": | |
conf = {"mask_number": 1, "max_db": -30} | |
elif distort_type == "poly_distortion": | |
conf = {"a": 4, "m": 2, "n": 2} | |
distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out) | |