Spaces:
Running
Running
import os | |
import json | |
def test_reconstuct(): | |
import yaml | |
from diffusers import AutoencoderKL | |
from transformers import SpeechT5HifiGan | |
from audioldm2.utilities.data.dataset import AudioDataset | |
from utils import load_clip, load_clap, load_t5 | |
model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2' | |
config = yaml.load( | |
open( | |
'config/16k_64.yaml', | |
'r' | |
), | |
Loader=yaml.FullLoader, | |
) | |
print(config) | |
t5 = load_t5('cuda', max_length=256) | |
clap = load_clap('cuda', max_length=256) | |
dataset = AudioDataset( | |
config=config, split="train", waveform_only=False, dataset_json_path='mini_dataset.json', | |
tokenizer=clap.tokenizer, | |
uncond_pro=0.1, | |
text_ctx_len=77, | |
tokenizer_t5=t5.tokenizer, | |
text_ctx_len_t5=256, | |
uncond_pro_t5=0.1, | |
) | |
print(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0).size()) | |
vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')) | |
vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')) | |
latents = vae.encode(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0)).latent_dist.sample().mul_(vae.config.scaling_factor) | |
print('laten size:', latents.size()) | |
latents = 1 / vae.config.scaling_factor * latents | |
mel_spectrogram = vae.decode(latents).sample | |
print(mel_spectrogram.size()) | |
if mel_spectrogram.dim() == 4: | |
mel_spectrogram = mel_spectrogram.squeeze(1) | |
waveform = vocoder(mel_spectrogram) | |
waveform = waveform[0].cpu().float().detach().numpy() | |
print(waveform.shape) | |
# import soundfile as sf | |
# sf.write('reconstruct.wav', waveform, samplerate=16000) | |
from scipy.io import wavfile | |
# wavfile.write('reconstruct.wav', 16000, waveform) | |
def mini_dataset(num=32): | |
data = [] | |
for i in range(num): | |
data.append( | |
{ | |
'wav': 'case.mp3', | |
'label': 'a beautiful music', | |
} | |
) | |
with open('mini_dataset.json', 'w') as f: | |
json.dump(data, f, indent=4) | |
def fma_dataset(): | |
import pandas as pd | |
annotation_prex = "/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/annotation" | |
annotation_list = ['test-00000-of-00001.parquet', 'train-00000-of-00001.parquet', 'valid-00000-of-00001.parquet'] | |
dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/fma_large' | |
data = [] | |
for annotation_file in annotation_list: | |
annotation_file = os.path.join(annotation_prex, annotation_file) | |
df=pd.read_parquet(annotation_file) | |
print(df.shape) | |
for id, row in df.iterrows(): | |
#print(id, row['pseudo_caption'], row['path']) | |
tmp_path = os.path.join(dataset_prex, row['path'] + '.mp3') | |
# print(tmp_path) | |
if os.path.exists(tmp_path): | |
data.append( | |
{ | |
'wav': tmp_path, | |
'label': row['pseudo_caption'], | |
} | |
) | |
# break | |
print(len(data)) | |
with open('fma_dataset.json', 'w') as f: | |
json.dump(data, f, indent=4) | |
def audioset_dataset(): | |
import pandas as pd | |
dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset' | |
annotation_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset/balanced_train-00000-of-00001.parquet' | |
df=pd.read_parquet(annotation_path) | |
print(df.shape) | |
data = [] | |
for id, row in df.iterrows(): | |
#print(id, row['pseudo_caption'], row['path']) | |
try: | |
tmp_path = os.path.join(dataset_prex, row['path'] + '.flac') | |
except: | |
print(row['path']) | |
if os.path.exists(tmp_path): | |
# print(tmp_path) | |
data.append( | |
{ | |
'wav': tmp_path, | |
'label': row['pseudo_caption'], | |
} | |
) | |
print(len(data)) | |
with open('audioset_dataset.json', 'w') as f: | |
json.dump(data, f, indent=4) | |
def combine_dataset(): | |
data_list = ['fma_dataset.json', 'audioset_dataset.json'] | |
data = [] | |
for data_file in data_list: | |
with open(data_file, 'r') as f: | |
data += json.load(f) | |
print(len(data)) | |
with open('combine_dataset.json', 'w') as f: | |
json.dump(data, f, indent=4) | |
def test_music_format(): | |
import torchaudio | |
filename = '2.flac' | |
waveform, sr = torchaudio.load(filename,) | |
print(waveform, sr ) | |
def test_flops(): | |
version = 'giant' | |
import torch | |
from constants import build_model | |
from thop import profile | |
model = build_model(version).cuda() | |
img_ids = torch.randn((1, 1024, 3)).cuda() | |
txt = torch.randn((1, 256, 4096)).cuda() | |
txt_ids = torch.randn((1, 256, 3)).cuda() | |
y = torch.randn((1, 768)).cuda() | |
x = torch.randn((1, 1024, 32)).cuda() | |
t = torch.tensor([1] * 1).cuda() | |
flops, _ = profile(model, inputs=(x, img_ids, txt, txt_ids, t, y,)) | |
print('FLOPs = ' + str(flops * 2/1000**3) + 'G') | |
# test_music_format() | |
# test_reconstuct() | |
# mini_dataset() | |
# fma_dataset() | |
# audioset_dataset() | |
# combine_dataset() | |
test_flops() |