import torch import torchaudio import random import itertools import numpy as np ####from tools.mix import mix from e2_tts_pytorch.mix import mix import time import traceback import os #from datasets import load_dataset ####from transformers import ClapModel, ClapProcessor ####clap = ClapModel.from_pretrained("/ckptstorage/zhanghaomin/models/EnCLAP/larger_clap_general/").to("cpu") ####clap.eval() ####for param in clap.parameters(): #### param.requires_grad = False ####clap_processor = ClapProcessor.from_pretrained("/ckptstorage/zhanghaomin/models/EnCLAP/larger_clap_general/") #from msclap import CLAP #clap_model = CLAP("/ckptstorage/zhanghaomin/models/msclap/clapcap_weights_2023.pth", version="clapcap", use_cuda=False) #clap_model.clapcap.eval() #for param in clap_model.clapcap.parameters(): # param.requires_grad = False #new_freq = 16000 #hop_size = 160 new_freq = 24000 #hop_size = 256 hop_size = 320 #total_length = 1024 #MIN_TARGET_LEN = 281 #MAX_TARGET_LEN = 937 total_length = 750 MIN_TARGET_LEN = 750 MAX_TARGET_LEN = 750 #LEN_D = 1 LEN_D = 0 clap_freq = 48000 msclap_freq = 44100 max_len_in_seconds = 10 max_len_in_seconds_msclap = 7 #period_length = 30 period_length = 7 cut_length = 10 def normalize_wav(waveform): waveform = waveform - torch.mean(waveform) waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8) return waveform * 0.5 def _pad_spec(fbank, target_length=total_length): batch, n_frames, channels = fbank.shape p = target_length - n_frames if p > 0: pad = torch.zeros(batch, p, channels).to(fbank.device) fbank = torch.cat([fbank, pad], 1) elif p < 0: fbank = fbank[:, :target_length, :] if channels % 2 != 0: fbank = fbank[:, :, :-1] return fbank SCORE_THRESHOLD_VAL = 0.15 #SCORE_THRESHOLD_TRAIN = { # "/zhanghaomin/datas/audiocaps": -np.inf, # "/radiostorage/WavCaps": -np.inf, # "/radiostorage/AudioGroup": -np.inf, # "/ckptstorage/zhanghaomin/audioset": -np.inf, # "/ckptstorage/zhanghaomin/BBCSoundEffects": -np.inf, # "/ckptstorage/zhanghaomin/CLAP_freesound": -np.inf, #} SOUNDEFFECT = { "/zhanghaomin/datas/audiocaps": False, "/radiostorage/WavCaps": False, "/radiostorage/AudioGroup": True, "/ckptstorage/zhanghaomin/audioset": False, "/ckptstorage/zhanghaomin/BBCSoundEffects": False, "/ckptstorage/zhanghaomin/CLAP_freesound": False, "/zhanghaomin/datas/musiccap": False, "/ckptstorage/zhanghaomin/TangoPromptBank": False, "/ckptstorage/zhanghaomin/audiosetsl": False, "/ckptstorage/zhanghaomin/giantsoundeffects": True, } FILTER_NUM = { "/zhanghaomin/datas/audiocaps": [0,0], "/radiostorage/WavCaps": [0,0], "/radiostorage/AudioGroup": [0,0], "/ckptstorage/zhanghaomin/audioset": [0,0], "/ckptstorage/zhanghaomin/BBCSoundEffects": [0,0], "/ckptstorage/zhanghaomin/CLAP_freesound": [0,0], "/zhanghaomin/datas/musiccap": [0,0], "/ckptstorage/zhanghaomin/TangoPromptBank": [0,0], "/ckptstorage/zhanghaomin/audiosetsl": [0,0], "/ckptstorage/zhanghaomin/giantsoundeffects": [0,0], } TURNOFF_CLAP_FILTER_GLOBAL = False def pad_wav(waveform, segment_length, text, prefix, val): waveform_length = waveform.shape[1] if segment_length is None or waveform_length == segment_length: return waveform, text elif waveform_length > segment_length: return waveform[:, :segment_length], text else: if val: if (not SOUNDEFFECT[prefix]) or (waveform_length > segment_length / 3.0): pad_wav = torch.zeros((waveform.shape[0], segment_length-waveform_length)).to(waveform.device) waveform = torch.cat([waveform, pad_wav], 1) return waveform, text else: min_repeats = max(int(segment_length / 3.0 // waveform_length), 2) max_repeats = segment_length // waveform_length if val: repeats = (min_repeats + max_repeats) // 2 else: repeats = random.randint(min_repeats, max_repeats) waveform = torch.cat([waveform]*repeats, 1) if waveform.shape[1] < segment_length: pad_wav = torch.zeros((waveform.shape[0], segment_length-waveform.shape[1])).to(waveform.device) waveform = torch.cat([waveform, pad_wav], 1) #if text[-1] in [",", "."]: # text = text[:-1] + " repeat " + str(repeats) + " times" + text[-1] #else: # text = text + " repeat " + str(repeats) + " times" return waveform, text else: repeats = segment_length // waveform_length + 1 waveform = torch.cat([waveform]*repeats, 1) assert(waveform.shape[0] == 1 and waveform.shape[1] >= segment_length) return waveform[:, :segment_length], text def msclap_generate(waveform, freq): waveform_msclap = torchaudio.functional.resample(waveform, orig_freq=freq, new_freq=msclap_freq)[0] start = 0 end = waveform_msclap.shape[0] if waveform_msclap.shape[0] > msclap_freq*max_len_in_seconds_msclap: start = random.randint(waveform_msclap.shape[0]-msclap_freq*max_len_in_seconds_msclap) end = start+msclap_freq*max_len_in_seconds_msclap waveform_msclap = waveform_msclap[start: end] if waveform_msclap.shape[0] < msclap_freq*max_len_in_seconds_msclap: waveform_msclap = torch.cat([waveform_msclap, torch.zeros(msclap_freq*max_len_in_seconds_msclap-waveform_msclap.shape[0])]) waveform_msclap = waveform_msclap.reshape(1,1,msclap_freq*max_len_in_seconds_msclap) caption = clap_model.generate_caption(waveform_msclap)[0] return caption, (start/float(msclap_freq), end/float(msclap_freq)) def do_clap_filter(waveform, text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN): global FILTER_NUM if isinstance(filename, tuple): filename = filename[0] if filename.startswith("/radiostorage/"): prefix = "/".join(filename.split("/")[:3]) else: prefix = "/".join(filename.split("/")[:4]) soundeffect = SOUNDEFFECT[prefix] if not if_clap_filter: return np.inf, False, (None, None, soundeffect) score_threshold = SCORE_THRESHOLD_VAL if val else SCORE_THRESHOLD_TRAIN if not if_clap_filter or TURNOFF_CLAP_FILTER_GLOBAL: score_threshold = -np.inf else: if not val: score_threshold = SCORE_THRESHOLD_TRAIN[prefix] #print(prefix, score_threshold) resampled = torchaudio.functional.resample(waveform.reshape(1, -1), orig_freq=new_freq, new_freq=clap_freq)[0].numpy() resampled = resampled[:clap_freq*max_len_in_seconds] inputs = clap_processor(text=[text], audios=[resampled], return_tensors="pt", padding=True, sampling_rate=clap_freq) inputs = {k: v.to("cpu") for k, v in inputs.items()} with torch.no_grad(): outputs = clap(**inputs) score = torch.dot(outputs.text_embeds[0,:], outputs.audio_embeds[0,:]).item() #print("do_clap_filter:", filename, text, resampled.shape, outputs.logits_per_audio, outputs.logits_per_text, score, score < score_threshold) if torch.any(torch.isnan(outputs.text_embeds)) or torch.any(torch.isnan(outputs.audio_embeds)): return -np.inf, True, None if main_process and if_clap_filter and not TURNOFF_CLAP_FILTER_GLOBAL: FILTER_NUM[prefix][0] += 1 if score < score_threshold: FILTER_NUM[prefix][1] += 1 if FILTER_NUM[prefix][0] % 10000 == 0 or FILTER_NUM[prefix][0] == 1000: print(prefix, FILTER_NUM[prefix][0], FILTER_NUM[prefix][1]/float(FILTER_NUM[prefix][0])) return score, score < score_threshold, (outputs.text_embeds, outputs.audio_embeds, soundeffect) def read_wav_file(filename, text, segment_length, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch): try: if isinstance(filename, tuple): if filename[0].startswith("/radiostorage/"): prefix = "/".join(filename[0].split("/")[:3]) else: prefix = "/".join(filename[0].split("/")[:4]) #print(filename, text, segment_length, val) wav, utt, period = filename #size = os.path.getsize(wav) #if size > 200000000: # print("Exception too large file:", filename, text, size) # return None, None, None base, name = wav.rsplit("/", 1) temp_base = "/ailab-train/speech/zhanghaomin/wav_temp/" + base.replace("/", "__") + "/" temp_filename = temp_base + name if os.path.exists(temp_filename): waveform, sr = torchaudio.load(temp_filename) else: #start = time.time() waveform0, sr = torchaudio.load(wav) #end = time.time() #print("load", end-start, wav) waveform = torchaudio.functional.resample(waveform0, orig_freq=sr, new_freq=new_freq)[0:nch, :] #if nch >= 2: # waveform = torch.cat([waveform.mean(axis=0, keepdims=True), waveform], 0) #print("resample", time.time()-end, wav) waveform = waveform[:, new_freq*period*period_length: new_freq*(period+1)*period_length] # 0~period_length s waveform = waveform[:, :new_freq*cut_length] os.makedirs(temp_base, exist_ok=True) torchaudio.save(temp_filename, waveform, new_freq) start = 0 if waveform.shape[1] > new_freq*max_len_in_seconds: if not val: start = random.randint(0, waveform.shape[1]-new_freq*max_len_in_seconds) waveform = waveform[:, start: start+new_freq*max_len_in_seconds] if val: text_index = 0 else: #text_index = random.choice([0,1,2]) #text_index = random.choice([0,1]) text_index = 0 text = text[text_index] #text, timestamps = msclap_generate(waveform0[:, sr*period*period_length: sr*(period+1)*period_length], sr) #waveform = waveform[int(timestamps[0]*new_freq): int(timestamps[1]*new_freq)] #print(waveform.shape, text) score, filtered, embeddings = do_clap_filter(waveform[0, :], text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN) if filtered: print("Exception below threshold:", filename, text, score) return None, None, None else: if filename.startswith("/radiostorage/"): prefix = "/".join(filename.split("/")[:3]) else: prefix = "/".join(filename.split("/")[:4]) #size = os.path.getsize(filename) #if size > 200000000: # print("Exception too large file:", filename, text, size) # return None, None, None base, name = filename.rsplit("/", 1) temp_base = "/ailab-train/speech/zhanghaomin/wav_temp/" + base.replace("/", "__") + "/" temp_filename = temp_base + name if os.path.exists(temp_filename): #print("wav exist", temp_filename) waveform, sr = torchaudio.load(temp_filename) else: #print("wav not exist", filename) #start = time.time() waveform, sr = torchaudio.load(filename) # Faster!!! #end = time.time() #print("load", end-start, filename) waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=new_freq)[0:nch, :] #if nch >= 2: # waveform = torch.cat([waveform.mean(axis=0, keepdims=True), waveform], 0) #print("resample", time.time()-end, filename) waveform = waveform[:, :new_freq*cut_length] os.makedirs(temp_base, exist_ok=True) torchaudio.save(temp_filename, waveform, new_freq) start = 0 if waveform.shape[1] > new_freq*max_len_in_seconds: if not val: start = random.randint(0, waveform.shape[1]-new_freq*max_len_in_seconds) waveform = waveform[:, start: start+new_freq*max_len_in_seconds] if isinstance(text, tuple): if val: text_index = 0 else: text_index = random.choice(list(range(len(text)))) text = text[text_index] score, filtered, embeddings = do_clap_filter(waveform[0, :], text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN) if filtered: print("Exception below threshold:", filename, text, score) return None, None, None except Exception as e: print("Exception load:", filename, text) traceback.print_exc() return None, None, None #try: # waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=new_freq)[0] #except Exception as e: # print("Exception resample:", waveform.shape, sr, filename, text) # return None, None, None if (waveform.shape[1] / float(new_freq) < 0.2) and (not SOUNDEFFECT[prefix]): print("Exception too short wav:", waveform.shape, sr, new_freq, filename, text) traceback.print_exc() return None, None, None try: waveform = normalize_wav(waveform) except Exception as e: print ("Exception normalizing:", waveform.shape, sr, new_freq, filename, text) traceback.print_exc() #waveform = torch.ones(sample_freq*max_len_in_seconds) return None, None, None waveform, text = pad_wav(waveform, segment_length, text, prefix, val) waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8) waveform = 0.5 * waveform #print(text) return waveform, text, embeddings def get_mel_from_wav(audio, _stft): audio = torch.nan_to_num(torch.clip(audio, -1, 1)) audio = torch.autograd.Variable(audio, requires_grad=False) melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) return melspec, log_magnitudes_stft, energy def argmax_lst(lst): return max(range(len(lst)), key=lst.__getitem__) def select_segment(waveform, target_length): ch, wav_length = waveform.shape assert(ch == 1 and wav_length == total_length * hop_size) energy = [] for i in range(total_length): energy.append(torch.mean(torch.abs(waveform[:, i*hop_size: (i+1)*hop_size]))) #sum_energy = [] #for i in range(total_length-target_length+1): # sum_energy.append(sum(energy[i: i+target_length])) sum_energy = [sum(energy[:target_length])] for i in range(1, total_length-target_length+1): sum_energy.append(sum_energy[-1]-energy[i-1]+energy[i+target_length-1]) start = argmax_lst(sum_energy) segment = waveform[:, start*hop_size: (start+target_length)*hop_size] ch, wav_length = segment.shape assert(ch == 1 and wav_length == target_length * hop_size) return segment def wav_to_fbank(paths, texts, num, target_length=total_length, fn_STFT=None, val=False, if_clap_filter=True, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1): assert fn_STFT is not None #raw_results = [read_wav_file(path, text, target_length * hop_size, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch) for path, text in zip(paths, texts)] results = [] #for result in raw_results: # if result[0] is not None: # results.append(result) for path, text in zip(paths, texts): result = read_wav_file(path, text, target_length * hop_size, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch) if result[0] is not None: results.append(result) if num > 0 and len(results) >= num: break if len(results) == 0: ####return None, None, None, None, None return None, None, None, None, None, None ####waveform = torch.cat([result[0] for result in results], 0) texts = [result[1] for result in results] embeddings = [result[2] for result in results] ####fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) ####fbank = fbank.transpose(1, 2) ####log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) ####fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( #### log_magnitudes_stft, target_length ####) ####return fbank, texts, embeddings, log_magnitudes_stft, waveform ####fbank = fn_STFT(waveform) fbanks = [] fbank_lens = [] for result in results: if not val: length = random.randint(MIN_TARGET_LEN, MAX_TARGET_LEN) else: length = (MIN_TARGET_LEN + MAX_TARGET_LEN) // 2 fbank_lens.append(length+LEN_D) if not val: waveform = select_segment(result[0], length) else: waveform = result[0][:, :length*hop_size] fbank = fn_STFT(waveform).transpose(-1,-2) #print("stft", waveform.shape, fbank.shape) fbanks.append(fbank) max_length = max(fbank_lens) for i in range(len(fbanks)): if fbanks[i].shape[1] < max_length: fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1) fbank = torch.cat(fbanks, 0) fbank_lens = torch.Tensor(fbank_lens).to(torch.int32) #print("fbank", fbank.shape, fbank_lens) return fbank, texts, None, None, None, fbank_lens def uncapitalize(s): if s: return s[:1].lower() + s[1:] else: return "" def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=total_length, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1): sound1, caption1, embeddings1 = read_wav_file(path1, caption1, target_length * hop_size, False, False, main_process, SCORE_THRESHOLD_TRAIN, nch)#[0].numpy() sound2, caption2, embeddings2 = read_wav_file(path2, caption2, target_length * hop_size, False, False, main_process, SCORE_THRESHOLD_TRAIN, nch)#[0].numpy() if sound1 is not None and sound2 is not None: mixed_sound = mix(sound1.numpy(), sound2.numpy(), 0.5, new_freq) mixed_sound = mixed_sound.astype(np.float32) mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2)) #resampled = torchaudio.functional.resample(torch.from_numpy(mixed_sound).reshape(1, -1), orig_freq=new_freq, new_freq=clap_freq)[0].numpy() #resampled = resampled[:clap_freq*max_len_in_seconds] #inputs = clap_processor(text=[mixed_caption], audios=[resampled], return_tensors="pt", padding=True, sampling_rate=clap_freq) #inputs = {k: v.to("cpu") for k, v in inputs.items()} #with torch.no_grad(): # outputs = clap(**inputs) if not (embeddings1[2] or embeddings2[2]): filename = path1 else: filename = "/radiostorage/AudioGroup" score, filtered, embeddings = do_clap_filter(torch.from_numpy(mixed_sound)[0, :], mixed_caption, filename, False, False, main_process, SCORE_THRESHOLD_TRAIN) #print(score, filtered, embeddings if embeddings is None else embeddings[2], path1, path2, filename) if filtered: #print("Exception below threshold:", path1, path2, caption1, caption2, filename, score) return None, None, None return mixed_sound, mixed_caption, embeddings else: return None, None, None def augment(paths, texts, num_items=4, target_length=total_length, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1): mixed_sounds, mixed_captions, mixed_embeddings = [], [], [] combinations = list(itertools.combinations(list(range(len(texts))), 2)) random.shuffle(combinations) if len(combinations) < num_items: selected_combinations = combinations else: selected_combinations = combinations[:num_items] for (i, j) in selected_combinations: new_sound, new_caption, new_embeddings = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length, main_process, SCORE_THRESHOLD_TRAIN, nch) if new_sound is not None: mixed_sounds.append(new_sound) mixed_captions.append(new_caption) mixed_embeddings.append(new_embeddings) if len(mixed_sounds) == 0: return None, None, None waveform = torch.tensor(np.concatenate(mixed_sounds, 0)) waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8) waveform = 0.5 * waveform return waveform, mixed_captions, mixed_embeddings def augment_wav_to_fbank(paths, texts, num_items=4, target_length=total_length, fn_STFT=None, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1): assert fn_STFT is not None waveform, captions, embeddings = augment(paths, texts, num_items, target_length, main_process, SCORE_THRESHOLD_TRAIN, nch) if waveform is None: ####return None, None, None, None, None return None, None, None, None, None, None ####fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) ####fbank = fbank.transpose(1, 2) ####log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) #### ####fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( #### log_magnitudes_stft, target_length ####) #### ####return fbank, log_magnitudes_stft, waveform, captions, embeddings ####fbank = fn_STFT(waveform) fbanks = [] fbank_lens = [] for i in range(waveform.shape[0]): length = random.randint(MIN_TARGET_LEN, MAX_TARGET_LEN) fbank_lens.append(length+LEN_D) ####fbank = fn_STFT(waveform[i:i+1, :length*hop_size]).transpose(-1,-2) fbank = fn_STFT(select_segment(waveform[i:i+1, :], length)).transpose(-1,-2) fbanks.append(fbank) max_length = max(fbank_lens) for i in range(len(fbanks)): if fbanks[i].shape[1] < max_length: fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1) fbank = torch.cat(fbanks, 0) fbank_lens = torch.Tensor(fbank_lens).to(torch.int32) return fbank, None, None, captions, None, fbank_lens