Spaces:
Running
Running
import torch | |
import torchaudio | |
import random | |
import itertools | |
import numpy as np | |
####from tools.mix import mix | |
from e2_tts_pytorch.mix import mix | |
import time | |
import traceback | |
import os | |
#from datasets import load_dataset | |
####from transformers import ClapModel, ClapProcessor | |
####clap = ClapModel.from_pretrained("/ckptstorage/zhanghaomin/models/EnCLAP/larger_clap_general/").to("cpu") | |
####clap.eval() | |
####for param in clap.parameters(): | |
#### param.requires_grad = False | |
####clap_processor = ClapProcessor.from_pretrained("/ckptstorage/zhanghaomin/models/EnCLAP/larger_clap_general/") | |
#from msclap import CLAP | |
#clap_model = CLAP("/ckptstorage/zhanghaomin/models/msclap/clapcap_weights_2023.pth", version="clapcap", use_cuda=False) | |
#clap_model.clapcap.eval() | |
#for param in clap_model.clapcap.parameters(): | |
# param.requires_grad = False | |
#new_freq = 16000 | |
#hop_size = 160 | |
new_freq = 24000 | |
#hop_size = 256 | |
hop_size = 320 | |
#total_length = 1024 | |
#MIN_TARGET_LEN = 281 | |
#MAX_TARGET_LEN = 937 | |
total_length = 750 | |
MIN_TARGET_LEN = 750 | |
MAX_TARGET_LEN = 750 | |
#LEN_D = 1 | |
LEN_D = 0 | |
clap_freq = 48000 | |
msclap_freq = 44100 | |
max_len_in_seconds = 10 | |
max_len_in_seconds_msclap = 7 | |
#period_length = 30 | |
period_length = 7 | |
cut_length = 10 | |
def normalize_wav(waveform): | |
waveform = waveform - torch.mean(waveform) | |
waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8) | |
return waveform * 0.5 | |
def _pad_spec(fbank, target_length=total_length): | |
batch, n_frames, channels = fbank.shape | |
p = target_length - n_frames | |
if p > 0: | |
pad = torch.zeros(batch, p, channels).to(fbank.device) | |
fbank = torch.cat([fbank, pad], 1) | |
elif p < 0: | |
fbank = fbank[:, :target_length, :] | |
if channels % 2 != 0: | |
fbank = fbank[:, :, :-1] | |
return fbank | |
SCORE_THRESHOLD_VAL = 0.15 | |
#SCORE_THRESHOLD_TRAIN = { | |
# "/zhanghaomin/datas/audiocaps": -np.inf, | |
# "/radiostorage/WavCaps": -np.inf, | |
# "/radiostorage/AudioGroup": -np.inf, | |
# "/ckptstorage/zhanghaomin/audioset": -np.inf, | |
# "/ckptstorage/zhanghaomin/BBCSoundEffects": -np.inf, | |
# "/ckptstorage/zhanghaomin/CLAP_freesound": -np.inf, | |
#} | |
SOUNDEFFECT = { | |
"/zhanghaomin/datas/audiocaps": False, | |
"/radiostorage/WavCaps": False, | |
"/radiostorage/AudioGroup": True, | |
"/ckptstorage/zhanghaomin/audioset": False, | |
"/ckptstorage/zhanghaomin/BBCSoundEffects": False, | |
"/ckptstorage/zhanghaomin/CLAP_freesound": False, | |
"/zhanghaomin/datas/musiccap": False, | |
"/ckptstorage/zhanghaomin/TangoPromptBank": False, | |
"/ckptstorage/zhanghaomin/audiosetsl": False, | |
"/ckptstorage/zhanghaomin/giantsoundeffects": True, | |
} | |
FILTER_NUM = { | |
"/zhanghaomin/datas/audiocaps": [0,0], | |
"/radiostorage/WavCaps": [0,0], | |
"/radiostorage/AudioGroup": [0,0], | |
"/ckptstorage/zhanghaomin/audioset": [0,0], | |
"/ckptstorage/zhanghaomin/BBCSoundEffects": [0,0], | |
"/ckptstorage/zhanghaomin/CLAP_freesound": [0,0], | |
"/zhanghaomin/datas/musiccap": [0,0], | |
"/ckptstorage/zhanghaomin/TangoPromptBank": [0,0], | |
"/ckptstorage/zhanghaomin/audiosetsl": [0,0], | |
"/ckptstorage/zhanghaomin/giantsoundeffects": [0,0], | |
} | |
TURNOFF_CLAP_FILTER_GLOBAL = False | |
def pad_wav(waveform, segment_length, text, prefix, val): | |
waveform_length = waveform.shape[1] | |
if segment_length is None or waveform_length == segment_length: | |
return waveform, text | |
elif waveform_length > segment_length: | |
return waveform[:, :segment_length], text | |
else: | |
if val: | |
if (not SOUNDEFFECT[prefix]) or (waveform_length > segment_length / 3.0): | |
pad_wav = torch.zeros((waveform.shape[0], segment_length-waveform_length)).to(waveform.device) | |
waveform = torch.cat([waveform, pad_wav], 1) | |
return waveform, text | |
else: | |
min_repeats = max(int(segment_length / 3.0 // waveform_length), 2) | |
max_repeats = segment_length // waveform_length | |
if val: | |
repeats = (min_repeats + max_repeats) // 2 | |
else: | |
repeats = random.randint(min_repeats, max_repeats) | |
waveform = torch.cat([waveform]*repeats, 1) | |
if waveform.shape[1] < segment_length: | |
pad_wav = torch.zeros((waveform.shape[0], segment_length-waveform.shape[1])).to(waveform.device) | |
waveform = torch.cat([waveform, pad_wav], 1) | |
#if text[-1] in [",", "."]: | |
# text = text[:-1] + " repeat " + str(repeats) + " times" + text[-1] | |
#else: | |
# text = text + " repeat " + str(repeats) + " times" | |
return waveform, text | |
else: | |
repeats = segment_length // waveform_length + 1 | |
waveform = torch.cat([waveform]*repeats, 1) | |
assert(waveform.shape[0] == 1 and waveform.shape[1] >= segment_length) | |
return waveform[:, :segment_length], text | |
def msclap_generate(waveform, freq): | |
waveform_msclap = torchaudio.functional.resample(waveform, orig_freq=freq, new_freq=msclap_freq)[0] | |
start = 0 | |
end = waveform_msclap.shape[0] | |
if waveform_msclap.shape[0] > msclap_freq*max_len_in_seconds_msclap: | |
start = random.randint(waveform_msclap.shape[0]-msclap_freq*max_len_in_seconds_msclap) | |
end = start+msclap_freq*max_len_in_seconds_msclap | |
waveform_msclap = waveform_msclap[start: end] | |
if waveform_msclap.shape[0] < msclap_freq*max_len_in_seconds_msclap: | |
waveform_msclap = torch.cat([waveform_msclap, torch.zeros(msclap_freq*max_len_in_seconds_msclap-waveform_msclap.shape[0])]) | |
waveform_msclap = waveform_msclap.reshape(1,1,msclap_freq*max_len_in_seconds_msclap) | |
caption = clap_model.generate_caption(waveform_msclap)[0] | |
return caption, (start/float(msclap_freq), end/float(msclap_freq)) | |
def do_clap_filter(waveform, text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN): | |
global FILTER_NUM | |
if isinstance(filename, tuple): | |
filename = filename[0] | |
if filename.startswith("/radiostorage/"): | |
prefix = "/".join(filename.split("/")[:3]) | |
else: | |
prefix = "/".join(filename.split("/")[:4]) | |
soundeffect = SOUNDEFFECT[prefix] | |
if not if_clap_filter: | |
return np.inf, False, (None, None, soundeffect) | |
score_threshold = SCORE_THRESHOLD_VAL if val else SCORE_THRESHOLD_TRAIN | |
if not if_clap_filter or TURNOFF_CLAP_FILTER_GLOBAL: | |
score_threshold = -np.inf | |
else: | |
if not val: | |
score_threshold = SCORE_THRESHOLD_TRAIN[prefix] | |
#print(prefix, score_threshold) | |
resampled = torchaudio.functional.resample(waveform.reshape(1, -1), orig_freq=new_freq, new_freq=clap_freq)[0].numpy() | |
resampled = resampled[:clap_freq*max_len_in_seconds] | |
inputs = clap_processor(text=[text], audios=[resampled], return_tensors="pt", padding=True, sampling_rate=clap_freq) | |
inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = clap(**inputs) | |
score = torch.dot(outputs.text_embeds[0,:], outputs.audio_embeds[0,:]).item() | |
#print("do_clap_filter:", filename, text, resampled.shape, outputs.logits_per_audio, outputs.logits_per_text, score, score < score_threshold) | |
if torch.any(torch.isnan(outputs.text_embeds)) or torch.any(torch.isnan(outputs.audio_embeds)): | |
return -np.inf, True, None | |
if main_process and if_clap_filter and not TURNOFF_CLAP_FILTER_GLOBAL: | |
FILTER_NUM[prefix][0] += 1 | |
if score < score_threshold: | |
FILTER_NUM[prefix][1] += 1 | |
if FILTER_NUM[prefix][0] % 10000 == 0 or FILTER_NUM[prefix][0] == 1000: | |
print(prefix, FILTER_NUM[prefix][0], FILTER_NUM[prefix][1]/float(FILTER_NUM[prefix][0])) | |
return score, score < score_threshold, (outputs.text_embeds, outputs.audio_embeds, soundeffect) | |
def read_wav_file(filename, text, segment_length, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch): | |
try: | |
if isinstance(filename, tuple): | |
if filename[0].startswith("/radiostorage/"): | |
prefix = "/".join(filename[0].split("/")[:3]) | |
else: | |
prefix = "/".join(filename[0].split("/")[:4]) | |
#print(filename, text, segment_length, val) | |
wav, utt, period = filename | |
#size = os.path.getsize(wav) | |
#if size > 200000000: | |
# print("Exception too large file:", filename, text, size) | |
# return None, None, None | |
base, name = wav.rsplit("/", 1) | |
temp_base = "/ailab-train/speech/zhanghaomin/wav_temp/" + base.replace("/", "__") + "/" | |
temp_filename = temp_base + name | |
if os.path.exists(temp_filename): | |
waveform, sr = torchaudio.load(temp_filename) | |
else: | |
#start = time.time() | |
waveform0, sr = torchaudio.load(wav) | |
#end = time.time() | |
#print("load", end-start, wav) | |
waveform = torchaudio.functional.resample(waveform0, orig_freq=sr, new_freq=new_freq)[0:nch, :] | |
#if nch >= 2: | |
# waveform = torch.cat([waveform.mean(axis=0, keepdims=True), waveform], 0) | |
#print("resample", time.time()-end, wav) | |
waveform = waveform[:, new_freq*period*period_length: new_freq*(period+1)*period_length] # 0~period_length s | |
waveform = waveform[:, :new_freq*cut_length] | |
os.makedirs(temp_base, exist_ok=True) | |
torchaudio.save(temp_filename, waveform, new_freq) | |
start = 0 | |
if waveform.shape[1] > new_freq*max_len_in_seconds: | |
if not val: | |
start = random.randint(0, waveform.shape[1]-new_freq*max_len_in_seconds) | |
waveform = waveform[:, start: start+new_freq*max_len_in_seconds] | |
if val: | |
text_index = 0 | |
else: | |
#text_index = random.choice([0,1,2]) | |
#text_index = random.choice([0,1]) | |
text_index = 0 | |
text = text[text_index] | |
#text, timestamps = msclap_generate(waveform0[:, sr*period*period_length: sr*(period+1)*period_length], sr) | |
#waveform = waveform[int(timestamps[0]*new_freq): int(timestamps[1]*new_freq)] | |
#print(waveform.shape, text) | |
score, filtered, embeddings = do_clap_filter(waveform[0, :], text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN) | |
if filtered: | |
print("Exception below threshold:", filename, text, score) | |
return None, None, None | |
else: | |
if filename.startswith("/radiostorage/"): | |
prefix = "/".join(filename.split("/")[:3]) | |
else: | |
prefix = "/".join(filename.split("/")[:4]) | |
#size = os.path.getsize(filename) | |
#if size > 200000000: | |
# print("Exception too large file:", filename, text, size) | |
# return None, None, None | |
base, name = filename.rsplit("/", 1) | |
temp_base = "/ailab-train/speech/zhanghaomin/wav_temp/" + base.replace("/", "__") + "/" | |
temp_filename = temp_base + name | |
if os.path.exists(temp_filename): | |
#print("wav exist", temp_filename) | |
waveform, sr = torchaudio.load(temp_filename) | |
else: | |
#print("wav not exist", filename) | |
#start = time.time() | |
waveform, sr = torchaudio.load(filename) # Faster!!! | |
#end = time.time() | |
#print("load", end-start, filename) | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=new_freq)[0:nch, :] | |
#if nch >= 2: | |
# waveform = torch.cat([waveform.mean(axis=0, keepdims=True), waveform], 0) | |
#print("resample", time.time()-end, filename) | |
waveform = waveform[:, :new_freq*cut_length] | |
os.makedirs(temp_base, exist_ok=True) | |
torchaudio.save(temp_filename, waveform, new_freq) | |
start = 0 | |
if waveform.shape[1] > new_freq*max_len_in_seconds: | |
if not val: | |
start = random.randint(0, waveform.shape[1]-new_freq*max_len_in_seconds) | |
waveform = waveform[:, start: start+new_freq*max_len_in_seconds] | |
if isinstance(text, tuple): | |
if val: | |
text_index = 0 | |
else: | |
text_index = random.choice(list(range(len(text)))) | |
text = text[text_index] | |
score, filtered, embeddings = do_clap_filter(waveform[0, :], text, filename, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN) | |
if filtered: | |
print("Exception below threshold:", filename, text, score) | |
return None, None, None | |
except Exception as e: | |
print("Exception load:", filename, text) | |
traceback.print_exc() | |
return None, None, None | |
#try: | |
# waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=new_freq)[0] | |
#except Exception as e: | |
# print("Exception resample:", waveform.shape, sr, filename, text) | |
# return None, None, None | |
if (waveform.shape[1] / float(new_freq) < 0.2) and (not SOUNDEFFECT[prefix]): | |
print("Exception too short wav:", waveform.shape, sr, new_freq, filename, text) | |
traceback.print_exc() | |
return None, None, None | |
try: | |
waveform = normalize_wav(waveform) | |
except Exception as e: | |
print ("Exception normalizing:", waveform.shape, sr, new_freq, filename, text) | |
traceback.print_exc() | |
#waveform = torch.ones(sample_freq*max_len_in_seconds) | |
return None, None, None | |
waveform, text = pad_wav(waveform, segment_length, text, prefix, val) | |
waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8) | |
waveform = 0.5 * waveform | |
#print(text) | |
return waveform, text, embeddings | |
def get_mel_from_wav(audio, _stft): | |
audio = torch.nan_to_num(torch.clip(audio, -1, 1)) | |
audio = torch.autograd.Variable(audio, requires_grad=False) | |
melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) | |
return melspec, log_magnitudes_stft, energy | |
def argmax_lst(lst): | |
return max(range(len(lst)), key=lst.__getitem__) | |
def select_segment(waveform, target_length): | |
ch, wav_length = waveform.shape | |
assert(ch == 1 and wav_length == total_length * hop_size) | |
energy = [] | |
for i in range(total_length): | |
energy.append(torch.mean(torch.abs(waveform[:, i*hop_size: (i+1)*hop_size]))) | |
#sum_energy = [] | |
#for i in range(total_length-target_length+1): | |
# sum_energy.append(sum(energy[i: i+target_length])) | |
sum_energy = [sum(energy[:target_length])] | |
for i in range(1, total_length-target_length+1): | |
sum_energy.append(sum_energy[-1]-energy[i-1]+energy[i+target_length-1]) | |
start = argmax_lst(sum_energy) | |
segment = waveform[:, start*hop_size: (start+target_length)*hop_size] | |
ch, wav_length = segment.shape | |
assert(ch == 1 and wav_length == target_length * hop_size) | |
return segment | |
def wav_to_fbank(paths, texts, num, target_length=total_length, fn_STFT=None, val=False, if_clap_filter=True, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1): | |
assert fn_STFT is not None | |
#raw_results = [read_wav_file(path, text, target_length * hop_size, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch) for path, text in zip(paths, texts)] | |
results = [] | |
#for result in raw_results: | |
# if result[0] is not None: | |
# results.append(result) | |
for path, text in zip(paths, texts): | |
result = read_wav_file(path, text, target_length * hop_size, val, if_clap_filter, main_process, SCORE_THRESHOLD_TRAIN, nch) | |
if result[0] is not None: | |
results.append(result) | |
if num > 0 and len(results) >= num: | |
break | |
if len(results) == 0: | |
####return None, None, None, None, None | |
return None, None, None, None, None, None | |
####waveform = torch.cat([result[0] for result in results], 0) | |
texts = [result[1] for result in results] | |
embeddings = [result[2] for result in results] | |
####fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) | |
####fbank = fbank.transpose(1, 2) | |
####log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) | |
####fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( | |
#### log_magnitudes_stft, target_length | |
####) | |
####return fbank, texts, embeddings, log_magnitudes_stft, waveform | |
####fbank = fn_STFT(waveform) | |
fbanks = [] | |
fbank_lens = [] | |
for result in results: | |
if not val: | |
length = random.randint(MIN_TARGET_LEN, MAX_TARGET_LEN) | |
else: | |
length = (MIN_TARGET_LEN + MAX_TARGET_LEN) // 2 | |
fbank_lens.append(length+LEN_D) | |
if not val: | |
waveform = select_segment(result[0], length) | |
else: | |
waveform = result[0][:, :length*hop_size] | |
fbank = fn_STFT(waveform).transpose(-1,-2) | |
#print("stft", waveform.shape, fbank.shape) | |
fbanks.append(fbank) | |
max_length = max(fbank_lens) | |
for i in range(len(fbanks)): | |
if fbanks[i].shape[1] < max_length: | |
fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1) | |
fbank = torch.cat(fbanks, 0) | |
fbank_lens = torch.Tensor(fbank_lens).to(torch.int32) | |
#print("fbank", fbank.shape, fbank_lens) | |
return fbank, texts, None, None, None, fbank_lens | |
def uncapitalize(s): | |
if s: | |
return s[:1].lower() + s[1:] | |
else: | |
return "" | |
def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=total_length, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1): | |
sound1, caption1, embeddings1 = read_wav_file(path1, caption1, target_length * hop_size, False, False, main_process, SCORE_THRESHOLD_TRAIN, nch)#[0].numpy() | |
sound2, caption2, embeddings2 = read_wav_file(path2, caption2, target_length * hop_size, False, False, main_process, SCORE_THRESHOLD_TRAIN, nch)#[0].numpy() | |
if sound1 is not None and sound2 is not None: | |
mixed_sound = mix(sound1.numpy(), sound2.numpy(), 0.5, new_freq) | |
mixed_sound = mixed_sound.astype(np.float32) | |
mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2)) | |
#resampled = torchaudio.functional.resample(torch.from_numpy(mixed_sound).reshape(1, -1), orig_freq=new_freq, new_freq=clap_freq)[0].numpy() | |
#resampled = resampled[:clap_freq*max_len_in_seconds] | |
#inputs = clap_processor(text=[mixed_caption], audios=[resampled], return_tensors="pt", padding=True, sampling_rate=clap_freq) | |
#inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
#with torch.no_grad(): | |
# outputs = clap(**inputs) | |
if not (embeddings1[2] or embeddings2[2]): | |
filename = path1 | |
else: | |
filename = "/radiostorage/AudioGroup" | |
score, filtered, embeddings = do_clap_filter(torch.from_numpy(mixed_sound)[0, :], mixed_caption, filename, False, False, main_process, SCORE_THRESHOLD_TRAIN) | |
#print(score, filtered, embeddings if embeddings is None else embeddings[2], path1, path2, filename) | |
if filtered: | |
#print("Exception below threshold:", path1, path2, caption1, caption2, filename, score) | |
return None, None, None | |
return mixed_sound, mixed_caption, embeddings | |
else: | |
return None, None, None | |
def augment(paths, texts, num_items=4, target_length=total_length, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1): | |
mixed_sounds, mixed_captions, mixed_embeddings = [], [], [] | |
combinations = list(itertools.combinations(list(range(len(texts))), 2)) | |
random.shuffle(combinations) | |
if len(combinations) < num_items: | |
selected_combinations = combinations | |
else: | |
selected_combinations = combinations[:num_items] | |
for (i, j) in selected_combinations: | |
new_sound, new_caption, new_embeddings = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length, main_process, SCORE_THRESHOLD_TRAIN, nch) | |
if new_sound is not None: | |
mixed_sounds.append(new_sound) | |
mixed_captions.append(new_caption) | |
mixed_embeddings.append(new_embeddings) | |
if len(mixed_sounds) == 0: | |
return None, None, None | |
waveform = torch.tensor(np.concatenate(mixed_sounds, 0)) | |
waveform = waveform / (torch.max(torch.abs(waveform[0, :])) + 1e-8) | |
waveform = 0.5 * waveform | |
return waveform, mixed_captions, mixed_embeddings | |
def augment_wav_to_fbank(paths, texts, num_items=4, target_length=total_length, fn_STFT=None, main_process=True, SCORE_THRESHOLD_TRAIN="", nch=1): | |
assert fn_STFT is not None | |
waveform, captions, embeddings = augment(paths, texts, num_items, target_length, main_process, SCORE_THRESHOLD_TRAIN, nch) | |
if waveform is None: | |
####return None, None, None, None, None | |
return None, None, None, None, None, None | |
####fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) | |
####fbank = fbank.transpose(1, 2) | |
####log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) | |
#### | |
####fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( | |
#### log_magnitudes_stft, target_length | |
####) | |
#### | |
####return fbank, log_magnitudes_stft, waveform, captions, embeddings | |
####fbank = fn_STFT(waveform) | |
fbanks = [] | |
fbank_lens = [] | |
for i in range(waveform.shape[0]): | |
length = random.randint(MIN_TARGET_LEN, MAX_TARGET_LEN) | |
fbank_lens.append(length+LEN_D) | |
####fbank = fn_STFT(waveform[i:i+1, :length*hop_size]).transpose(-1,-2) | |
fbank = fn_STFT(select_segment(waveform[i:i+1, :], length)).transpose(-1,-2) | |
fbanks.append(fbank) | |
max_length = max(fbank_lens) | |
for i in range(len(fbanks)): | |
if fbanks[i].shape[1] < max_length: | |
fbanks[i] = torch.cat([fbanks[i], torch.zeros(fbanks[i].shape[0], max_length-fbanks[i].shape[1], fbanks[i].shape[2])], 1) | |
fbank = torch.cat(fbanks, 0) | |
fbank_lens = torch.Tensor(fbank_lens).to(torch.int32) | |
return fbank, None, None, captions, None, fbank_lens |