import os import torch import argparse import math from einops import rearrange, repeat from PIL import Image from diffusers import AutoencoderKL from transformers import SpeechT5HifiGan from utils import load_t5, load_clap, load_ae from train import RF from constants import build_model def prepare(t5, clip, img, prompt): bs, c, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) if isinstance(prompt, str): prompt = [prompt] txt = t5(prompt) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) print(img_ids.size(), txt.size(), vec.size()) return img, { "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "y": vec.to(img.device), } def main(args): print('generate with MusicFlux') torch.manual_seed(args.seed) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu" latent_size = (256, 16) model = build_model(args.version).to(device) local_path = '/maindata/data/shared/multimodal/zhengcong.fei/code/music-flow/results/base/checkpoints/0050000.pt' state_dict = torch.load(local_path, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict['ema']) model.eval() # important! diffusion = RF() model_path = '/maindata/data/shared/multimodal/public/ckpts/FLUX.1-dev' # Setup VAE t5 = load_t5(device, max_length=256) clap = load_clap(device, max_length=256) model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2' vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')).to(device) vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')).to(device) with open(args.prompt_file, 'r') as f: conds_txt = f.readlines() L = len(conds_txt) unconds_txt = ["low quality, gentle"] * L print(L, conds_txt, unconds_txt) init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda() STEPSIZE = 50 img, conds = prepare(t5, clap, init_noise, conds_txt) _, unconds = prepare(t5, clap, init_noise, unconds_txt) with torch.autocast(device_type='cuda'): images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0) print(images[-1].size(), ) images = rearrange( images[-1], "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=128, w=8, ph=2, pw=2,) # print(images.size()) latents = 1 / vae.config.scaling_factor * images mel_spectrogram = vae.decode(latents).sample print(mel_spectrogram.size()) for i in range(L): x_i = mel_spectrogram[i] if x_i.dim() == 4: x_i = x_i.squeeze(1) waveform = vocoder(x_i) waveform = waveform[0].cpu().float().detach().numpy() print(waveform.shape) # import soundfile as sf # sf.write('reconstruct.wav', waveform, samplerate=16000) from scipy.io import wavfile wavfile.write('wav/sample_' + str(i) + '.wav', 16000, waveform) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--version", type=str, default="base") parser.add_argument("--prompt_file", type=str, default='config/example.txt') parser.add_argument("--seed", type=int, default=2024) args = parser.parse_args() main(args)