|
import os |
|
import torch |
|
import h5py |
|
import random |
|
import numpy as np |
|
import soundfile as sf |
|
from models import DiT |
|
from diffusion import create_diffusion |
|
from tqdm import tqdm |
|
import sys |
|
sys.path.append('./tools/bigvgan_v2_22khz_80band_256x') |
|
from bigvgan import BigVGAN |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import argparse |
|
|
|
device = 'cuda:1' if torch.cuda.is_available() else 'cpu' |
|
|
|
class MelToAudio_bigvgan(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.vocoder = BigVGAN.from_pretrained('/home/zheqid/workspace/music_dit/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False) |
|
self.vocoder.remove_weight_norm() |
|
|
|
def __call__(self, z): |
|
x = self.mel_to_audio(z) |
|
return x |
|
|
|
def mel_to_audio(self, x): |
|
with torch.no_grad(): |
|
self.vocoder.eval() |
|
y = self.vocoder(x[:, :, :]) |
|
y = y.squeeze(0) |
|
return y |
|
|
|
vocoder = MelToAudio_bigvgan().to(device) |
|
|
|
def load_trained_model(checkpoint_path): |
|
model = DiT( |
|
input_size=(80, 800), |
|
patch_size=8, |
|
in_channels=1, |
|
hidden_size=384, |
|
depth=12, |
|
num_heads=6, |
|
) |
|
model.to(device) |
|
checkpoint = torch.load(checkpoint_path) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.eval() |
|
return model |
|
|
|
def load_all_meta_and_mel_from_h5(h5_file): |
|
with h5py.File(h5_file, 'r') as f: |
|
keys = list(f.keys()) |
|
for key in keys: |
|
meta_latent = torch.FloatTensor(f[key]['meta'][:]).to(device) |
|
mel = torch.FloatTensor(f[key]['mel'][:]).to(device) |
|
yield key, meta_latent, mel |
|
|
|
def extract_random_mel_segment(mel, segment_length=800): |
|
total_length = mel.shape[2] |
|
if total_length > segment_length: |
|
start = np.random.randint(0, total_length - segment_length) |
|
mel_segment = mel[:, :, start:start + segment_length] |
|
else: |
|
padding = segment_length - total_length |
|
mel_segment = F.pad(mel, (0, padding), mode='constant', value=0) |
|
|
|
mel_segment = (mel_segment + 10) / 20 |
|
return mel_segment |
|
|
|
def infer_and_generate_audio(model, diffusion, meta_latent): |
|
latent_size = (80, 800) |
|
z = torch.randn(1, 1, latent_size[0], latent_size[1], device=device) |
|
model_kwargs = dict(y=meta_latent) |
|
|
|
with torch.no_grad(): |
|
samples = diffusion.p_sample_loop( |
|
model.forward, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device |
|
) |
|
|
|
return samples |
|
|
|
def save_audio(mel, vocoder, output_path, sample_rate=24000): |
|
with torch.no_grad(): |
|
if mel.dim() == 4 and mel.shape[1] == 1: |
|
mel = mel[0, 0, :, :] |
|
elif mel.dim() == 3 and mel.shape[0] == 1: |
|
mel = mel[0] |
|
else: |
|
raise ValueError(f"Unexpected mel shape: {mel.shape}") |
|
|
|
mel = mel.unsqueeze(0) |
|
wav = vocoder(mel * 20 - 10).cpu().numpy() |
|
|
|
sf.write(output_path, wav[0], samplerate=sample_rate) |
|
print(f"Saved audio to: {output_path}") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Generate audio using DiT and BigVGAN') |
|
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') |
|
parser.add_argument('--h5_file', type=str, required=True, help='Path to input H5 file') |
|
parser.add_argument('--output_gt_dir', type=str, required=True, help='Directory to save ground truth audio') |
|
parser.add_argument('--output_gen_dir', type=str, required=True, help='Directory to save generated audio') |
|
parser.add_argument('--segment_length', type=int, default=800, help='Segment length for mel slices (default: 800)') |
|
parser.add_argument('--sample_rate', type=int, default=22050, help='Sample rate for output audio (default: 24000)') |
|
args = parser.parse_args() |
|
|
|
model = load_trained_model(args.checkpoint) |
|
diffusion = create_diffusion(timestep_respacing="") |
|
|
|
for i, (key, meta_latent, mel) in enumerate(tqdm(load_all_meta_and_mel_from_h5(args.h5_file))): |
|
mel_segment = extract_random_mel_segment(mel, segment_length=args.segment_length) |
|
|
|
ground_truth_wav_path = os.path.join(args.output_gt_dir, f"{key}.wav") |
|
save_audio(mel_segment, vocoder, ground_truth_wav_path, sample_rate=args.sample_rate) |
|
|
|
generated_mel = infer_and_generate_audio(model, diffusion, meta_latent) |
|
|
|
output_wav_path = os.path.join(args.output_gen_dir, f"{key}.wav") |
|
save_audio(generated_mel, vocoder, output_wav_path, sample_rate=args.sample_rate) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
''' |
|
python sample.py --checkpoint ./gtzan-ck/model_epoch_20000.pt \ |
|
--h5_file ./dataset/gtzan_test.h5 \ |
|
--output_gt_dir ./sample/gn \ |
|
--output_gen_dir ./sample/gt \ |
|
--segment_length 800 \ |
|
--sample_rate 22050 |
|
''' |
|
|