FluxMusicGUI / test.py
flosstradamus's picture
Upload 194 files
afe1a07 verified
raw
history blame
5.53 kB
import os
import json
def test_reconstuct():
import yaml
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
from audioldm2.utilities.data.dataset import AudioDataset
from utils import load_clip, load_clap, load_t5
model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2'
config = yaml.load(
open(
'config/16k_64.yaml',
'r'
),
Loader=yaml.FullLoader,
)
print(config)
t5 = load_t5('cuda', max_length=256)
clap = load_clap('cuda', max_length=256)
dataset = AudioDataset(
config=config, split="train", waveform_only=False, dataset_json_path='mini_dataset.json',
tokenizer=clap.tokenizer,
uncond_pro=0.1,
text_ctx_len=77,
tokenizer_t5=t5.tokenizer,
text_ctx_len_t5=256,
uncond_pro_t5=0.1,
)
print(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0).size())
vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae'))
vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder'))
latents = vae.encode(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0)).latent_dist.sample().mul_(vae.config.scaling_factor)
print('laten size:', latents.size())
latents = 1 / vae.config.scaling_factor * latents
mel_spectrogram = vae.decode(latents).sample
print(mel_spectrogram.size())
if mel_spectrogram.dim() == 4:
mel_spectrogram = mel_spectrogram.squeeze(1)
waveform = vocoder(mel_spectrogram)
waveform = waveform[0].cpu().float().detach().numpy()
print(waveform.shape)
# import soundfile as sf
# sf.write('reconstruct.wav', waveform, samplerate=16000)
from scipy.io import wavfile
# wavfile.write('reconstruct.wav', 16000, waveform)
def mini_dataset(num=32):
data = []
for i in range(num):
data.append(
{
'wav': 'case.mp3',
'label': 'a beautiful music',
}
)
with open('mini_dataset.json', 'w') as f:
json.dump(data, f, indent=4)
def fma_dataset():
import pandas as pd
annotation_prex = "/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/annotation"
annotation_list = ['test-00000-of-00001.parquet', 'train-00000-of-00001.parquet', 'valid-00000-of-00001.parquet']
dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/fma_large'
data = []
for annotation_file in annotation_list:
annotation_file = os.path.join(annotation_prex, annotation_file)
df=pd.read_parquet(annotation_file)
print(df.shape)
for id, row in df.iterrows():
#print(id, row['pseudo_caption'], row['path'])
tmp_path = os.path.join(dataset_prex, row['path'] + '.mp3')
# print(tmp_path)
if os.path.exists(tmp_path):
data.append(
{
'wav': tmp_path,
'label': row['pseudo_caption'],
}
)
# break
print(len(data))
with open('fma_dataset.json', 'w') as f:
json.dump(data, f, indent=4)
def audioset_dataset():
import pandas as pd
dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset'
annotation_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset/balanced_train-00000-of-00001.parquet'
df=pd.read_parquet(annotation_path)
print(df.shape)
data = []
for id, row in df.iterrows():
#print(id, row['pseudo_caption'], row['path'])
try:
tmp_path = os.path.join(dataset_prex, row['path'] + '.flac')
except:
print(row['path'])
if os.path.exists(tmp_path):
# print(tmp_path)
data.append(
{
'wav': tmp_path,
'label': row['pseudo_caption'],
}
)
print(len(data))
with open('audioset_dataset.json', 'w') as f:
json.dump(data, f, indent=4)
def combine_dataset():
data_list = ['fma_dataset.json', 'audioset_dataset.json']
data = []
for data_file in data_list:
with open(data_file, 'r') as f:
data += json.load(f)
print(len(data))
with open('combine_dataset.json', 'w') as f:
json.dump(data, f, indent=4)
def test_music_format():
import torchaudio
filename = '2.flac'
waveform, sr = torchaudio.load(filename,)
print(waveform, sr )
def test_flops():
version = 'giant'
import torch
from constants import build_model
from thop import profile
model = build_model(version).cuda()
img_ids = torch.randn((1, 1024, 3)).cuda()
txt = torch.randn((1, 256, 4096)).cuda()
txt_ids = torch.randn((1, 256, 3)).cuda()
y = torch.randn((1, 768)).cuda()
x = torch.randn((1, 1024, 32)).cuda()
t = torch.tensor([1] * 1).cuda()
flops, _ = profile(model, inputs=(x, img_ids, txt, txt_ids, t, y,))
print('FLOPs = ' + str(flops * 2/1000**3) + 'G')
# test_music_format()
# test_reconstuct()
# mini_dataset()
# fma_dataset()
# audioset_dataset()
# combine_dataset()
test_flops()