Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torchaudio | |
| import os | |
| import re | |
| import json | |
| import random | |
| import io | |
| from huggingface_hub import hf_hub_download | |
| from muq import MuQMuLan | |
| from diffrhythm2.cfm import CFM | |
| from diffrhythm2.backbones.dit import DiT | |
| from bigvgan.model import Generator | |
| STRUCT_INFO = { | |
| "[start]": 500, | |
| "[end]": 501, | |
| "[intro]": 502, | |
| "[verse]": 503, | |
| "[chorus]": 504, | |
| "[outro]": 505, | |
| "[inst]": 506, | |
| "[solo]": 507, | |
| "[bridge]": 508, | |
| "[hook]": 509, | |
| "[break]": 510, | |
| "[stop]": 511, | |
| "[space]": 512 | |
| } | |
| class CNENTokenizer(): | |
| def __init__(self): | |
| curr_path = os.path.abspath(__file__) | |
| vocab_path = os.path.join(os.path.dirname((os.path.dirname(curr_path))), "g2p/g2p/vocab.json") | |
| with open(vocab_path, 'r') as file: | |
| self.phone2id:dict = json.load(file)['vocab'] | |
| self.id2phone = {v:k for (k, v) in self.phone2id.items()} | |
| from g2p.g2p_generation import chn_eng_g2p | |
| self.tokenizer = chn_eng_g2p | |
| def encode(self, text): | |
| phone, token = self.tokenizer(text) | |
| token = [x+1 for x in token] | |
| return token | |
| def decode(self, token): | |
| return "|".join([self.id2phone[x-1] for x in token]) | |
| def prepare_model(repo_id, device, dtype): | |
| diffrhythm2_ckpt_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="model.safetensors", | |
| local_dir="./ckpt", | |
| local_files_only=False, | |
| ) | |
| diffrhythm2_config_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="model.json", | |
| local_dir="./ckpt", | |
| local_files_only=False, | |
| ) | |
| with open(diffrhythm2_config_path) as f: | |
| model_config = json.load(f) | |
| model_config['use_flex_attn'] = False | |
| diffrhythm2 = CFM( | |
| transformer=DiT( | |
| **model_config | |
| ), | |
| num_channels=model_config['mel_dim'], | |
| block_size=model_config['block_size'], | |
| ) | |
| total_params = sum(p.numel() for p in diffrhythm2.parameters()) | |
| diffrhythm2 = diffrhythm2.to(device).to(dtype) | |
| if diffrhythm2_ckpt_path.endswith('.safetensors'): | |
| from safetensors.torch import load_file | |
| ckpt = load_file(diffrhythm2_ckpt_path) | |
| else: | |
| ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu') | |
| diffrhythm2.load_state_dict(ckpt) | |
| print(f"Total params: {total_params:,}") | |
| # load Mulan | |
| mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device).to(dtype) | |
| # load frontend | |
| lrc_tokenizer = CNENTokenizer() | |
| # load decoder | |
| decoder_ckpt_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="decoder.bin", | |
| local_dir="./ckpt", | |
| local_files_only=False, | |
| ) | |
| decoder_config_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="decoder.json", | |
| local_dir="./ckpt", | |
| local_files_only=False, | |
| ) | |
| decoder = Generator(decoder_config_path, decoder_ckpt_path) | |
| decoder = decoder.to(device).to(dtype) | |
| return diffrhythm2, mulan, lrc_tokenizer, decoder | |
| STRUCT_PATTERN = re.compile(r'^\[.*?\]$') | |
| def parse_lyrics(lrc_tokenizer, lyrics: str): | |
| lyrics_with_time = [] | |
| lyrics = lyrics.split("\n") | |
| get_start = False | |
| for line in lyrics: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| struct_flag = STRUCT_PATTERN.match(line) | |
| if struct_flag: | |
| struct_idx = STRUCT_INFO.get(line.lower(), None) | |
| if struct_idx is not None: | |
| if struct_idx == STRUCT_INFO['[start]']: | |
| get_start = True | |
| lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']]) | |
| else: | |
| continue | |
| else: | |
| tokens = lrc_tokenizer.encode(line.strip()) | |
| tokens = tokens + [STRUCT_INFO['[stop]']] | |
| lyrics_with_time.append(tokens) | |
| if len(lyrics_with_time) != 0 and not get_start: | |
| lyrics_with_time = [[STRUCT_INFO['[start]'], STRUCT_INFO['[stop]']]] + lyrics_with_time | |
| return lyrics_with_time | |
| def get_audio_prompt(model, audio_file, device, dtype): | |
| prompt_wav, sr = torchaudio.load(audio_file) | |
| prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000) | |
| if prompt_wav.shape[1] > 24000 * 10: | |
| start = random.randint(0, prompt_wav.shape[1] - 24000 * 10) | |
| prompt_wav = prompt_wav[:, start:start+24000*10] | |
| prompt_wav = prompt_wav.mean(dim=0, keepdim=True) | |
| with torch.no_grad(): | |
| style_prompt_embed = model(wavs = prompt_wav) | |
| return style_prompt_embed.squeeze(0).detach() | |
| def get_text_prompt(model, text, device, dtype): | |
| with torch.no_grad(): | |
| style_prompt_embed = model(texts = [text]) | |
| return style_prompt_embed.squeeze(0).detach() | |
| def make_fake_stereo(audio, sampling_rate): | |
| left_channel = audio | |
| right_channel = audio.clone() | |
| right_channel = right_channel * 0.8 | |
| delay_samples = int(0.01 * sampling_rate) | |
| right_channel = torch.roll(right_channel, delay_samples) | |
| right_channel[:,:delay_samples] = 0 | |
| # stereo_audio = np.concatenate([left_channel, right_channel], axis=0) | |
| stereo_audio = torch.cat([left_channel, right_channel], dim=0) | |
| return stereo_audio | |
| def inference( | |
| model, | |
| decoder, | |
| text, | |
| style_prompt, | |
| duration, | |
| cfg_strength=1.0, | |
| sample_steps=32, | |
| fake_stereo=True, | |
| odeint_method='euler', | |
| file_type="wav" | |
| ): | |
| with torch.inference_mode(): | |
| latent = model.sample_block_cache( | |
| text=text.unsqueeze(0), | |
| duration=int(duration * 5), | |
| style_prompt=style_prompt.unsqueeze(0), | |
| steps=sample_steps, | |
| cfg_strength=cfg_strength, | |
| odeint_method=odeint_method | |
| ) | |
| latent = latent.transpose(1, 2).detach() | |
| audio = decoder.decode_audio(latent, overlap=5, chunk_size=20).detach() | |
| num_channels = 1 | |
| audio = audio.float().cpu().detach().squeeze()[None, :] | |
| if fake_stereo: | |
| audio = make_fake_stereo(audio, decoder.h.sampling_rate) | |
| num_channels = 2 | |
| if file_type == 'wav': | |
| return (decoder.h.sampling_rate, audio.numpy().T) # [channel, time] | |
| else: | |
| buffer = io.BytesIO() | |
| torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type) | |
| return buffer.getvalue() | |
| def inference_stream( | |
| model, | |
| decoder, | |
| text, | |
| style_prompt, | |
| duration, | |
| cfg_strength=1.0, | |
| sample_steps=32, | |
| fake_stereo=True, | |
| odeint_method='euler', | |
| file_type="wav" | |
| ): | |
| with torch.inference_mode(): | |
| for audio in model.sample_cache_stream( | |
| decoder=decoder, | |
| text=text.unsqueeze(0), | |
| duration=int(duration * 5), | |
| style_prompt=style_prompt.unsqueeze(0), | |
| steps=sample_steps, | |
| cfg_strength=cfg_strength, | |
| chunk_size=20, | |
| overlap=5, | |
| odeint_method=odeint_method | |
| ): | |
| audio = audio.float().cpu().numpy().squeeze()[None, :] | |
| if fake_stereo: | |
| audio = make_fake_stereo(audio, decoder.h.sampling_rate) | |
| # encoded_audio = io.BytesIO() | |
| # torchaudio.save(encoded_audio, audio, decoder.h.sampling_rate, format='wav') | |
| yield (decoder.h.sampling_rate, audio.T) # [channel, time] | |