ASLP-lab's picture
Update diffrhythm2/utils.py
beb10c1 verified
raw
history blame
7.54 kB
import torch
import torchaudio
import os
import re
import json
import random
import io
from huggingface_hub import hf_hub_download
from muq import MuQMuLan
from diffrhythm2.cfm import CFM
from diffrhythm2.backbones.dit import DiT
from bigvgan.model import Generator
STRUCT_INFO = {
"[start]": 500,
"[end]": 501,
"[intro]": 502,
"[verse]": 503,
"[chorus]": 504,
"[outro]": 505,
"[inst]": 506,
"[solo]": 507,
"[bridge]": 508,
"[hook]": 509,
"[break]": 510,
"[stop]": 511,
"[space]": 512
}
class CNENTokenizer():
def __init__(self):
curr_path = os.path.abspath(__file__)
vocab_path = os.path.join(os.path.dirname((os.path.dirname(curr_path))), "g2p/g2p/vocab.json")
with open(vocab_path, 'r') as file:
self.phone2id:dict = json.load(file)['vocab']
self.id2phone = {v:k for (k, v) in self.phone2id.items()}
from g2p.g2p_generation import chn_eng_g2p
self.tokenizer = chn_eng_g2p
def encode(self, text):
phone, token = self.tokenizer(text)
token = [x+1 for x in token]
return token
def decode(self, token):
return "|".join([self.id2phone[x-1] for x in token])
def prepare_model(repo_id, device, dtype):
diffrhythm2_ckpt_path = hf_hub_download(
repo_id=repo_id,
filename="model.safetensors",
local_dir="./ckpt",
local_files_only=False,
)
diffrhythm2_config_path = hf_hub_download(
repo_id=repo_id,
filename="model.json",
local_dir="./ckpt",
local_files_only=False,
)
with open(diffrhythm2_config_path) as f:
model_config = json.load(f)
model_config['use_flex_attn'] = False
diffrhythm2 = CFM(
transformer=DiT(
**model_config
),
num_channels=model_config['mel_dim'],
block_size=model_config['block_size'],
)
total_params = sum(p.numel() for p in diffrhythm2.parameters())
diffrhythm2 = diffrhythm2.to(device).to(dtype)
if diffrhythm2_ckpt_path.endswith('.safetensors'):
from safetensors.torch import load_file
ckpt = load_file(diffrhythm2_ckpt_path)
else:
ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu')
diffrhythm2.load_state_dict(ckpt)
print(f"Total params: {total_params:,}")
# load Mulan
mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device).to(dtype)
# load frontend
lrc_tokenizer = CNENTokenizer()
# load decoder
decoder_ckpt_path = hf_hub_download(
repo_id=repo_id,
filename="decoder.bin",
local_dir="./ckpt",
local_files_only=False,
)
decoder_config_path = hf_hub_download(
repo_id=repo_id,
filename="decoder.json",
local_dir="./ckpt",
local_files_only=False,
)
decoder = Generator(decoder_config_path, decoder_ckpt_path)
decoder = decoder.to(device).to(dtype)
return diffrhythm2, mulan, lrc_tokenizer, decoder
STRUCT_PATTERN = re.compile(r'^\[.*?\]$')
def parse_lyrics(lrc_tokenizer, lyrics: str):
lyrics_with_time = []
lyrics = lyrics.split("\n")
get_start = False
for line in lyrics:
line = line.strip()
if not line:
continue
struct_flag = STRUCT_PATTERN.match(line)
if struct_flag:
struct_idx = STRUCT_INFO.get(line.lower(), None)
if struct_idx is not None:
if struct_idx == STRUCT_INFO['[start]']:
get_start = True
lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']])
else:
continue
else:
tokens = lrc_tokenizer.encode(line.strip())
tokens = tokens + [STRUCT_INFO['[stop]']]
lyrics_with_time.append(tokens)
if len(lyrics_with_time) != 0 and not get_start:
lyrics_with_time = [[STRUCT_INFO['[start]'], STRUCT_INFO['[stop]']]] + lyrics_with_time
return lyrics_with_time
@torch.no_grad()
def get_audio_prompt(model, audio_file, device, dtype):
prompt_wav, sr = torchaudio.load(audio_file)
prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
if prompt_wav.shape[1] > 24000 * 10:
start = random.randint(0, prompt_wav.shape[1] - 24000 * 10)
prompt_wav = prompt_wav[:, start:start+24000*10]
prompt_wav = prompt_wav.mean(dim=0, keepdim=True)
with torch.no_grad():
style_prompt_embed = model(wavs = prompt_wav)
return style_prompt_embed.squeeze(0).detach()
@torch.no_grad()
def get_text_prompt(model, text, device, dtype):
with torch.no_grad():
style_prompt_embed = model(texts = [text])
return style_prompt_embed.squeeze(0).detach()
@torch.no_grad()
def make_fake_stereo(audio, sampling_rate):
left_channel = audio
right_channel = audio.clone()
right_channel = right_channel * 0.8
delay_samples = int(0.01 * sampling_rate)
right_channel = torch.roll(right_channel, delay_samples)
right_channel[:,:delay_samples] = 0
# stereo_audio = np.concatenate([left_channel, right_channel], axis=0)
stereo_audio = torch.cat([left_channel, right_channel], dim=0)
return stereo_audio
def inference(
model,
decoder,
text,
style_prompt,
duration,
cfg_strength=1.0,
sample_steps=32,
fake_stereo=True,
odeint_method='euler',
file_type="wav"
):
with torch.inference_mode():
latent = model.sample_block_cache(
text=text.unsqueeze(0),
duration=int(duration * 5),
style_prompt=style_prompt.unsqueeze(0),
steps=sample_steps,
cfg_strength=cfg_strength,
odeint_method=odeint_method
)
latent = latent.transpose(1, 2).detach()
audio = decoder.decode_audio(latent, overlap=5, chunk_size=20).detach()
num_channels = 1
audio = audio.float().cpu().detach().squeeze()[None, :]
if fake_stereo:
audio = make_fake_stereo(audio, decoder.h.sampling_rate)
num_channels = 2
if file_type == 'wav':
return (decoder.h.sampling_rate, audio.numpy().T) # [channel, time]
else:
buffer = io.BytesIO()
torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
return buffer.getvalue()
def inference_stream(
model,
decoder,
text,
style_prompt,
duration,
cfg_strength=1.0,
sample_steps=32,
fake_stereo=True,
odeint_method='euler',
file_type="wav"
):
with torch.inference_mode():
for audio in model.sample_cache_stream(
decoder=decoder,
text=text.unsqueeze(0),
duration=int(duration * 5),
style_prompt=style_prompt.unsqueeze(0),
steps=sample_steps,
cfg_strength=cfg_strength,
chunk_size=20,
overlap=5,
odeint_method=odeint_method
):
audio = audio.float().cpu().numpy().squeeze()[None, :]
if fake_stereo:
audio = make_fake_stereo(audio, decoder.h.sampling_rate)
# encoded_audio = io.BytesIO()
# torchaudio.save(encoded_audio, audio, decoder.h.sampling_rate, format='wav')
yield (decoder.h.sampling_rate, audio.T) # [channel, time]