lshzhm's picture
Upload 141 files
1991049 verified
import torch
import torchaudio
import random
import itertools
import numpy as np
####from tools.mix import mix
from e2_tts_pytorch.mix import mix
import time
import traceback
import os
#from datasets import load_dataset
####from transformers import ClapModel, ClapProcessor
####clap = ClapModel.from_pretrained("/ckptstorage/zhanghaomin/models/EnCLAP/larger_clap_general/").to("cpu")
####clap.eval()
####for param in clap.parameters():
#### param.requires_grad = False
####clap_processor = ClapProcessor.from_pretrained("/ckptstorage/zhanghaomin/models/EnCLAP/larger_clap_general/")
#from msclap import CLAP
#clap_model = CLAP("/ckptstorage/zhanghaomin/models/msclap/clapcap_weights_2023.pth", version="clapcap", use_cuda=False)
#clap_model.clapcap.eval()
#for param in clap_model.clapcap.parameters():
# param.requires_grad = False
#new_freq = 16000
#hop_size = 160
new_freq = 24000
#hop_size = 256
hop_size = 320
#total_length = 1024
#MIN_TARGET_LEN = 281
#MAX_TARGET_LEN = 937
total_length = 750
MIN_TARGET_LEN = 750
MAX_TARGET_LEN = 750
#LEN_D = 1
LEN_D = 0
clap_freq = 48000
msclap_freq = 44100
max_len_in_seconds = 10
max_len_in_seconds_msclap = 7
#period_length = 30
period_length = 7
cut_length = 10
def normalize_wav(waveform):
waveform = waveform - torch.mean(waveform)
waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8)
return waveform * 0.5
def _pad_spec(fbank, target_length=total_length):
batch, n_frames, channels = fbank.shape
p = target_length - n_frames
if p > 0:
pad = torch.zeros(batch, p, channels).to(fbank.device)
fbank = torch.cat([fbank, pad], 1)
elif p < 0:
fbank = fbank[:, :target_length, :]
if channels % 2 != 0:
fbank = fbank[:, :, :-1]
return fbank
SCORE_THRESHOLD_VAL = 0.15
#SCORE_THRESHOLD_TRAIN = {
# "/zhanghaomin/datas/audiocaps": -np.inf,
# "/radiostorage/WavCaps": -np.inf,
# "/radiostorage/AudioGroup": -np.inf,
# "/ckptstorage/zhanghaomin/audioset": -np.inf,
# "/ckptstorage/zhanghaomin/BBCSoundEffects": -np.inf,
# "/ckptstorage/zhanghaomin/CLAP_freesound": -np.inf,
#}
SOUNDEFFECT = {
"/zhanghaomin/datas/audiocaps": False,
"/radiostorage/WavCaps": False,
"/radiostorage/AudioGroup": True,
"/ckptstorage/zhanghaomin/audioset": False,
"/ckptstorage/zhanghaomin/BBCSoundEffects": False,
"/ckptstorage/zhanghaomin/CLAP_freesound": False,
"/zhanghaomin/datas/musiccap": False,
"/ckptstorage/zhanghaomin/TangoPromptBank": False,
"/ckptstorage/zhanghaomin/audiosetsl": False,
"/ckptstorage/zhanghaomin/giantsoundeffects": True,
}
FILTER_NUM = {
"/zhanghaomin/datas/audiocaps": [0,0],
"/radiostorage/WavCaps": [0,0],
"/radiostorage/AudioGroup": [0,0],
"/ckptstorage/zhanghaomin/audioset": [0,0],
"/ckptstorage/zhanghaomin/BBCSoundEffects": [0,0],
"/ckptstorage/zhanghaomin/CLAP_freesound": [0,0],
"/zhanghaomin/datas/musiccap": [0,0],
"/ckptstorage/zhanghaomin/TangoPromptBank": [0,0],
"/ckptstorage/zhanghaomin/audiosetsl": [0,0],
"/ckptstorage/zhanghaomin/giantsoundeffects": [0,0],
}
TURNOFF_CLAP_FILTER_GLOBAL = False
def pad_wav(waveform, segment_length, text, prefix, val):
waveform_length = waveform.shape[1]
if segment_length is None or waveform_length == segment_length:
return waveform, text
elif waveform_length > segment_length:
return waveform[:, :segment_length], text
else:
if val:
if (not SOUNDEFFECT[prefix]) or (waveform_length > segment_length / 3.0):
pad_wav = torch.zeros((waveform.shape[0], segment_length-waveform_length)).to(waveform.device)
waveform = torch.cat([waveform, pad_wav], 1)
return waveform, text
else:
min_repeats = max(int(segment_length / 3.0 // waveform_length), 2)
max_repeats = segment_length // waveform_length
if val:
repeats = (min_repeats + max_repeats) // 2
else:
repeats = random.randint(min_repeats, max_repeats)
waveform = torch.cat([waveform]*repeats, 1)
if waveform.shape[1] < segment_length:
pad_wav = torch.zeros((waveform.shape[0], segment_length-waveform.shape[1])).to(waveform.device)
waveform = torch.cat([waveform, pad_wav], 1)
#if text[-1] in [",", "."]:
# text = text[:-1] + " repeat " + str(repeats) + " times" + text[-1]
#else:
# text = text + " repeat " + str(repeats) + " times"
return waveform, text
else:
repeats = segment_length // waveform_length + 1
waveform = torch.cat([waveform]*repeats, 1)
assert(waveform.shape[0] == 1 and waveform.shape[1] >= segment_length)
return waveform[:, :segment_length], text
def msclap_generate(waveform, freq):
waveform_msclap = torchaudio.functional.resample(waveform, orig_freq=freq, new_freq=msclap_freq)[0]
start = 0
end = waveform_msclap.shape[0]
if waveform_msclap.shape[0] > msclap_freq*max_len_in_seconds_msclap:
start = random.randint(waveform_msclap.shape[0]-msclap_freq*max_len_in_seconds_msclap)
end = start+msclap_freq*max_len_in_seconds_msclap
waveform_msclap = waveform_msclap[start: end]
if waveform_msclap.shape[0] < msclap_freq*max_len_in_seconds_msclap:
waveform_msclap = torch.cat([waveform_msclap, torch.zeros(msclap_freq*max_len_in_seconds_msclap-waveform_msclap.shape[0])])
waveform_msclap = waveform_msclap.reshape(1,1,msclap_freq*max_len_in_seconds_msclap)
caption = clap_model.generate_caption(waveform_msclap)[0]
return caption, (start/float(msclap_freq), end/float(msclap_freq))
def do_clap_filter(waveform, text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN):
global FILTER_NUM
if isinstance(filename, tuple):
filename = filename[0]
if filename.startswith("/radiostorage/"):
prefix = "/".join(filename.split("/")[:3])
else:
prefix = "/".join(filename.split("/")[:4])
soundeffect = SOUNDEFFECT[prefix]
if not if_clap_filter:
return np.inf, False, (None, None, soundeffect)
score_threshold = SCORE_THRESHOLD_VAL if val else SCORE_THRESHOLD_TRAIN
if not if_clap_filter or TURNOFF_CLAP_FILTER_GLOBAL:
score_threshold = -np.inf
else:
if not val:
score_threshold = SCORE_THRESHOLD_TRAIN[prefix]
#print(prefix, score_threshold)
resampled = torchaudio.functional.resample(waveform.reshape(1, -1), orig_freq=new_freq, new_freq=clap_freq)[0].numpy()
resampled = resampled[:clap_freq*max_len_in_seconds]
inputs = clap_processor(text=[text], audios=[resampled], return_tensors="pt", padding=True, sampling_rate=clap_freq)
inputs = {k: v.to("cpu") for k, v in inputs.items()}
with torch.no_grad():
outputs = clap(**inputs)
score = torch.dot(outputs.text_embeds[0,:], outputs.audio_embeds[0,:]).item()
#print("do_clap_filter:", filename, text, resampled.shape, outputs.logits_per_audio, outputs.logits_per_text, score, score < score_threshold)
if torch.any(torch.isnan(outputs.text_embeds)) or torch.any(torch.isnan(outputs.audio_embeds)):
return -np.inf, True, None
if main_process and if_clap_filter and not TURNOFF_CLAP_FILTER_GLOBAL:
FILTER_NUM[prefix][0] += 1
if score < score_threshold:
FILTER_NUM[prefix][1] += 1
if FILTER_NUM[prefix][0] % 10000 == 0 or FILTER_NUM[prefix][0] == 1000:
print(prefix, FILTER_NUM[prefix][0], FILTER_NUM[prefix][1]/float(FILTER_NUM[prefix][0]))
return score, score < score_threshold, (outputs.text_embeds, outputs.audio_embeds, soundeffect)
def read_wav_file(filename, text, segment_length, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch):
try:
if isinstance(filename, tuple):
if filename[0].startswith("/radiostorage/"):
prefix = "/".join(filename[0].split("/")[:3])
else:
prefix = "/".join(filename[0].split("/")[:4])
#print(filename, text, segment_length, val)
wav, utt, period = filename
#size = os.path.getsize(wav)
#if size > 200000000:
# print("Exception too large file:", filename, text, size)
# return None, None, None
base, name = wav.rsplit("/", 1)
temp_base = "/ailab-train/speech/zhanghaomin/wav_temp/" + base.replace("/", "__") + "/"
temp_filename = temp_base + name
if os.path.exists(temp_filename):
waveform, sr = torchaudio.load(temp_filename)
else:
#start = time.time()
waveform0, sr = torchaudio.load(wav)
#end = time.time()
#print("load", end-start, wav)
waveform = torchaudio.functional.resample(waveform0, orig_freq=sr, new_freq=new_freq)[0:nch, :]
#if nch >= 2:
# waveform = torch.cat([waveform.mean(axis=0, keepdims=True), waveform], 0)
#print("resample", time.time()-end, wav)
waveform = waveform[:, new_freq*period*period_length: new_freq*(period+1)*period_length] # 0~period_length s
waveform = waveform[:, :new_freq*cut_length]
os.makedirs(temp_base, exist_ok=True)
torchaudio.save(temp_filename, waveform, new_freq)
start = 0
if waveform.shape[1] > new_freq*max_len_in_seconds:
if not val:
start = random.randint(0, waveform.shape[1]-new_freq*max_len_in_seconds)
waveform = waveform[:, start: start+new_freq*max_len_in_seconds]
if val:
text_index = 0
else:
#text_index = random.choice([0,1,2])
#text_index = random.choice([0,1])
text_index = 0
text = text[text_index]
#text, timestamps = msclap_generate(waveform0[:, sr*period*period_length: sr*(period+1)*period_length], sr)
#waveform = waveform[int(timestamps[0]*new_freq): int(timestamps[1]*new_freq)]
#print(waveform.shape, text)
score, filtered, embeddings = do_clap_filter(waveform[0, :], text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN)
if filtered:
print("Exception below threshold:", filename, text, score)
return None, None, None
else:
if filename.startswith("/radiostorage/"):
prefix = "/".join(filename.split("/")[:3])
else:
prefix = "/".join(filename.split("/")[:4])
#size = os.path.getsize(filename)
#if size > 200000000:
# print("Exception too large file:", filename, text, size)
# return None, None, None
base, name = filename.rsplit("/", 1)
temp_base = "/ailab-train/speech/zhanghaomin/wav_temp/" + base.replace("/", "__") + "/"
temp_filename = temp_base + name
if os.path.exists(temp_filename):
#print("wav exist", temp_filename)
waveform, sr = torchaudio.load(temp_filename)
else:
#print("wav not exist", filename)
#start = time.time()
waveform, sr = torchaudio.load(filename) # Faster!!!
#end = time.time()
#print("load", end-start, filename)
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=new_freq)[0:nch, :]
#if nch >= 2:
# waveform = torch.cat([waveform.mean(axis=0, keepdims=True), waveform], 0)
#print("resample", time.time()-end, filename)
waveform = waveform[:, :new_freq*cut_length]
os.makedirs(temp_base, exist_ok=True)
torchaudio.save(temp_filename, waveform, new_freq)
start = 0
if waveform.shape[1] > new_freq*max_len_in_seconds:
if not val:
start = random.randint(0, waveform.shape[1]-new_freq*max_len_in_seconds)
waveform = waveform[:, start: start+new_freq*max_len_in_seconds]
if isinstance(text, tuple):
if val:
text_index = 0
else:
text_index = random.choice(list(range(len(text))))
text = text[text_index]
score, filtered, embeddings = do_clap_filter(waveform[0, :], text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN)
if filtered:
print("Exception below threshold:", filename, text, score)
return None, None, None
except Exception as e:
print("Exception load:", filename, text)
traceback.print_exc()
return None, None, None
#try:
# waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=new_freq)[0]
#except Exception as e:
# print("Exception resample:", waveform.shape, sr, filename, text)
# return None, None, None
if (waveform.shape[1] / float(new_freq) < 0.2) and (not SOUNDEFFECT[prefix]):
print("Exception too short wav:", waveform.shape, sr, new_freq, filename, text)
traceback.print_exc()
return None, None, None
try:
waveform = normalize_wav(waveform)
except Exception as e:
print ("Exception normalizing:", waveform.shape, sr, new_freq, filename, text)
traceback.print_exc()
#waveform = torch.ones(sample_freq*max_len_in_seconds)
return None, None, None
waveform, text = pad_wav(waveform, segment_length, text, prefix, val)
waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8)
waveform = 0.5 * waveform
#print(text)
return waveform, text, embeddings
def get_mel_from_wav(audio, _stft):
audio = torch.nan_to_num(torch.clip(audio, -1, 1))
audio = torch.autograd.Variable(audio, requires_grad=False)
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
return melspec, log_magnitudes_stft, energy
def argmax_lst(lst):
return max(range(len(lst)), key=lst.__getitem__)
def select_segment(waveform, target_length):
ch, wav_length = waveform.shape
assert(ch == 1 and wav_length == total_length * hop_size)
energy = []
for i in range(total_length):
energy.append(torch.mean(torch.abs(waveform[:, i*hop_size: (i+1)*hop_size])))
#sum_energy = []
#for i in range(total_length-target_length+1):
# sum_energy.append(sum(energy[i: i+target_length]))
sum_energy = [sum(energy[:target_length])]
for i in range(1, total_length-target_length+1):
sum_energy.append(sum_energy[-1]-energy[i-1]+energy[i+target_length-1])
start = argmax_lst(sum_energy)
segment = waveform[:, start*hop_size: (start+target_length)*hop_size]
ch, wav_length = segment.shape
assert(ch == 1 and wav_length == target_length * hop_size)
return segment
def wav_to_fbank(paths, texts, num, target_length=total_length, fn_STFT=None, val=False, if_clap_filter=True, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1):
assert fn_STFT is not None
#raw_results = [read_wav_file(path, text, target_length * hop_size, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch) for path, text in zip(paths, texts)]
results = []
#for result in raw_results:
# if result[0] is not None:
# results.append(result)
for path, text in zip(paths, texts):
result = read_wav_file(path, text, target_length * hop_size, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch)
if result[0] is not None:
results.append(result)
if num > 0 and len(results) >= num:
break
if len(results) == 0:
####return None, None, None, None, None
return None, None, None, None, None, None
####waveform = torch.cat([result[0] for result in results], 0)
texts = [result[1] for result in results]
embeddings = [result[2] for result in results]
####fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
####fbank = fbank.transpose(1, 2)
####log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
####fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
#### log_magnitudes_stft, target_length
####)
####return fbank, texts, embeddings, log_magnitudes_stft, waveform
####fbank = fn_STFT(waveform)
fbanks = []
fbank_lens = []
for result in results:
if not val:
length = random.randint(MIN_TARGET_LEN, MAX_TARGET_LEN)
else:
length = (MIN_TARGET_LEN + MAX_TARGET_LEN) // 2
fbank_lens.append(length+LEN_D)
if not val:
waveform = select_segment(result[0], length)
else:
waveform = result[0][:, :length*hop_size]
fbank = fn_STFT(waveform).transpose(-1,-2)
#print("stft", waveform.shape, fbank.shape)
fbanks.append(fbank)
max_length = max(fbank_lens)
for i in range(len(fbanks)):
if fbanks[i].shape[1] < max_length:
fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1)
fbank = torch.cat(fbanks, 0)
fbank_lens = torch.Tensor(fbank_lens).to(torch.int32)
#print("fbank", fbank.shape, fbank_lens)
return fbank, texts, None, None, None, fbank_lens
def uncapitalize(s):
if s:
return s[:1].lower() + s[1:]
else:
return ""
def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=total_length, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1):
sound1, caption1, embeddings1 = read_wav_file(path1, caption1, target_length * hop_size, False, False, main_process, SCORE_THRESHOLD_TRAIN, nch)#[0].numpy()
sound2, caption2, embeddings2 = read_wav_file(path2, caption2, target_length * hop_size, False, False, main_process, SCORE_THRESHOLD_TRAIN, nch)#[0].numpy()
if sound1 is not None and sound2 is not None:
mixed_sound = mix(sound1.numpy(), sound2.numpy(), 0.5, new_freq)
mixed_sound = mixed_sound.astype(np.float32)
mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2))
#resampled = torchaudio.functional.resample(torch.from_numpy(mixed_sound).reshape(1, -1), orig_freq=new_freq, new_freq=clap_freq)[0].numpy()
#resampled = resampled[:clap_freq*max_len_in_seconds]
#inputs = clap_processor(text=[mixed_caption], audios=[resampled], return_tensors="pt", padding=True, sampling_rate=clap_freq)
#inputs = {k: v.to("cpu") for k, v in inputs.items()}
#with torch.no_grad():
# outputs = clap(**inputs)
if not (embeddings1[2] or embeddings2[2]):
filename = path1
else:
filename = "/radiostorage/AudioGroup"
score, filtered, embeddings = do_clap_filter(torch.from_numpy(mixed_sound)[0, :], mixed_caption, filename, False, False, main_process, SCORE_THRESHOLD_TRAIN)
#print(score, filtered, embeddings if embeddings is None else embeddings[2], path1, path2, filename)
if filtered:
#print("Exception below threshold:", path1, path2, caption1, caption2, filename, score)
return None, None, None
return mixed_sound, mixed_caption, embeddings
else:
return None, None, None
def augment(paths, texts, num_items=4, target_length=total_length, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1):
mixed_sounds, mixed_captions, mixed_embeddings = [], [], []
combinations = list(itertools.combinations(list(range(len(texts))), 2))
random.shuffle(combinations)
if len(combinations) < num_items:
selected_combinations = combinations
else:
selected_combinations = combinations[:num_items]
for (i, j) in selected_combinations:
new_sound, new_caption, new_embeddings = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length, main_process, SCORE_THRESHOLD_TRAIN, nch)
if new_sound is not None:
mixed_sounds.append(new_sound)
mixed_captions.append(new_caption)
mixed_embeddings.append(new_embeddings)
if len(mixed_sounds) == 0:
return None, None, None
waveform = torch.tensor(np.concatenate(mixed_sounds, 0))
waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8)
waveform = 0.5 * waveform
return waveform, mixed_captions, mixed_embeddings
def augment_wav_to_fbank(paths, texts, num_items=4, target_length=total_length, fn_STFT=None, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1):
assert fn_STFT is not None
waveform, captions, embeddings = augment(paths, texts, num_items, target_length, main_process, SCORE_THRESHOLD_TRAIN, nch)
if waveform is None:
####return None, None, None, None, None
return None, None, None, None, None, None
####fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
####fbank = fbank.transpose(1, 2)
####log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2)
####
####fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
#### log_magnitudes_stft, target_length
####)
####
####return fbank, log_magnitudes_stft, waveform, captions, embeddings
####fbank = fn_STFT(waveform)
fbanks = []
fbank_lens = []
for i in range(waveform.shape[0]):
length = random.randint(MIN_TARGET_LEN, MAX_TARGET_LEN)
fbank_lens.append(length+LEN_D)
####fbank = fn_STFT(waveform[i:i+1, :length*hop_size]).transpose(-1,-2)
fbank = fn_STFT(select_segment(waveform[i:i+1, :], length)).transpose(-1,-2)
fbanks.append(fbank)
max_length = max(fbank_lens)
for i in range(len(fbanks)):
if fbanks[i].shape[1] < max_length:
fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1)
fbank = torch.cat(fbanks, 0)
fbank_lens = torch.Tensor(fbank_lens).to(torch.int32)
return fbank, None, None, captions, None, fbank_lens