Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import time | |
| import json | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from codeclm.trainer.codec_song_pl import CodecLM_PL | |
| from codeclm.models import CodecLM | |
| from third_party.demucs.models.pretrained import get_model_from_yaml | |
| class Separator: | |
| def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: | |
| if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): | |
| self.device = torch.device(f"cuda:{gpu_id}") | |
| else: | |
| self.device = torch.device("cpu") | |
| self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) | |
| def init_demucs_model(self, model_path, config_path): | |
| model = get_model_from_yaml(config_path, model_path) | |
| model.to(self.device) | |
| model.eval() | |
| return model | |
| def load_audio(self, f): | |
| a, fs = torchaudio.load(f) | |
| if (fs != 48000): | |
| a = torchaudio.functional.resample(a, fs, 48000) | |
| if a.shape[-1] >= 48000*10: | |
| a = a[..., :48000*10] | |
| else: | |
| a = torch.cat([a, a], -1) | |
| return a[:, 0:48000*10] | |
| def run(self, audio_path, output_dir='tmp', ext=".flac"): | |
| os.makedirs(output_dir, exist_ok=True) | |
| name, _ = os.path.splitext(os.path.split(audio_path)[-1]) | |
| output_paths = [] | |
| for stem in self.demucs_model.sources: | |
| output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") | |
| if os.path.exists(output_path): | |
| output_paths.append(output_path) | |
| if len(output_paths) == 1: # 4 | |
| vocal_path = output_paths[0] | |
| else: | |
| drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) | |
| for path in [drums_path, bass_path, other_path]: | |
| os.remove(path) | |
| full_audio = self.load_audio(audio_path) | |
| vocal_audio = self.load_audio(vocal_path) | |
| bgm_audio = full_audio - vocal_audio | |
| return full_audio, vocal_audio, bgm_audio | |
| def main_sep(): | |
| torch.backends.cudnn.enabled = False #taiji的某些傻呗node会报奇奇怪怪的错 | |
| OmegaConf.register_new_resolver("eval", lambda x: eval(x)) | |
| OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) | |
| OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0]) | |
| OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x))) | |
| cfg = OmegaConf.load(sys.argv[1]) | |
| save_dir = sys.argv[2] | |
| input_jsonl = sys.argv[3] | |
| sidx = sys.argv[4] | |
| cfg.mode = 'inference' | |
| max_duration = cfg.max_dur | |
| # Define model or load pretrained model | |
| model_light = CodecLM_PL(cfg) | |
| model_light = model_light.eval().cuda() | |
| model_light.audiolm.cfg = cfg | |
| model = CodecLM(name = "tmp", | |
| lm = model_light.audiolm, | |
| audiotokenizer = model_light.audio_tokenizer, | |
| max_duration = max_duration, | |
| seperate_tokenizer = model_light.seperate_tokenizer, | |
| ) | |
| separator = Separator() | |
| cfg_coef = 1.5 #25 | |
| temp = 1.0 | |
| top_k = 50 | |
| top_p = 0.0 | |
| record_tokens = True | |
| record_window = 50 | |
| model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef, | |
| top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window) | |
| os.makedirs(save_dir + "/token", exist_ok=True) | |
| os.makedirs(save_dir + "/audios", exist_ok=True) | |
| os.makedirs(save_dir + "/jsonl", exist_ok=True) | |
| with open(input_jsonl, "r") as fp: | |
| lines = fp.readlines() | |
| new_items = [] | |
| for line in lines: | |
| item = json.loads(line) | |
| target_name = f"{save_dir}/token/{item['idx']}_s{sidx}.npy" | |
| target_wav_name = f"{save_dir}/audios/{item['idx']}_s{sidx}.flac" | |
| descriptions = item["descriptions"] | |
| lyric = item["gt_lyric"] | |
| start_time = time.time() | |
| pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path']) | |
| generate_inp = { | |
| 'lyrics': [lyric.replace(" ", " ")], | |
| 'descriptions': [descriptions], | |
| 'melody_wavs': pmt_wav, | |
| 'vocal_wavs': vocal_wav, | |
| 'bgm_wavs': bgm_wav, | |
| } | |
| mid_time = time.time() | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| tokens = model.generate(**generate_inp, return_tokens=True) | |
| end_time = time.time() | |
| if tokens.shape[-1] > 3000: | |
| tokens = tokens[..., :3000] | |
| with torch.no_grad(): | |
| wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav) | |
| torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate) | |
| np.save(target_name, tokens.cpu().squeeze(0).numpy()) | |
| print(f"process{item['idx']}, demucs cost {mid_time - start_time}s, lm cos {end_time - mid_time}") | |
| item["idx"] = f"{item['idx']}_s{sidx}" | |
| item["tk_path"] = target_name | |
| new_items.append(item) | |
| src_jsonl_name = os.path.split(input_jsonl)[-1] | |
| with open(f"{save_dir}/jsonl/{src_jsonl_name}-s{sidx}.jsonl", "w", encoding='utf-8') as fw: | |
| for item in new_items: | |
| fw.writelines(json.dumps(item, ensure_ascii=False)+"\n") | |
| if __name__ == "__main__": | |
| main_sep() | |