Spaces:
Running
on
L40S
Running
on
L40S
root
commited on
Commit
·
f9e2d84
1
Parent(s):
410c1c2
update v1.5-beta
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +5 -7
- codeclm/models/builders.py +5 -4
- codeclm/modules/conditioners.py +13 -3
- codeclm/tokenizer/Flow1dVAE/cal_token_stat.py +0 -19
- codeclm/tokenizer/Flow1dVAE/compare_model_weight.py +0 -13
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py +0 -121
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py +0 -94
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py +0 -70
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py +0 -46
- codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py +0 -86
- codeclm/tokenizer/Flow1dVAE/generate_1rvq.py +3 -32
- codeclm/tokenizer/Flow1dVAE/generate_2rvq.py +0 -293
- codeclm/tokenizer/Flow1dVAE/generate_4rvq.py +0 -292
- codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py +0 -1278
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py +0 -372
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py +0 -830
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py +0 -994
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py +0 -313
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py +0 -313
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py +0 -313
- codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py +0 -461
- codeclm/tokenizer/Flow1dVAE/model_1rvq.py +0 -2
- codeclm/tokenizer/Flow1dVAE/model_2rvq.py +0 -774
- codeclm/tokenizer/Flow1dVAE/model_4rvq.py +0 -774
- codeclm/tokenizer/Flow1dVAE/model_septoken.py +0 -2
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml +0 -122
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml +0 -125
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml +0 -137
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml +0 -139
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml +0 -138
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml +0 -139
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml +0 -135
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml +0 -137
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml +0 -116
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml +0 -125
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml +0 -128
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml +0 -126
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml +0 -128
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml +0 -128
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml +0 -121
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac.yaml +0 -0
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml +0 -121
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml +0 -125
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml +0 -124
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml +0 -108
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml +0 -105
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml +0 -106
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml +0 -20
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py +0 -2
- codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py +0 -115
app.py
CHANGED
|
@@ -16,14 +16,12 @@ from download import download_model
|
|
| 16 |
# 下载模型
|
| 17 |
APP_DIR = op.dirname(op.abspath(__file__))
|
| 18 |
download_model(APP_DIR)
|
| 19 |
-
|
| 20 |
-
os.makedirs(base_full_path, exist_ok=True)
|
| 21 |
-
download_model(base_full_path, repo_id="lglg666/SongGeneration-base-full", revision="19ebdb6")
|
| 22 |
print("Successful downloaded model.")
|
| 23 |
|
| 24 |
# 模型初始化
|
| 25 |
from levo_inference import LeVoInference
|
| 26 |
-
MODEL = LeVoInference(
|
| 27 |
|
| 28 |
EXAMPLE_LYRICS = """
|
| 29 |
[intro-medium]
|
|
@@ -225,7 +223,7 @@ lyrics
|
|
| 225 |
minimum=0.1,
|
| 226 |
maximum=2.0,
|
| 227 |
step=0.1,
|
| 228 |
-
value=0.
|
| 229 |
interactive=True,
|
| 230 |
elem_id="temperature",
|
| 231 |
)
|
|
@@ -268,12 +266,12 @@ lyrics
|
|
| 268 |
# 生成按钮点击事件
|
| 269 |
generate_btn.click(
|
| 270 |
fn=generate_song,
|
| 271 |
-
inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(
|
| 272 |
outputs=[output_audio, output_json]
|
| 273 |
)
|
| 274 |
generate_bgm_btn.click(
|
| 275 |
fn=generate_song,
|
| 276 |
-
inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(
|
| 277 |
outputs=[output_audio, output_json]
|
| 278 |
)
|
| 279 |
|
|
|
|
| 16 |
# 下载模型
|
| 17 |
APP_DIR = op.dirname(op.abspath(__file__))
|
| 18 |
download_model(APP_DIR)
|
| 19 |
+
download_model(op.join(APP_DIR, "ckpt"), repo_id="waytan22/SongGeneration-v1.5-beta", revision="db10f47")
|
|
|
|
|
|
|
| 20 |
print("Successful downloaded model.")
|
| 21 |
|
| 22 |
# 模型初始化
|
| 23 |
from levo_inference import LeVoInference
|
| 24 |
+
MODEL = LeVoInference(op.join(APP_DIR, "ckpt", "SongGeneration-v1.5-beta"))
|
| 25 |
|
| 26 |
EXAMPLE_LYRICS = """
|
| 27 |
[intro-medium]
|
|
|
|
| 223 |
minimum=0.1,
|
| 224 |
maximum=2.0,
|
| 225 |
step=0.1,
|
| 226 |
+
value=0.8,
|
| 227 |
interactive=True,
|
| 228 |
elem_id="temperature",
|
| 229 |
)
|
|
|
|
| 266 |
# 生成按钮点击事件
|
| 267 |
generate_btn.click(
|
| 268 |
fn=generate_song,
|
| 269 |
+
inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50)],
|
| 270 |
outputs=[output_audio, output_json]
|
| 271 |
)
|
| 272 |
generate_bgm_btn.click(
|
| 273 |
fn=generate_song,
|
| 274 |
+
inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50), gr.State("bgm")],
|
| 275 |
outputs=[output_audio, output_json]
|
| 276 |
)
|
| 277 |
|
codeclm/models/builders.py
CHANGED
|
@@ -52,7 +52,7 @@ def get_audio_tokenizer_model_cpu(checkpoint_path: str, cfg: omegaconf.DictConfi
|
|
| 52 |
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
|
| 53 |
|
| 54 |
|
| 55 |
-
def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
|
| 56 |
"""Instantiate a LM."""
|
| 57 |
lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
|
| 58 |
|
|
@@ -61,8 +61,8 @@ def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
|
|
| 61 |
q_modeling = lm_kwargs.pop('q_modeling', None)
|
| 62 |
|
| 63 |
# conditioner
|
| 64 |
-
condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg)
|
| 65 |
-
|
| 66 |
# codebook pattern: delay
|
| 67 |
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
| 68 |
if codebooks_pattern_cfg.modeling is None:
|
|
@@ -97,7 +97,7 @@ def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
|
|
| 97 |
raise KeyError(f"Unexpected LM model {lm_type}")
|
| 98 |
|
| 99 |
|
| 100 |
-
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider:
|
| 101 |
"""Instantiate a conditioning model."""
|
| 102 |
cfg = getattr(cfg, 'conditioners')
|
| 103 |
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
|
@@ -115,6 +115,7 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
|
|
| 115 |
elif model_type == "QwTextTokenizer":
|
| 116 |
conditioners[str(cond)] = QwTextConditioner(
|
| 117 |
output_dim=output_dim,
|
|
|
|
| 118 |
**model_args
|
| 119 |
)
|
| 120 |
elif model_type == "qt_embedding":
|
|
|
|
| 52 |
return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
|
| 53 |
|
| 54 |
|
| 55 |
+
def get_lm_model(cfg: omegaconf.DictConfig, version: str = 'v1.0'): #-> LMModel:
|
| 56 |
"""Instantiate a LM."""
|
| 57 |
lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
|
| 58 |
|
|
|
|
| 61 |
q_modeling = lm_kwargs.pop('q_modeling', None)
|
| 62 |
|
| 63 |
# conditioner
|
| 64 |
+
condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg, version=version)
|
| 65 |
+
|
| 66 |
# codebook pattern: delay
|
| 67 |
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
| 68 |
if codebooks_pattern_cfg.modeling is None:
|
|
|
|
| 97 |
raise KeyError(f"Unexpected LM model {lm_type}")
|
| 98 |
|
| 99 |
|
| 100 |
+
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig, version: str = 'v1.0') -> ConditionerProvider:
|
| 101 |
"""Instantiate a conditioning model."""
|
| 102 |
cfg = getattr(cfg, 'conditioners')
|
| 103 |
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
|
|
|
| 115 |
elif model_type == "QwTextTokenizer":
|
| 116 |
conditioners[str(cond)] = QwTextConditioner(
|
| 117 |
output_dim=output_dim,
|
| 118 |
+
version=version,
|
| 119 |
**model_args
|
| 120 |
)
|
| 121 |
elif model_type == "qt_embedding":
|
codeclm/modules/conditioners.py
CHANGED
|
@@ -188,10 +188,13 @@ class QwTokenizerConditioner(TextConditioner):
|
|
| 188 |
class QwTextConditioner(TextConditioner):
|
| 189 |
def __init__(self, output_dim: int,
|
| 190 |
token_path = "",
|
| 191 |
-
max_len = 300
|
|
|
|
| 192 |
|
| 193 |
from transformers import Qwen2Tokenizer
|
| 194 |
-
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
|
|
|
|
|
|
|
| 195 |
voc_size = len(self.text_tokenizer.get_vocab())
|
| 196 |
# here initialize a output_proj (nn.Embedding) layer
|
| 197 |
super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
|
|
@@ -636,7 +639,14 @@ class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout):
|
|
| 636 |
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
|
| 637 |
else:
|
| 638 |
if customized is None:
|
| 639 |
-
sample.text[condition]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
else:
|
| 641 |
text_cond = deepcopy(sample.text[condition])
|
| 642 |
if "structure" in customized:
|
|
|
|
| 188 |
class QwTextConditioner(TextConditioner):
|
| 189 |
def __init__(self, output_dim: int,
|
| 190 |
token_path = "",
|
| 191 |
+
max_len = 300,
|
| 192 |
+
version: str = 'v1.0'): #""
|
| 193 |
|
| 194 |
from transformers import Qwen2Tokenizer
|
| 195 |
+
self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
|
| 196 |
+
if version == 'v1.5':
|
| 197 |
+
self.text_tokenizer.add_tokens(['[Musicality-very-high]', '[Musicality-high]', '[Musicality-medium]', '[Musicality-low]', '[Musicality-very-low]'], special_tokens=True)
|
| 198 |
voc_size = len(self.text_tokenizer.get_vocab())
|
| 199 |
# here initialize a output_proj (nn.Embedding) layer
|
| 200 |
super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
|
|
|
|
| 639 |
sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
|
| 640 |
else:
|
| 641 |
if customized is None:
|
| 642 |
+
if condition in ['type_info'] and sample.text[condition] is not None:
|
| 643 |
+
if "[Musicality-very-high]" in sample.text[condition]:
|
| 644 |
+
sample.text[condition] = "[Musicality-very-low], ."
|
| 645 |
+
print(f"cfg unconditioning: change sample.text[condition] to [Musicality-very-low]")
|
| 646 |
+
else:
|
| 647 |
+
sample.text[condition] = None
|
| 648 |
+
else:
|
| 649 |
+
sample.text[condition] = None
|
| 650 |
else:
|
| 651 |
text_cond = deepcopy(sample.text[condition])
|
| 652 |
if "structure" in customized:
|
codeclm/tokenizer/Flow1dVAE/cal_token_stat.py
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
import kaldiio
|
| 2 |
-
from tqdm import tqdm
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
if __name__ == "__main__":
|
| 6 |
-
bar = torch.zeros(1, 16384)
|
| 7 |
-
with open('token.scp', 'r') as f:
|
| 8 |
-
for item_idx, line in tqdm(enumerate(f)):
|
| 9 |
-
idx, pos = line.strip().split()
|
| 10 |
-
codes = kaldiio.load_mat(pos)
|
| 11 |
-
for i0 in range(codes.shape[-1]):
|
| 12 |
-
bar[0, codes[0, 0, i0]] += 1
|
| 13 |
-
if(item_idx % 1000 == 0):
|
| 14 |
-
print("=========")
|
| 15 |
-
print(1 - (bar[0]==0).sum() / bar.shape[-1])
|
| 16 |
-
print("=========")
|
| 17 |
-
print("=========")
|
| 18 |
-
print(1 - (bar[0]==0).sum() / bar.shape[-1])
|
| 19 |
-
print("=========")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/compare_model_weight.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import sys
|
| 3 |
-
from safetensors.torch import load_file
|
| 4 |
-
|
| 5 |
-
if __name__ == "__main__":
|
| 6 |
-
m0, m1 = sys.argv[1], sys.argv[2]
|
| 7 |
-
m0 = load_file(m0)
|
| 8 |
-
m1 = load_file(m1)
|
| 9 |
-
|
| 10 |
-
ks = [k for k in m0.keys() if 'bestrq' in k]
|
| 11 |
-
for k in ks:
|
| 12 |
-
print(k, (m0[k] - m1[k]).abs().sum())
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py
DELETED
|
@@ -1,121 +0,0 @@
|
|
| 1 |
-
import torch,torchaudio
|
| 2 |
-
import os,sys,json
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
import numpy as np
|
| 5 |
-
|
| 6 |
-
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
| 7 |
-
from generate_septoken import Tango as Tango_sep
|
| 8 |
-
from generate_2rvq import Tango as Tango_1x2
|
| 9 |
-
import kaldiio
|
| 10 |
-
from kaldiio import WriteHelper
|
| 11 |
-
from audio import AudioFile
|
| 12 |
-
|
| 13 |
-
from demucs.models.pretrained import get_model_from_yaml
|
| 14 |
-
from filelock import FileLock
|
| 15 |
-
|
| 16 |
-
# os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml")
|
| 17 |
-
class Separator:
|
| 18 |
-
def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
| 19 |
-
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
| 20 |
-
self.device = torch.device(f"cuda:{gpu_id}")
|
| 21 |
-
else:
|
| 22 |
-
self.device = torch.device("cpu")
|
| 23 |
-
self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
|
| 24 |
-
|
| 25 |
-
def init_demucs_model(self, model_path, config_path):
|
| 26 |
-
model = get_model_from_yaml(config_path, model_path)
|
| 27 |
-
model.to(self.device)
|
| 28 |
-
model.eval()
|
| 29 |
-
return model
|
| 30 |
-
|
| 31 |
-
def load_audio(self, f):
|
| 32 |
-
a, fs = torchaudio.load(f)
|
| 33 |
-
if (fs != 48000):
|
| 34 |
-
a = torchaudio.functional.resample(a, fs, 48000)
|
| 35 |
-
# if a.shape[-1] >= 48000*10:
|
| 36 |
-
# a = a[..., :48000*10]
|
| 37 |
-
# else:
|
| 38 |
-
# a = torch.cat([a, a], -1)
|
| 39 |
-
# return a[:, 0:48000*10]
|
| 40 |
-
return a
|
| 41 |
-
|
| 42 |
-
def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"):
|
| 43 |
-
name, _ = os.path.splitext(os.path.split(audio_path)[-1])
|
| 44 |
-
output_paths = []
|
| 45 |
-
# lock_path = os.path.join(output_dir, f"{name}.lock")
|
| 46 |
-
# with FileLock(lock_path): # 加一个避免多卡访问时死锁
|
| 47 |
-
for stem in self.demucs_model.sources:
|
| 48 |
-
output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
|
| 49 |
-
if os.path.exists(output_path):
|
| 50 |
-
output_paths.append(output_path)
|
| 51 |
-
if len(output_paths) == 1: # 4
|
| 52 |
-
# drums_path, bass_path, other_path, vocal_path = output_paths
|
| 53 |
-
vocal_path = output_paths[0]
|
| 54 |
-
else:
|
| 55 |
-
lock_path = os.path.join(output_dir, f"{name}_separate.lock")
|
| 56 |
-
with FileLock(lock_path):
|
| 57 |
-
drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
|
| 58 |
-
full_audio = self.load_audio(audio_path)
|
| 59 |
-
vocal_audio = self.load_audio(vocal_path)
|
| 60 |
-
minlen = min(full_audio.shape[-1], vocal_audio.shape[-1])
|
| 61 |
-
# bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen]
|
| 62 |
-
bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path)
|
| 63 |
-
for path in [drums_path, bass_path, other_path, vocal_path]:
|
| 64 |
-
os.remove(path)
|
| 65 |
-
return full_audio, vocal_audio, bgm_audio
|
| 66 |
-
|
| 67 |
-
def read_wav(fname, sample_rate=48_000):
|
| 68 |
-
try:
|
| 69 |
-
orig_samples, fs = torchaudio.load(fname)
|
| 70 |
-
except:
|
| 71 |
-
af = AudioFile(fname)
|
| 72 |
-
orig_samples = af.read()
|
| 73 |
-
fs = af.samplerate()
|
| 74 |
-
orig_samples = orig_samples[0]
|
| 75 |
-
if(fs!=sample_rate):
|
| 76 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate)
|
| 77 |
-
fs = sample_rate
|
| 78 |
-
if orig_samples.shape[0] == 1:
|
| 79 |
-
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| 80 |
-
return orig_samples
|
| 81 |
-
|
| 82 |
-
if __name__ == "__main__":
|
| 83 |
-
# Define Model
|
| 84 |
-
json_path = sys.argv[1]
|
| 85 |
-
|
| 86 |
-
mus_infos = []
|
| 87 |
-
with open(json_path) as f:
|
| 88 |
-
for line in f:
|
| 89 |
-
item = json.loads(line)
|
| 90 |
-
mus_infos.append(item)
|
| 91 |
-
|
| 92 |
-
tango_sep = Tango_sep(model_path="./saved/model_septoken/model_2.safetensors")
|
| 93 |
-
tango_1x2 = Tango_1x2(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2)
|
| 94 |
-
separator = Separator()
|
| 95 |
-
|
| 96 |
-
# Feature extraction loop
|
| 97 |
-
# for i in tqdm(range(2000)):
|
| 98 |
-
first_time = True
|
| 99 |
-
for item in tqdm(mus_infos):
|
| 100 |
-
if(os.path.exists(item['path'])):
|
| 101 |
-
full_path = item['path']
|
| 102 |
-
else:
|
| 103 |
-
full_path = '/mnt/share/' + item['path']
|
| 104 |
-
|
| 105 |
-
full_tensor, vocal_tensor, bgm_tensor = separator.run(full_path)
|
| 106 |
-
|
| 107 |
-
# full_tensor = read_wav(full_path)
|
| 108 |
-
# vocal_tensor = read_wav(vocal_path)
|
| 109 |
-
# length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
|
| 110 |
-
# full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
|
| 111 |
-
# bgm_tensor = full_tensor - vocal_tensor
|
| 112 |
-
codes_1x2 = tango_1x2.sound2code(full_tensor)
|
| 113 |
-
codes_vocal, codes_bgm = tango_sep.sound2code(vocal_tensor, bgm_tensor)
|
| 114 |
-
codes = torch.cat([codes_1x2[:,[0],:], codes_vocal, codes_bgm], 1).cpu().numpy()
|
| 115 |
-
save_path = full_path.replace('.wav', '.1x1_and_sep.npy').replace('.mp3', '.1x1_and_sep.npy').replace('.flac', '.1x1_and_sep.npy').replace('.ogg', '.1x1_and_sep.npy')
|
| 116 |
-
assert save_path != full_path, (save_path, full_path)
|
| 117 |
-
np.save(save_path, codes)
|
| 118 |
-
|
| 119 |
-
if(first_time):
|
| 120 |
-
first_time = False
|
| 121 |
-
print(codes_vocal.shape, codes_bgm.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
import torch,torchaudio
|
| 2 |
-
import os,sys,json
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
|
| 5 |
-
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
| 6 |
-
from generate_septoken import Tango
|
| 7 |
-
import kaldiio
|
| 8 |
-
from kaldiio import WriteHelper
|
| 9 |
-
from audio import AudioFile
|
| 10 |
-
|
| 11 |
-
def read_wav(fname, sample_rate=48_000):
|
| 12 |
-
try:
|
| 13 |
-
orig_samples, fs = torchaudio.load(fname)
|
| 14 |
-
except:
|
| 15 |
-
af = AudioFile(fname)
|
| 16 |
-
orig_samples = af.read()
|
| 17 |
-
fs = af.samplerate()
|
| 18 |
-
orig_samples = orig_samples[0]
|
| 19 |
-
if(fs!=sample_rate):
|
| 20 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate)
|
| 21 |
-
fs = sample_rate
|
| 22 |
-
if orig_samples.shape[0] == 1:
|
| 23 |
-
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| 24 |
-
return orig_samples
|
| 25 |
-
|
| 26 |
-
if __name__ == "__main__":
|
| 27 |
-
# Define Model
|
| 28 |
-
json_path = sys.argv[1]
|
| 29 |
-
outdir = sys.argv[2]
|
| 30 |
-
|
| 31 |
-
mus_infos = []
|
| 32 |
-
with open(json_path) as f:
|
| 33 |
-
for line in f:
|
| 34 |
-
item = json.loads(line)
|
| 35 |
-
mus_infos.append(item)
|
| 36 |
-
|
| 37 |
-
tango = Tango(model_path="./saved/model_septoken/model_2.safetensors")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# Feature extraction loop
|
| 41 |
-
# for i in tqdm(range(2000)):
|
| 42 |
-
first_time = True
|
| 43 |
-
with WriteHelper('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir), write_function="pickle") as writer_vocal, WriteHelper('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir), write_function="pickle") as writer_bgm:
|
| 44 |
-
print('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir))
|
| 45 |
-
print('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir))
|
| 46 |
-
for item in tqdm(mus_infos):
|
| 47 |
-
try:
|
| 48 |
-
# if True:
|
| 49 |
-
idx = item['idx']
|
| 50 |
-
# print(idx)
|
| 51 |
-
if(os.path.exists(item['path'])):
|
| 52 |
-
full_path = item['path']
|
| 53 |
-
else:
|
| 54 |
-
full_path = '/mnt/share/' + item['path']
|
| 55 |
-
if(os.path.exists(item['vocal_path'])):
|
| 56 |
-
vocal_path = item['vocal_path']
|
| 57 |
-
bgm_paths = item['bgm_path']
|
| 58 |
-
else:
|
| 59 |
-
vocal_path = '/mnt/share/' + item['vocal_path']
|
| 60 |
-
bgm_paths = ['/mnt/share/' + p for p in item['bgm_path']]
|
| 61 |
-
vocal_tensor = read_wav(vocal_path)
|
| 62 |
-
# full_tensor = read_wav(full_path)
|
| 63 |
-
# length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
|
| 64 |
-
# full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
|
| 65 |
-
# bgm_tensor = full_tensor - vocal_tensor
|
| 66 |
-
bgm_tensor = sum([read_wav(p) for p in bgm_paths])
|
| 67 |
-
codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor)
|
| 68 |
-
writer_vocal(str(idx), codes_vocal.cpu())
|
| 69 |
-
writer_bgm(str(idx), codes_bgm.cpu())
|
| 70 |
-
if(first_time):
|
| 71 |
-
first_time = False
|
| 72 |
-
print(codes_vocal.shape, codes_bgm.shape)
|
| 73 |
-
except:
|
| 74 |
-
print(item['vocal_path'])
|
| 75 |
-
print(item['bgm_path'])
|
| 76 |
-
continue
|
| 77 |
-
|
| 78 |
-
# idx = item['idx']
|
| 79 |
-
# # print(idx)
|
| 80 |
-
# full_path = item['path']
|
| 81 |
-
# vocal_path = item['vocal_path']
|
| 82 |
-
# bgm_paths = item['bgm_path']
|
| 83 |
-
# full_tensor = read_wav(full_path)
|
| 84 |
-
# vocal_tensor = read_wav(vocal_path)
|
| 85 |
-
# length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
|
| 86 |
-
# full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
|
| 87 |
-
# bgm_tensor = full_tensor - vocal_tensor
|
| 88 |
-
# codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor)
|
| 89 |
-
# writer_vocal(str(idx), codes_vocal.cpu())
|
| 90 |
-
# writer_bgm(str(idx), codes_bgm.cpu())
|
| 91 |
-
# if(first_time):
|
| 92 |
-
# first_time = False
|
| 93 |
-
# print(codes_vocal.shape, codes_bgm.shape)
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py
DELETED
|
@@ -1,70 +0,0 @@
|
|
| 1 |
-
import torch,torchaudio
|
| 2 |
-
import os,sys,json
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
|
| 5 |
-
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
| 6 |
-
from generate_2rvq import Tango
|
| 7 |
-
import kaldiio
|
| 8 |
-
from kaldiio import WriteHelper
|
| 9 |
-
import torch
|
| 10 |
-
import subprocess
|
| 11 |
-
import time
|
| 12 |
-
import sys
|
| 13 |
-
|
| 14 |
-
def get_gpu_memory():
|
| 15 |
-
_output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
|
| 16 |
-
|
| 17 |
-
ACCEPTABLE_AVAILABLE_MEMORY = 1024
|
| 18 |
-
COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
|
| 19 |
-
memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:]
|
| 20 |
-
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
| 21 |
-
return memory_free_values
|
| 22 |
-
|
| 23 |
-
if __name__ == "__main__":
|
| 24 |
-
# Define Model
|
| 25 |
-
json_path = sys.argv[1]
|
| 26 |
-
outdir = sys.argv[2]
|
| 27 |
-
|
| 28 |
-
gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES'])
|
| 29 |
-
while True:
|
| 30 |
-
free_mem = get_gpu_memory()
|
| 31 |
-
free_mem = free_mem[gpu_idx]
|
| 32 |
-
if(free_mem > 25_000):
|
| 33 |
-
print("GPU memory {}, run matrix cal".format(free_mem))
|
| 34 |
-
break
|
| 35 |
-
else:
|
| 36 |
-
print("GPU memory {}, sleep 1min".format(free_mem))
|
| 37 |
-
time.sleep(60)
|
| 38 |
-
|
| 39 |
-
mus_infos = []
|
| 40 |
-
with open(json_path) as f:
|
| 41 |
-
for line in f:
|
| 42 |
-
item = json.loads(line)
|
| 43 |
-
mus_infos.append(item)
|
| 44 |
-
|
| 45 |
-
tango = Tango(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
# Feature extraction loop
|
| 49 |
-
# for i in tqdm(range(2000)):
|
| 50 |
-
with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
|
| 51 |
-
print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
|
| 52 |
-
for item in tqdm(mus_infos):
|
| 53 |
-
try:
|
| 54 |
-
# if True:
|
| 55 |
-
idx = item['idx']
|
| 56 |
-
# print(idx)
|
| 57 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 58 |
-
if(os.path.exists(item['path'])):
|
| 59 |
-
codes = tango.file2code(item['path'])
|
| 60 |
-
else:
|
| 61 |
-
codes = tango.file2code('/mnt/share/' + item['path'])
|
| 62 |
-
writer(str(idx), codes.cpu())
|
| 63 |
-
except:
|
| 64 |
-
print(item['path'])
|
| 65 |
-
continue
|
| 66 |
-
# idx = item['idx']
|
| 67 |
-
# # print(idx)
|
| 68 |
-
# with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 69 |
-
# codes = tango.file2code(item['path'])
|
| 70 |
-
# writer(str(idx), codes.cpu())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
import torch,torchaudio
|
| 2 |
-
import os,sys,json
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
|
| 5 |
-
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
| 6 |
-
from generate_4rvq import Tango
|
| 7 |
-
import kaldiio
|
| 8 |
-
from kaldiio import WriteHelper
|
| 9 |
-
|
| 10 |
-
if __name__ == "__main__":
|
| 11 |
-
# Define Model
|
| 12 |
-
json_path = sys.argv[1]
|
| 13 |
-
outdir = sys.argv[2]
|
| 14 |
-
|
| 15 |
-
mus_infos = []
|
| 16 |
-
with open(json_path) as f:
|
| 17 |
-
for line in f:
|
| 18 |
-
item = json.loads(line)
|
| 19 |
-
mus_infos.append(item)
|
| 20 |
-
|
| 21 |
-
tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
# Feature extraction loop
|
| 25 |
-
# for i in tqdm(range(2000)):
|
| 26 |
-
with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
|
| 27 |
-
print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
|
| 28 |
-
for item in tqdm(mus_infos):
|
| 29 |
-
try:
|
| 30 |
-
# if True:
|
| 31 |
-
idx = item['idx']
|
| 32 |
-
# print(idx)
|
| 33 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 34 |
-
if(os.path.exists(item['path'])):
|
| 35 |
-
codes = tango.file2code(item['path'])
|
| 36 |
-
else:
|
| 37 |
-
codes = tango.file2code('/mnt/share/' + item['path'])
|
| 38 |
-
writer(str(idx), codes.cpu())
|
| 39 |
-
except:
|
| 40 |
-
print(item['path'])
|
| 41 |
-
continue
|
| 42 |
-
# idx = item['idx']
|
| 43 |
-
# # print(idx)
|
| 44 |
-
# with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 45 |
-
# codes = tango.file2code(item['path'])
|
| 46 |
-
# writer(str(idx), codes.cpu())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
import torch,torchaudio
|
| 2 |
-
import os,sys,json
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
|
| 5 |
-
#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
|
| 6 |
-
from generate_4rvq import Tango
|
| 7 |
-
import kaldiio
|
| 8 |
-
from kaldiio import WriteHelper
|
| 9 |
-
import torch
|
| 10 |
-
import subprocess
|
| 11 |
-
import time
|
| 12 |
-
import sys
|
| 13 |
-
|
| 14 |
-
def get_gpu_memory():
|
| 15 |
-
_output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
|
| 16 |
-
|
| 17 |
-
ACCEPTABLE_AVAILABLE_MEMORY = 1024
|
| 18 |
-
COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
|
| 19 |
-
memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:]
|
| 20 |
-
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
| 21 |
-
return memory_free_values
|
| 22 |
-
|
| 23 |
-
if __name__ == "__main__":
|
| 24 |
-
# Define Model
|
| 25 |
-
json_path = sys.argv[1]
|
| 26 |
-
outdir = sys.argv[2]
|
| 27 |
-
ds = int(sys.argv[3])
|
| 28 |
-
|
| 29 |
-
gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES'])
|
| 30 |
-
while True:
|
| 31 |
-
free_mem = get_gpu_memory()
|
| 32 |
-
free_mem = free_mem[gpu_idx]
|
| 33 |
-
if(free_mem > 25_000):
|
| 34 |
-
print("GPU memory {}, run matrix cal".format(free_mem))
|
| 35 |
-
break
|
| 36 |
-
else:
|
| 37 |
-
print("GPU memory {}, sleep 1min".format(free_mem))
|
| 38 |
-
time.sleep(60)
|
| 39 |
-
|
| 40 |
-
mus_infos = []
|
| 41 |
-
with open(json_path) as f:
|
| 42 |
-
for line in f:
|
| 43 |
-
item = json.loads(line)
|
| 44 |
-
mus_infos.append(item)
|
| 45 |
-
|
| 46 |
-
tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# Feature extraction loop
|
| 50 |
-
# for i in tqdm(range(2000)):
|
| 51 |
-
with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
|
| 52 |
-
print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
|
| 53 |
-
bar = torch.zeros(4, 16384)
|
| 54 |
-
for item_idx, item in tqdm(enumerate(mus_infos)):
|
| 55 |
-
try:
|
| 56 |
-
# if True:
|
| 57 |
-
idx = item['idx']
|
| 58 |
-
# print(idx)
|
| 59 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 60 |
-
if(os.path.exists(item['path'])):
|
| 61 |
-
codes = tango.file2code_ds(item['path'], ds)
|
| 62 |
-
else:
|
| 63 |
-
codes = tango.file2code_ds('/mnt/share/' + item['path'], ds)
|
| 64 |
-
codes = codes.cpu()
|
| 65 |
-
writer(str(idx), codes)
|
| 66 |
-
for i0 in range(codes.shape[-1]):
|
| 67 |
-
bar[0, codes[0, 0, i0]] += 1
|
| 68 |
-
bar[1, codes[0, 1, i0]] += 1
|
| 69 |
-
bar[2, codes[0, 2, i0]] += 1
|
| 70 |
-
bar[3, codes[0, 3, i0]] += 1
|
| 71 |
-
except Exception as e:
|
| 72 |
-
print(item['path'])
|
| 73 |
-
# print(e.message, e.args)
|
| 74 |
-
# exit(1)
|
| 75 |
-
continue
|
| 76 |
-
|
| 77 |
-
if(item_idx % 1000 == 0):
|
| 78 |
-
print("=========")
|
| 79 |
-
print(1 - (bar[0]==0).sum() / bar.shape[-1])
|
| 80 |
-
print("=========")
|
| 81 |
-
|
| 82 |
-
# idx = item['idx']
|
| 83 |
-
# # print(idx)
|
| 84 |
-
# with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 85 |
-
# codes = tango.file2code(item['path'])
|
| 86 |
-
# writer(str(idx), codes.cpu())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/generate_1rvq.py
CHANGED
|
@@ -8,7 +8,6 @@ import librosa
|
|
| 8 |
import os
|
| 9 |
import math
|
| 10 |
import numpy as np
|
| 11 |
-
from tools.get_1dvae_large import get_model
|
| 12 |
import tools.torch_tools as torch_tools
|
| 13 |
from safetensors.torch import load_file
|
| 14 |
|
|
@@ -24,9 +23,9 @@ class Tango:
|
|
| 24 |
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
| 25 |
self.device = device
|
| 26 |
|
| 27 |
-
self.vae = get_model(vae_config, vae_model)
|
| 28 |
-
self.vae = self.vae.to(device)
|
| 29 |
-
self.vae=self.vae.eval()
|
| 30 |
self.layer_num = layer_num
|
| 31 |
|
| 32 |
self.MAX_DURATION = 360
|
|
@@ -254,37 +253,9 @@ class Tango:
|
|
| 254 |
# print(fname, wave.shape)
|
| 255 |
return wave
|
| 256 |
|
| 257 |
-
@torch.no_grad()
|
| 258 |
-
def sound2sound_vae(self, sound, prompt=None, steps=50, disable_progress=False):
|
| 259 |
-
min_samples = int(40 * 25) # 40ms per frame
|
| 260 |
-
hop_samples = min_samples // 4 * 3
|
| 261 |
-
ovlp_samples = min_samples - hop_samples
|
| 262 |
-
dur = 20
|
| 263 |
-
|
| 264 |
-
latent_list = []
|
| 265 |
-
for i in range(0, sound.shape[-1], dur*48000):
|
| 266 |
-
if(i+dur*2*48000 > sound.shape[-1]):
|
| 267 |
-
latent = tango.vae.encode_audio(sound.cuda()[None,:,i:])
|
| 268 |
-
break
|
| 269 |
-
else:
|
| 270 |
-
latent = tango.vae.encode_audio(sound.cuda()[None,:,i:i+dur*48000])
|
| 271 |
-
latent_list.append(latent)
|
| 272 |
-
|
| 273 |
-
output = None
|
| 274 |
-
for i in range(len(latent_list)):
|
| 275 |
-
print(i)
|
| 276 |
-
latent = latent_list[i]
|
| 277 |
-
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
| 278 |
-
if output is None:
|
| 279 |
-
output = cur_output
|
| 280 |
-
else:
|
| 281 |
-
output = torch.cat([output, cur_output], -1)
|
| 282 |
-
return output
|
| 283 |
-
|
| 284 |
def to(self, device=None, dtype=None, non_blocking=False):
|
| 285 |
if device is not None:
|
| 286 |
self.device = device
|
| 287 |
self.model.device = device
|
| 288 |
-
self.vae = self.vae.to(device, dtype, non_blocking)
|
| 289 |
self.model = self.model.to(device, dtype, non_blocking)
|
| 290 |
return self
|
|
|
|
| 8 |
import os
|
| 9 |
import math
|
| 10 |
import numpy as np
|
|
|
|
| 11 |
import tools.torch_tools as torch_tools
|
| 12 |
from safetensors.torch import load_file
|
| 13 |
|
|
|
|
| 23 |
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
| 24 |
self.device = device
|
| 25 |
|
| 26 |
+
# self.vae = get_model(vae_config, vae_model)
|
| 27 |
+
# self.vae = self.vae.to(device)
|
| 28 |
+
# self.vae=self.vae.eval()
|
| 29 |
self.layer_num = layer_num
|
| 30 |
|
| 31 |
self.MAX_DURATION = 360
|
|
|
|
| 253 |
# print(fname, wave.shape)
|
| 254 |
return wave
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
def to(self, device=None, dtype=None, non_blocking=False):
|
| 257 |
if device is not None:
|
| 258 |
self.device = device
|
| 259 |
self.model.device = device
|
|
|
|
| 260 |
self.model = self.model.to(device, dtype, non_blocking)
|
| 261 |
return self
|
codeclm/tokenizer/Flow1dVAE/generate_2rvq.py
DELETED
|
@@ -1,293 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import torch
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
from model_2rvq import PromptCondAudioDiffusion
|
| 5 |
-
from diffusers import DDIMScheduler, DDPMScheduler
|
| 6 |
-
import torchaudio
|
| 7 |
-
import librosa
|
| 8 |
-
import os
|
| 9 |
-
import math
|
| 10 |
-
import numpy as np
|
| 11 |
-
# from tools.get_mulan import get_mulan
|
| 12 |
-
from tools.get_1dvae_large import get_model
|
| 13 |
-
import tools.torch_tools as torch_tools
|
| 14 |
-
from safetensors.torch import load_file
|
| 15 |
-
from audio import AudioFile
|
| 16 |
-
import kaldiio
|
| 17 |
-
|
| 18 |
-
class Tango:
|
| 19 |
-
def __init__(self, \
|
| 20 |
-
model_path, \
|
| 21 |
-
layer_num=6, \
|
| 22 |
-
rvq_num=1, \
|
| 23 |
-
device="cuda:0"):
|
| 24 |
-
|
| 25 |
-
self.sample_rate = 48000
|
| 26 |
-
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
| 27 |
-
self.device = device
|
| 28 |
-
|
| 29 |
-
self.vae = get_model()
|
| 30 |
-
self.vae = self.vae.to(device)
|
| 31 |
-
self.vae=self.vae.eval()
|
| 32 |
-
self.layer_num = layer_num
|
| 33 |
-
|
| 34 |
-
self.MAX_DURATION = 360
|
| 35 |
-
main_config = {
|
| 36 |
-
"num_channels":32,
|
| 37 |
-
"unet_model_name":None,
|
| 38 |
-
"unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
|
| 39 |
-
"snr_gamma":None,
|
| 40 |
-
}
|
| 41 |
-
self.rvq_num = rvq_num
|
| 42 |
-
# print("rvq_num: ", self.rvq_num)
|
| 43 |
-
# exit()
|
| 44 |
-
self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
| 45 |
-
if model_path.endswith(".safetensors"):
|
| 46 |
-
main_weights = load_file(model_path)
|
| 47 |
-
else:
|
| 48 |
-
main_weights = torch.load(model_path, map_location=device)
|
| 49 |
-
self.model.load_state_dict(main_weights, strict=False)
|
| 50 |
-
print ("Successfully loaded checkpoint from:", model_path)
|
| 51 |
-
|
| 52 |
-
self.model.eval()
|
| 53 |
-
self.model.init_device_dtype(torch.device(device), torch.float32)
|
| 54 |
-
|
| 55 |
-
# self.scheduler = DDIMScheduler.from_pretrained( \
|
| 56 |
-
# scheduler_name, subfolder="scheduler")
|
| 57 |
-
# self.scheduler = DDPMScheduler.from_pretrained( \
|
| 58 |
-
# scheduler_name, subfolder="scheduler")
|
| 59 |
-
print("Successfully loaded inference scheduler from {}".format(scheduler_name))
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
@torch.no_grad()
|
| 64 |
-
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
| 65 |
-
def sound2code(self, orig_samples, batch_size=8):
|
| 66 |
-
if(orig_samples.ndim == 2):
|
| 67 |
-
audios = orig_samples.unsqueeze(0).to(self.device)
|
| 68 |
-
elif(orig_samples.ndim == 3):
|
| 69 |
-
audios = orig_samples.to(self.device)
|
| 70 |
-
else:
|
| 71 |
-
assert orig_samples.ndim in (2,3), orig_samples.shape
|
| 72 |
-
audios = self.preprocess_audio(audios)
|
| 73 |
-
audios = audios.squeeze(0)
|
| 74 |
-
orig_length = audios.shape[-1]
|
| 75 |
-
min_samples = int(40 * self.sample_rate)
|
| 76 |
-
# 40秒对应10个token
|
| 77 |
-
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
| 78 |
-
# print("output_len: ", output_len)
|
| 79 |
-
|
| 80 |
-
while(audios.shape[-1] < min_samples):
|
| 81 |
-
audios = torch.cat([audios, audios], -1)
|
| 82 |
-
int_max_len=audios.shape[-1]//min_samples+1
|
| 83 |
-
audios = torch.cat([audios, audios], -1)
|
| 84 |
-
audios=audios[:,:int(int_max_len*(min_samples))]
|
| 85 |
-
codes_list=[]
|
| 86 |
-
|
| 87 |
-
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
| 88 |
-
|
| 89 |
-
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
| 90 |
-
# import pdb; pdb.set_trace()
|
| 91 |
-
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num)
|
| 92 |
-
# print("codes",codes[0].shape)
|
| 93 |
-
|
| 94 |
-
codes_list.append(torch.cat(codes, 1))
|
| 95 |
-
# print("codes_list",codes_list[0].shape)
|
| 96 |
-
|
| 97 |
-
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
|
| 98 |
-
codes=codes[:,:,:output_len]
|
| 99 |
-
|
| 100 |
-
return codes
|
| 101 |
-
|
| 102 |
-
@torch.no_grad()
|
| 103 |
-
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
| 104 |
-
def sound2code_ds(self, orig_samples, ds, batch_size=8):
|
| 105 |
-
if(orig_samples.ndim == 2):
|
| 106 |
-
audios = orig_samples.unsqueeze(0).to(self.device)
|
| 107 |
-
elif(orig_samples.ndim == 3):
|
| 108 |
-
audios = orig_samples.to(self.device)
|
| 109 |
-
else:
|
| 110 |
-
assert orig_samples.ndim in (2,3), orig_samples.shape
|
| 111 |
-
audios = self.preprocess_audio(audios)
|
| 112 |
-
audios = audios.squeeze(0)
|
| 113 |
-
orig_length = audios.shape[-1]
|
| 114 |
-
min_samples = int(40 * self.sample_rate)
|
| 115 |
-
# 40秒对应10个token
|
| 116 |
-
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
| 117 |
-
# print("output_len: ", output_len)
|
| 118 |
-
|
| 119 |
-
while(audios.shape[-1] < min_samples):
|
| 120 |
-
audios = torch.cat([audios, audios], -1)
|
| 121 |
-
int_max_len=audios.shape[-1]//min_samples+1
|
| 122 |
-
audios = torch.cat([audios, audios], -1)
|
| 123 |
-
audios=audios[:,:int(int_max_len*(min_samples))]
|
| 124 |
-
codes_list=[]
|
| 125 |
-
|
| 126 |
-
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
| 127 |
-
|
| 128 |
-
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
| 129 |
-
# import pdb; pdb.set_trace()
|
| 130 |
-
codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds)
|
| 131 |
-
# print("codes",codes[0].shape)
|
| 132 |
-
|
| 133 |
-
codes_list.append(torch.cat(codes, 1))
|
| 134 |
-
# print("codes_list",codes_list[0].shape)
|
| 135 |
-
|
| 136 |
-
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
|
| 137 |
-
codes=codes[:,:,:output_len]
|
| 138 |
-
|
| 139 |
-
return codes
|
| 140 |
-
|
| 141 |
-
@torch.no_grad()
|
| 142 |
-
def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
| 143 |
-
codes = codes.to(self.device)
|
| 144 |
-
|
| 145 |
-
min_samples = duration * 25 # 40ms per frame
|
| 146 |
-
hop_samples = min_samples // 4 * 3
|
| 147 |
-
ovlp_samples = min_samples - hop_samples
|
| 148 |
-
hop_frames = hop_samples
|
| 149 |
-
ovlp_frames = ovlp_samples
|
| 150 |
-
first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
|
| 151 |
-
first_latent_length = 0
|
| 152 |
-
first_latent_codes_length = 0
|
| 153 |
-
|
| 154 |
-
if(isinstance(prompt, torch.Tensor)):
|
| 155 |
-
# prepare prompt
|
| 156 |
-
prompt = prompt.to(self.device)
|
| 157 |
-
if(prompt.ndim == 3):
|
| 158 |
-
assert prompt.shape[0] == 1, prompt.shape
|
| 159 |
-
prompt = prompt[0]
|
| 160 |
-
elif(prompt.ndim == 1):
|
| 161 |
-
prompt = prompt.unsqueeze(0).repeat(2,1)
|
| 162 |
-
elif(prompt.ndim == 2):
|
| 163 |
-
if(prompt.shape[0] == 1):
|
| 164 |
-
prompt = prompt.repeat(2,1)
|
| 165 |
-
|
| 166 |
-
if(prompt.shape[-1] < int(30 * self.sample_rate)):
|
| 167 |
-
# if less than 30s, just choose the first 10s
|
| 168 |
-
prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
|
| 169 |
-
else:
|
| 170 |
-
# else choose from 20.48s which might includes verse or chorus
|
| 171 |
-
prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
| 172 |
-
|
| 173 |
-
true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
|
| 174 |
-
# print("true_latent.shape", true_latent.shape)
|
| 175 |
-
# print("first_latent.shape", first_latent.shape)
|
| 176 |
-
#true_latent.shape torch.Size([1, 250, 64])
|
| 177 |
-
# first_latent.shape torch.Size([1, 1000, 64])
|
| 178 |
-
|
| 179 |
-
first_latent[:,0:true_latent.shape[1],:] = true_latent
|
| 180 |
-
first_latent_length = true_latent.shape[1]
|
| 181 |
-
first_latent_codes = self.sound2code(prompt)
|
| 182 |
-
first_latent_codes_length = first_latent_codes.shape[-1]
|
| 183 |
-
codes = torch.cat([first_latent_codes, codes], -1)
|
| 184 |
-
|
| 185 |
-
codes_len= codes.shape[-1]
|
| 186 |
-
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
| 187 |
-
# target_len = int(codes_len / 100 * 4 * self.sample_rate)
|
| 188 |
-
# code repeat
|
| 189 |
-
if(codes_len < min_samples):
|
| 190 |
-
while(codes.shape[-1] < min_samples):
|
| 191 |
-
codes = torch.cat([codes, codes], -1)
|
| 192 |
-
codes = codes[:,:,0:min_samples]
|
| 193 |
-
codes_len = codes.shape[-1]
|
| 194 |
-
if((codes_len - ovlp_samples) % hop_samples > 0):
|
| 195 |
-
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
| 196 |
-
while(codes.shape[-1] < len_codes):
|
| 197 |
-
codes = torch.cat([codes, codes], -1)
|
| 198 |
-
codes = codes[:,:,0:len_codes]
|
| 199 |
-
latent_length = min_samples
|
| 200 |
-
latent_list = []
|
| 201 |
-
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
| 202 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 203 |
-
for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
| 204 |
-
codes_input=[]
|
| 205 |
-
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
| 206 |
-
if(sinx == 0):
|
| 207 |
-
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
| 208 |
-
incontext_length = first_latent_length
|
| 209 |
-
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 210 |
-
latent_list.append(latents)
|
| 211 |
-
else:
|
| 212 |
-
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
| 213 |
-
true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
|
| 214 |
-
print("true_latent.shape", true_latent.shape)
|
| 215 |
-
len_add_to_1000 = 1000 - true_latent.shape[-2]
|
| 216 |
-
# print("len_add_to_1000", len_add_to_1000)
|
| 217 |
-
# exit()
|
| 218 |
-
incontext_length = true_latent.shape[-2]
|
| 219 |
-
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
|
| 220 |
-
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 221 |
-
latent_list.append(latents)
|
| 222 |
-
|
| 223 |
-
latent_list = [l.float() for l in latent_list]
|
| 224 |
-
latent_list[0] = latent_list[0][:,:,first_latent_length:]
|
| 225 |
-
min_samples = int(min_samples * self.sample_rate // 1000 * 40)
|
| 226 |
-
hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
|
| 227 |
-
ovlp_samples = min_samples - hop_samples
|
| 228 |
-
with torch.no_grad():
|
| 229 |
-
output = None
|
| 230 |
-
for i in range(len(latent_list)):
|
| 231 |
-
latent = latent_list[i]
|
| 232 |
-
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
| 233 |
-
|
| 234 |
-
if output is None:
|
| 235 |
-
output = cur_output
|
| 236 |
-
else:
|
| 237 |
-
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| 238 |
-
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
| 239 |
-
print("output.shape", output.shape)
|
| 240 |
-
print("ov_win.shape", ov_win.shape)
|
| 241 |
-
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| 242 |
-
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| 243 |
-
output = output[:, 0:target_len]
|
| 244 |
-
return output
|
| 245 |
-
|
| 246 |
-
@torch.no_grad()
|
| 247 |
-
def preprocess_audio(self, input_audios, threshold=0.8):
|
| 248 |
-
assert len(input_audios.shape) == 3, input_audios.shape
|
| 249 |
-
nchan = input_audios.shape[1]
|
| 250 |
-
input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
| 251 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
| 252 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
| 253 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| 254 |
-
return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
| 255 |
-
|
| 256 |
-
@torch.no_grad()
|
| 257 |
-
def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
|
| 258 |
-
codes = self.sound2code(sound)
|
| 259 |
-
# print(codes.shape)
|
| 260 |
-
# exit()
|
| 261 |
-
wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
| 262 |
-
# print(fname, wave.shape)
|
| 263 |
-
return wave
|
| 264 |
-
|
| 265 |
-
def file2code(self, fname):
|
| 266 |
-
try:
|
| 267 |
-
orig_samples, fs = torchaudio.load(fname)
|
| 268 |
-
except:
|
| 269 |
-
af = AudioFile(fname)
|
| 270 |
-
orig_samples = af.read()
|
| 271 |
-
fs = af.samplerate()
|
| 272 |
-
orig_samples = orig_samples[0]
|
| 273 |
-
if(fs!=self.sample_rate):
|
| 274 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
| 275 |
-
fs = self.sample_rate
|
| 276 |
-
if orig_samples.shape[0] == 1:
|
| 277 |
-
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| 278 |
-
return self.sound2code(orig_samples)
|
| 279 |
-
|
| 280 |
-
def file2code_ds(self, fname, ds):
|
| 281 |
-
try:
|
| 282 |
-
orig_samples, fs = torchaudio.load(fname)
|
| 283 |
-
except:
|
| 284 |
-
af = AudioFile(fname)
|
| 285 |
-
orig_samples = af.read()
|
| 286 |
-
fs = af.samplerate()
|
| 287 |
-
orig_samples = orig_samples[0]
|
| 288 |
-
if(fs!=self.sample_rate):
|
| 289 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
| 290 |
-
fs = self.sample_rate
|
| 291 |
-
if orig_samples.shape[0] == 1:
|
| 292 |
-
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| 293 |
-
return self.sound2code_ds(orig_samples, ds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/generate_4rvq.py
DELETED
|
@@ -1,292 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import torch
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
from model_4rvq import PromptCondAudioDiffusion
|
| 5 |
-
from diffusers import DDIMScheduler, DDPMScheduler
|
| 6 |
-
import torchaudio
|
| 7 |
-
import librosa
|
| 8 |
-
import os
|
| 9 |
-
import math
|
| 10 |
-
import numpy as np
|
| 11 |
-
# from tools.get_mulan import get_mulan
|
| 12 |
-
from tools.get_1dvae_large import get_model
|
| 13 |
-
import tools.torch_tools as torch_tools
|
| 14 |
-
from safetensors.torch import load_file
|
| 15 |
-
from audio import AudioFile
|
| 16 |
-
|
| 17 |
-
class Tango:
|
| 18 |
-
def __init__(self, \
|
| 19 |
-
model_path, \
|
| 20 |
-
layer_num=6, \
|
| 21 |
-
rvq_num=1, \
|
| 22 |
-
device="cuda:0"):
|
| 23 |
-
|
| 24 |
-
self.sample_rate = 48000
|
| 25 |
-
scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
|
| 26 |
-
self.device = device
|
| 27 |
-
|
| 28 |
-
self.vae = get_model()
|
| 29 |
-
self.vae = self.vae.to(device)
|
| 30 |
-
self.vae=self.vae.eval()
|
| 31 |
-
self.layer_num = layer_num
|
| 32 |
-
|
| 33 |
-
self.MAX_DURATION = 360
|
| 34 |
-
main_config = {
|
| 35 |
-
"num_channels":32,
|
| 36 |
-
"unet_model_name":None,
|
| 37 |
-
"unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
|
| 38 |
-
"snr_gamma":None,
|
| 39 |
-
}
|
| 40 |
-
self.rvq_num = rvq_num
|
| 41 |
-
# print("rvq_num: ", self.rvq_num)
|
| 42 |
-
# exit()
|
| 43 |
-
self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
| 44 |
-
if model_path.endswith(".safetensors"):
|
| 45 |
-
main_weights = load_file(model_path)
|
| 46 |
-
else:
|
| 47 |
-
main_weights = torch.load(model_path, map_location=device)
|
| 48 |
-
self.model.load_state_dict(main_weights, strict=False)
|
| 49 |
-
print ("Successfully loaded checkpoint from:", model_path)
|
| 50 |
-
|
| 51 |
-
self.model.eval()
|
| 52 |
-
self.model.init_device_dtype(torch.device(device), torch.float32)
|
| 53 |
-
|
| 54 |
-
# self.scheduler = DDIMScheduler.from_pretrained( \
|
| 55 |
-
# scheduler_name, subfolder="scheduler")
|
| 56 |
-
# self.scheduler = DDPMScheduler.from_pretrained( \
|
| 57 |
-
# scheduler_name, subfolder="scheduler")
|
| 58 |
-
print("Successfully loaded inference scheduler from {}".format(scheduler_name))
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
@torch.no_grad()
|
| 63 |
-
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
| 64 |
-
def sound2code(self, orig_samples, batch_size=8):
|
| 65 |
-
if(orig_samples.ndim == 2):
|
| 66 |
-
audios = orig_samples.unsqueeze(0).to(self.device)
|
| 67 |
-
elif(orig_samples.ndim == 3):
|
| 68 |
-
audios = orig_samples.to(self.device)
|
| 69 |
-
else:
|
| 70 |
-
assert orig_samples.ndim in (2,3), orig_samples.shape
|
| 71 |
-
audios = self.preprocess_audio(audios)
|
| 72 |
-
audios = audios.squeeze(0)
|
| 73 |
-
orig_length = audios.shape[-1]
|
| 74 |
-
min_samples = int(40 * self.sample_rate)
|
| 75 |
-
# 40秒对应10个token
|
| 76 |
-
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
| 77 |
-
# print("output_len: ", output_len)
|
| 78 |
-
|
| 79 |
-
while(audios.shape[-1] < min_samples):
|
| 80 |
-
audios = torch.cat([audios, audios], -1)
|
| 81 |
-
int_max_len=audios.shape[-1]//min_samples+1
|
| 82 |
-
audios = torch.cat([audios, audios], -1)
|
| 83 |
-
audios=audios[:,:int(int_max_len*(min_samples))]
|
| 84 |
-
codes_list=[]
|
| 85 |
-
|
| 86 |
-
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
| 87 |
-
|
| 88 |
-
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
| 89 |
-
# import pdb; pdb.set_trace()
|
| 90 |
-
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num)
|
| 91 |
-
# print("codes",codes[0].shape)
|
| 92 |
-
|
| 93 |
-
codes_list.append(torch.cat(codes, 1))
|
| 94 |
-
# print("codes_list",codes_list[0].shape)
|
| 95 |
-
|
| 96 |
-
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
|
| 97 |
-
codes=codes[:,:,:output_len]
|
| 98 |
-
|
| 99 |
-
return codes
|
| 100 |
-
|
| 101 |
-
@torch.no_grad()
|
| 102 |
-
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
| 103 |
-
def sound2code_ds(self, orig_samples, ds, batch_size=6):
|
| 104 |
-
if(orig_samples.ndim == 2):
|
| 105 |
-
audios = orig_samples.unsqueeze(0).to(self.device)
|
| 106 |
-
elif(orig_samples.ndim == 3):
|
| 107 |
-
audios = orig_samples.to(self.device)
|
| 108 |
-
else:
|
| 109 |
-
assert orig_samples.ndim in (2,3), orig_samples.shape
|
| 110 |
-
audios = self.preprocess_audio(audios)
|
| 111 |
-
audios = audios.squeeze(0)
|
| 112 |
-
orig_length = audios.shape[-1]
|
| 113 |
-
min_samples = int(40 * self.sample_rate)
|
| 114 |
-
# 40秒对应10个token
|
| 115 |
-
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
| 116 |
-
# print("output_len: ", output_len)
|
| 117 |
-
|
| 118 |
-
while(audios.shape[-1] < min_samples):
|
| 119 |
-
audios = torch.cat([audios, audios], -1)
|
| 120 |
-
int_max_len=audios.shape[-1]//min_samples+1
|
| 121 |
-
audios = torch.cat([audios, audios], -1)
|
| 122 |
-
audios=audios[:,:int(int_max_len*(min_samples))]
|
| 123 |
-
codes_list=[]
|
| 124 |
-
|
| 125 |
-
audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
|
| 126 |
-
|
| 127 |
-
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
| 128 |
-
# import pdb; pdb.set_trace()
|
| 129 |
-
codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds)
|
| 130 |
-
# print("codes",codes[0].shape)
|
| 131 |
-
|
| 132 |
-
codes_list.append(torch.cat(codes, 1))
|
| 133 |
-
# print("codes_list",codes_list[0].shape)
|
| 134 |
-
|
| 135 |
-
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
|
| 136 |
-
codes=codes[:,:,:output_len]
|
| 137 |
-
|
| 138 |
-
return codes
|
| 139 |
-
|
| 140 |
-
@torch.no_grad()
|
| 141 |
-
def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
| 142 |
-
codes = codes.to(self.device)
|
| 143 |
-
|
| 144 |
-
min_samples = duration * 25 # 40ms per frame
|
| 145 |
-
hop_samples = min_samples // 4 * 3
|
| 146 |
-
ovlp_samples = min_samples - hop_samples
|
| 147 |
-
hop_frames = hop_samples
|
| 148 |
-
ovlp_frames = ovlp_samples
|
| 149 |
-
first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
|
| 150 |
-
first_latent_length = 0
|
| 151 |
-
first_latent_codes_length = 0
|
| 152 |
-
|
| 153 |
-
if(isinstance(prompt, torch.Tensor)):
|
| 154 |
-
# prepare prompt
|
| 155 |
-
prompt = prompt.to(self.device)
|
| 156 |
-
if(prompt.ndim == 3):
|
| 157 |
-
assert prompt.shape[0] == 1, prompt.shape
|
| 158 |
-
prompt = prompt[0]
|
| 159 |
-
elif(prompt.ndim == 1):
|
| 160 |
-
prompt = prompt.unsqueeze(0).repeat(2,1)
|
| 161 |
-
elif(prompt.ndim == 2):
|
| 162 |
-
if(prompt.shape[0] == 1):
|
| 163 |
-
prompt = prompt.repeat(2,1)
|
| 164 |
-
|
| 165 |
-
if(prompt.shape[-1] < int(30 * self.sample_rate)):
|
| 166 |
-
# if less than 30s, just choose the first 10s
|
| 167 |
-
prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
|
| 168 |
-
else:
|
| 169 |
-
# else choose from 20.48s which might includes verse or chorus
|
| 170 |
-
prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
|
| 171 |
-
|
| 172 |
-
true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
|
| 173 |
-
# print("true_latent.shape", true_latent.shape)
|
| 174 |
-
# print("first_latent.shape", first_latent.shape)
|
| 175 |
-
#true_latent.shape torch.Size([1, 250, 64])
|
| 176 |
-
# first_latent.shape torch.Size([1, 1000, 64])
|
| 177 |
-
|
| 178 |
-
first_latent[:,0:true_latent.shape[1],:] = true_latent
|
| 179 |
-
first_latent_length = true_latent.shape[1]
|
| 180 |
-
first_latent_codes = self.sound2code(prompt)
|
| 181 |
-
first_latent_codes_length = first_latent_codes.shape[-1]
|
| 182 |
-
codes = torch.cat([first_latent_codes, codes], -1)
|
| 183 |
-
|
| 184 |
-
codes_len= codes.shape[-1]
|
| 185 |
-
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
| 186 |
-
# target_len = int(codes_len / 100 * 4 * self.sample_rate)
|
| 187 |
-
# code repeat
|
| 188 |
-
if(codes_len < min_samples):
|
| 189 |
-
while(codes.shape[-1] < min_samples):
|
| 190 |
-
codes = torch.cat([codes, codes], -1)
|
| 191 |
-
codes = codes[:,:,0:min_samples]
|
| 192 |
-
codes_len = codes.shape[-1]
|
| 193 |
-
if((codes_len - ovlp_samples) % hop_samples > 0):
|
| 194 |
-
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
| 195 |
-
while(codes.shape[-1] < len_codes):
|
| 196 |
-
codes = torch.cat([codes, codes], -1)
|
| 197 |
-
codes = codes[:,:,0:len_codes]
|
| 198 |
-
latent_length = min_samples
|
| 199 |
-
latent_list = []
|
| 200 |
-
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
| 201 |
-
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 202 |
-
for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
| 203 |
-
codes_input=[]
|
| 204 |
-
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
| 205 |
-
if(sinx == 0):
|
| 206 |
-
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
| 207 |
-
incontext_length = first_latent_length
|
| 208 |
-
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 209 |
-
latent_list.append(latents)
|
| 210 |
-
else:
|
| 211 |
-
# print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
|
| 212 |
-
true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
|
| 213 |
-
print("true_latent.shape", true_latent.shape)
|
| 214 |
-
len_add_to_1000 = 1000 - true_latent.shape[-2]
|
| 215 |
-
# print("len_add_to_1000", len_add_to_1000)
|
| 216 |
-
# exit()
|
| 217 |
-
incontext_length = true_latent.shape[-2]
|
| 218 |
-
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
|
| 219 |
-
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| 220 |
-
latent_list.append(latents)
|
| 221 |
-
|
| 222 |
-
latent_list = [l.float() for l in latent_list]
|
| 223 |
-
latent_list[0] = latent_list[0][:,:,first_latent_length:]
|
| 224 |
-
min_samples = int(min_samples * self.sample_rate // 1000 * 40)
|
| 225 |
-
hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
|
| 226 |
-
ovlp_samples = min_samples - hop_samples
|
| 227 |
-
with torch.no_grad():
|
| 228 |
-
output = None
|
| 229 |
-
for i in range(len(latent_list)):
|
| 230 |
-
latent = latent_list[i]
|
| 231 |
-
cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
|
| 232 |
-
|
| 233 |
-
if output is None:
|
| 234 |
-
output = cur_output
|
| 235 |
-
else:
|
| 236 |
-
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| 237 |
-
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
| 238 |
-
print("output.shape", output.shape)
|
| 239 |
-
print("ov_win.shape", ov_win.shape)
|
| 240 |
-
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| 241 |
-
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| 242 |
-
output = output[:, 0:target_len]
|
| 243 |
-
return output
|
| 244 |
-
|
| 245 |
-
@torch.no_grad()
|
| 246 |
-
def preprocess_audio(self, input_audios, threshold=0.8):
|
| 247 |
-
assert len(input_audios.shape) == 3, input_audios.shape
|
| 248 |
-
nchan = input_audios.shape[1]
|
| 249 |
-
input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
| 250 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
| 251 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
| 252 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| 253 |
-
return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
| 254 |
-
|
| 255 |
-
@torch.no_grad()
|
| 256 |
-
def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
|
| 257 |
-
codes = self.sound2code(sound)
|
| 258 |
-
# print(codes.shape)
|
| 259 |
-
# exit()
|
| 260 |
-
wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
| 261 |
-
# print(fname, wave.shape)
|
| 262 |
-
return wave
|
| 263 |
-
|
| 264 |
-
def file2code(self, fname):
|
| 265 |
-
try:
|
| 266 |
-
orig_samples, fs = torchaudio.load(fname)
|
| 267 |
-
except:
|
| 268 |
-
af = AudioFile(fname)
|
| 269 |
-
orig_samples = af.read()
|
| 270 |
-
fs = af.samplerate()
|
| 271 |
-
orig_samples = orig_samples[0]
|
| 272 |
-
if(fs!=self.sample_rate):
|
| 273 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
| 274 |
-
fs = self.sample_rate
|
| 275 |
-
if orig_samples.shape[0] == 1:
|
| 276 |
-
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| 277 |
-
return self.sound2code(orig_samples)
|
| 278 |
-
|
| 279 |
-
def file2code_ds(self, fname, ds):
|
| 280 |
-
try:
|
| 281 |
-
orig_samples, fs = torchaudio.load(fname)
|
| 282 |
-
except:
|
| 283 |
-
af = AudioFile(fname)
|
| 284 |
-
orig_samples = af.read()
|
| 285 |
-
fs = af.samplerate()
|
| 286 |
-
orig_samples = orig_samples[0]
|
| 287 |
-
if(fs!=self.sample_rate):
|
| 288 |
-
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
| 289 |
-
fs = self.sample_rate
|
| 290 |
-
if orig_samples.shape[0] == 1:
|
| 291 |
-
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| 292 |
-
return self.sound2code_ds(orig_samples, ds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py
DELETED
|
@@ -1,1278 +0,0 @@
|
|
| 1 |
-
from torch.utils.data import Dataset
|
| 2 |
-
from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List, Union
|
| 3 |
-
from beartype import beartype
|
| 4 |
-
from beartype.door import is_bearable
|
| 5 |
-
import random
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import os
|
| 8 |
-
from torchaudio.functional import resample
|
| 9 |
-
import torch
|
| 10 |
-
import typing as tp
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
import torchaudio as ta
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
import numpy as np
|
| 15 |
-
import json
|
| 16 |
-
import yaml
|
| 17 |
-
import torchaudio
|
| 18 |
-
import math
|
| 19 |
-
import re
|
| 20 |
-
from loguru import logger
|
| 21 |
-
import ffmpeg
|
| 22 |
-
|
| 23 |
-
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
| 24 |
-
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 25 |
-
|
| 26 |
-
super().__init__()
|
| 27 |
-
|
| 28 |
-
self.n_samples = n_samples
|
| 29 |
-
self.sample_rate = sample_rate
|
| 30 |
-
self.randomize = randomize
|
| 31 |
-
|
| 32 |
-
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 33 |
-
if self.n_samples < 0: #means not clip
|
| 34 |
-
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
| 35 |
-
t_start = 0.
|
| 36 |
-
t_end = 1.0
|
| 37 |
-
offset = 0
|
| 38 |
-
else:
|
| 39 |
-
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
| 40 |
-
# print(duration,(float(self.n_samples)/self.sample_rate+1))
|
| 41 |
-
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
| 42 |
-
t_start = 0.
|
| 43 |
-
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
| 44 |
-
offset = 0
|
| 45 |
-
# print('c1:',chunk.shape)
|
| 46 |
-
else:
|
| 47 |
-
offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 48 |
-
t_start = offset / float(cur_sample_rate) / duration
|
| 49 |
-
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
| 50 |
-
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 51 |
-
# print('offset:',offset)
|
| 52 |
-
# print('c0:',chunk.shape)
|
| 53 |
-
# Pad with silence if necessary.
|
| 54 |
-
if(chunk.shape[0]>1):
|
| 55 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 56 |
-
else:
|
| 57 |
-
chunk = chunk[[0],:].float()
|
| 58 |
-
if(cur_sample_rate!=self.sample_rate):
|
| 59 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 60 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
| 61 |
-
# print('b:',self.sample_rate,chunk.shape)
|
| 62 |
-
|
| 63 |
-
if self.n_samples > 0:
|
| 64 |
-
if chunk.shape[-1] < self.n_samples:
|
| 65 |
-
chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
| 66 |
-
else:
|
| 67 |
-
chunk = chunk[:,0:self.n_samples]
|
| 68 |
-
seconds_start = math.floor(offset / cur_sample_rate)
|
| 69 |
-
seconds_total = math.floor(duration)
|
| 70 |
-
|
| 71 |
-
return (
|
| 72 |
-
chunk,
|
| 73 |
-
t_start,
|
| 74 |
-
t_end,
|
| 75 |
-
seconds_start,
|
| 76 |
-
seconds_total
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
class Read_and_PadCrop_Normalized_T_Avoid_Watermark(torch.nn.Module):
|
| 80 |
-
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True, w_start = 0, w_interval = 11.3):
|
| 81 |
-
|
| 82 |
-
super().__init__()
|
| 83 |
-
|
| 84 |
-
self.n_samples = n_samples
|
| 85 |
-
self.sample_rate = sample_rate
|
| 86 |
-
self.randomize = randomize
|
| 87 |
-
|
| 88 |
-
self.w_start = w_start
|
| 89 |
-
self.w_interval = w_interval
|
| 90 |
-
|
| 91 |
-
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 92 |
-
if self.n_samples < 0: #means not clip
|
| 93 |
-
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
| 94 |
-
t_start = 0.
|
| 95 |
-
t_end = 1.0
|
| 96 |
-
offset = 0
|
| 97 |
-
else:
|
| 98 |
-
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
| 99 |
-
# print(duration,(float(self.n_samples)/self.sample_rate+1))
|
| 100 |
-
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
| 101 |
-
t_start = 0.
|
| 102 |
-
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
| 103 |
-
offset = 0
|
| 104 |
-
# print('c1:',chunk.shape)
|
| 105 |
-
else:
|
| 106 |
-
n_offset_option = (duration - self.w_start) // self.w_interval
|
| 107 |
-
if n_offset_option <= 1:
|
| 108 |
-
offset = 0
|
| 109 |
-
else:
|
| 110 |
-
offset = int((random.randint(0,n_offset_option-1) * self.w_interval + self.w_start) * cur_sample_rate)
|
| 111 |
-
# offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 112 |
-
t_start = offset / float(cur_sample_rate) / duration
|
| 113 |
-
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
| 114 |
-
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 115 |
-
# print('offset:',offset)
|
| 116 |
-
# print('c0:',chunk.shape)
|
| 117 |
-
# Pad with silence if necessary.
|
| 118 |
-
if(chunk.shape[0]>1):
|
| 119 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 120 |
-
else:
|
| 121 |
-
chunk = chunk[[0],:].float()
|
| 122 |
-
if(cur_sample_rate!=self.sample_rate):
|
| 123 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 124 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
| 125 |
-
# print('b:',self.sample_rate,chunk.shape)
|
| 126 |
-
|
| 127 |
-
if self.n_samples > 0:
|
| 128 |
-
if chunk.shape[-1] < self.n_samples:
|
| 129 |
-
chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
| 130 |
-
else:
|
| 131 |
-
chunk = chunk[:,0:self.n_samples]
|
| 132 |
-
seconds_start = math.floor(offset / cur_sample_rate)
|
| 133 |
-
seconds_total = math.floor(duration)
|
| 134 |
-
|
| 135 |
-
return (
|
| 136 |
-
chunk,
|
| 137 |
-
t_start,
|
| 138 |
-
t_end,
|
| 139 |
-
seconds_start,
|
| 140 |
-
seconds_total
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
|
| 144 |
-
if USE_DUMMY_AUDIO:
|
| 145 |
-
logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
|
| 146 |
-
|
| 147 |
-
class SafeAudioReader:
|
| 148 |
-
"""
|
| 149 |
-
This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
|
| 150 |
-
"""
|
| 151 |
-
def __init__(self,
|
| 152 |
-
duration: float, # 返回音频长度
|
| 153 |
-
sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
|
| 154 |
-
randomize: bool = True,
|
| 155 |
-
use_avoid_watermark_policy = False,
|
| 156 |
-
):
|
| 157 |
-
self.n_samples = int(sample_rate * duration)
|
| 158 |
-
self.reader = (
|
| 159 |
-
Read_and_PadCrop_Normalized_T_Avoid_Watermark if use_avoid_watermark_policy \
|
| 160 |
-
else Read_and_PadCrop_Normalized_T
|
| 161 |
-
)(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
|
| 162 |
-
|
| 163 |
-
#NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
|
| 164 |
-
def __call__(self,
|
| 165 |
-
filepath: os.PathLike, # 音频路径
|
| 166 |
-
origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
|
| 167 |
-
origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
|
| 168 |
-
) -> torch.Tensor:
|
| 169 |
-
if USE_DUMMY_AUDIO:
|
| 170 |
-
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
| 171 |
-
return wav
|
| 172 |
-
try:
|
| 173 |
-
if origin_sample_rate is None or origin_duration is None:
|
| 174 |
-
# audio_info = torchaudio.info(filepath)
|
| 175 |
-
# origin_sample_rate = audio_info.sample_rate
|
| 176 |
-
# origin_duration = audio_info.num_frames / origin_sample_rate
|
| 177 |
-
info = ffmpeg.probe(filepath)
|
| 178 |
-
origin_duration = float(info['format']['duration'])
|
| 179 |
-
origin_sample_rate = int(info['streams'][0]['sample_rate'])
|
| 180 |
-
wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
|
| 181 |
-
wav = wav.squeeze_(0)
|
| 182 |
-
except Exception as e:
|
| 183 |
-
logger.error(f"Error reading {filepath}: {e}")
|
| 184 |
-
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
| 185 |
-
return wav
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
class PromptTemplate:
|
| 189 |
-
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
|
| 190 |
-
self.template_text = template_text
|
| 191 |
-
self.tag_map = tag_map
|
| 192 |
-
self.lang = lang
|
| 193 |
-
|
| 194 |
-
@property
|
| 195 |
-
def tags(self):
|
| 196 |
-
return tuple(self.tag_map.keys())
|
| 197 |
-
|
| 198 |
-
def apply(self, **kwargs):
|
| 199 |
-
for tag in list(kwargs.keys()):
|
| 200 |
-
if kwargs[tag] == '':
|
| 201 |
-
kwargs.pop(tag)
|
| 202 |
-
for tag in self.tags:
|
| 203 |
-
if tag in kwargs:
|
| 204 |
-
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
|
| 205 |
-
else:
|
| 206 |
-
kwargs[tag] = ''
|
| 207 |
-
prompt = self.template_text.format(**kwargs)
|
| 208 |
-
|
| 209 |
-
return self.beautify(prompt)
|
| 210 |
-
|
| 211 |
-
def beautify(self, text):
|
| 212 |
-
if self.lang == 'en':
|
| 213 |
-
return self._beautify_en(text)
|
| 214 |
-
elif self.lang == 'zh':
|
| 215 |
-
return self._beautify_zh(text)
|
| 216 |
-
else:
|
| 217 |
-
raise ValueError(f'Unknown language {self.lang}')
|
| 218 |
-
|
| 219 |
-
@staticmethod
|
| 220 |
-
def _beautify_en(text):
|
| 221 |
-
# no continuous commas without content between them
|
| 222 |
-
text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
|
| 223 |
-
# no continuous whitespace
|
| 224 |
-
text = re.sub(r'\s+', ' ', text)
|
| 225 |
-
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace
|
| 226 |
-
text = re.sub(r'\s+,', r',', text)
|
| 227 |
-
text = re.sub(r',\s+', r', ', text)
|
| 228 |
-
# no whitespace before the full stop
|
| 229 |
-
text = re.sub(r'\s+\.', r'.', text)
|
| 230 |
-
# strip whitespace, comma, and replace ',.'
|
| 231 |
-
text = text.strip(' ,')
|
| 232 |
-
text = text.replace(',.', '.')
|
| 233 |
-
return text
|
| 234 |
-
|
| 235 |
-
@staticmethod
|
| 236 |
-
def _beautify_zh(text):
|
| 237 |
-
# no continuous commas without content between them
|
| 238 |
-
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
|
| 239 |
-
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
|
| 240 |
-
# assume there should be NO whitespace in Chinese
|
| 241 |
-
text = re.sub(r'\s+', r'', text)
|
| 242 |
-
# strip whitespace, comma, and replace ',。'
|
| 243 |
-
text = text.strip(', 、')
|
| 244 |
-
text = text.replace(',。', '。')
|
| 245 |
-
return text
|
| 246 |
-
|
| 247 |
-
def __repr__(self):
|
| 248 |
-
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
|
| 249 |
-
|
| 250 |
-
__str__ = __repr__
|
| 251 |
-
|
| 252 |
-
def parse_prompt_template(prompt_template_text, lang='en'):
|
| 253 |
-
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
|
| 254 |
-
tag_pattern = re.compile(r'{.+?}', re.DOTALL)
|
| 255 |
-
|
| 256 |
-
template_text = prompt_template_text.strip()
|
| 257 |
-
span_texts = span_pattern.findall(prompt_template_text)
|
| 258 |
-
tag_map = {}
|
| 259 |
-
for span_text in span_texts:
|
| 260 |
-
tag = tag_pattern.findall(span_text)[0].strip('{}')
|
| 261 |
-
tag_map[tag] = span_text
|
| 262 |
-
template_text = template_text.replace(span_text, '{'+tag+'}')
|
| 263 |
-
|
| 264 |
-
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
|
| 265 |
-
|
| 266 |
-
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
|
| 267 |
-
with open(path, 'r') as f:
|
| 268 |
-
lines = f.readlines()
|
| 269 |
-
cnt = 0
|
| 270 |
-
pts = []
|
| 271 |
-
for line in lines:
|
| 272 |
-
pt = parse_prompt_template(line, lang=lang)
|
| 273 |
-
cnt += 1
|
| 274 |
-
if len(pt.tags) < num:
|
| 275 |
-
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
|
| 276 |
-
pts.append(pt)
|
| 277 |
-
|
| 278 |
-
return pts
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def get_base_dir_file(key: os.PathLike):
|
| 282 |
-
base = os.path.basename(key)
|
| 283 |
-
dirname = os.path.basename(os.path.dirname(key))
|
| 284 |
-
return os.path.join(dirname, base)
|
| 285 |
-
|
| 286 |
-
def read_jsonlike(path: os.PathLike):
|
| 287 |
-
#json or jsonl
|
| 288 |
-
if str(path).endswith(".json"):
|
| 289 |
-
with open(path, 'r', encoding='utf8') as f:
|
| 290 |
-
data = json.load(f)
|
| 291 |
-
return data
|
| 292 |
-
elif str(path).endswith(".jsonl"):
|
| 293 |
-
with open(path, 'r', encoding='utf8') as f:
|
| 294 |
-
data = [json.loads(line) for line in f.readlines()]
|
| 295 |
-
return data
|
| 296 |
-
else:
|
| 297 |
-
raise ValueError("Unknown file format")
|
| 298 |
-
|
| 299 |
-
dist_prob_map = {
|
| 300 |
-
1: (1.0,),
|
| 301 |
-
2: (0.5, 0.5),
|
| 302 |
-
3: (0.3, 0.4, 0.3),
|
| 303 |
-
4: (0.2, 0.3, 0.3, 0.2),
|
| 304 |
-
5: (0.2, 0.2, 0.3, 0.2, 0.1),
|
| 305 |
-
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
|
| 306 |
-
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
| 307 |
-
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
| 308 |
-
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
| 309 |
-
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
| 310 |
-
}
|
| 311 |
-
|
| 312 |
-
'''
|
| 313 |
-
#更加偏向短文本的方案
|
| 314 |
-
dist_prob_map = {
|
| 315 |
-
1: (1.0,),
|
| 316 |
-
2: (0.7, 0.3),
|
| 317 |
-
3: (0.7, 0.2, 0.1),
|
| 318 |
-
4: (0.6, 0.2, 0.1, 0.1),
|
| 319 |
-
5: (0.6, 0.2, 0.1, 0.05, 0.05),
|
| 320 |
-
6: (0.6, 0.15, 0.1, 0.05, 0.05, 0.05),
|
| 321 |
-
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
| 322 |
-
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
| 323 |
-
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
| 324 |
-
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
| 325 |
-
}
|
| 326 |
-
'''
|
| 327 |
-
|
| 328 |
-
#全部都用的方案
|
| 329 |
-
# dist_prob_map = {
|
| 330 |
-
# 1: (1.0,),
|
| 331 |
-
# 2: (0, 1.0),
|
| 332 |
-
# 3: (0, 0, 1.0),
|
| 333 |
-
# 4: (0, 0, 0, 1.0),
|
| 334 |
-
# 5: (0, 0, 0, 0, 1.0),
|
| 335 |
-
# 6: (0, 0, 0, 0, 0, 1.0),
|
| 336 |
-
# 7: (0, 0, 0, 0, 0, 0, 1.0),
|
| 337 |
-
# 8: (0, 0, 0, 0, 0, 0, 0, 1.0),
|
| 338 |
-
# 9: (0, 0, 0, 0, 0, 0, 0, 0, 1.0),
|
| 339 |
-
# 10: (0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0)
|
| 340 |
-
# }
|
| 341 |
-
|
| 342 |
-
dist_prob_map_low = {
|
| 343 |
-
1: (1.0,),
|
| 344 |
-
2: (0.8, 0.2),
|
| 345 |
-
3: (0.8, 0.1, 0.1),
|
| 346 |
-
4: (0.7, 0.1, 0.1, 0.1),
|
| 347 |
-
5: (0.7, 0.1, 0.1, 0.05, 0.05),
|
| 348 |
-
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
|
| 349 |
-
}
|
| 350 |
-
|
| 351 |
-
_bpm_range_rights = (
|
| 352 |
-
(40, '20-40'),
|
| 353 |
-
(60, '40-60'),
|
| 354 |
-
(66, '60-66'),
|
| 355 |
-
(76, '66-76'),
|
| 356 |
-
(108, '76-108'),
|
| 357 |
-
(120, '108-120'),
|
| 358 |
-
(168, '120-168'),
|
| 359 |
-
(176, '168-176'),
|
| 360 |
-
(200, '176-200')
|
| 361 |
-
)
|
| 362 |
-
_bpm_desc_map = {
|
| 363 |
-
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
|
| 364 |
-
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
|
| 365 |
-
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
|
| 366 |
-
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
|
| 367 |
-
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
|
| 368 |
-
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
|
| 369 |
-
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
|
| 370 |
-
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
|
| 371 |
-
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
|
| 372 |
-
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
|
| 373 |
-
}
|
| 374 |
-
_bpm_desc_map_zh = {
|
| 375 |
-
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
|
| 376 |
-
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
|
| 377 |
-
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
|
| 378 |
-
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
|
| 379 |
-
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
|
| 380 |
-
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
|
| 381 |
-
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
|
| 382 |
-
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
|
| 383 |
-
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
|
| 384 |
-
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
|
| 385 |
-
}
|
| 386 |
-
def get_bpm_range(bpm):
|
| 387 |
-
bpm = int(bpm)
|
| 388 |
-
for right, tag in _bpm_range_rights:
|
| 389 |
-
if bpm <= right:
|
| 390 |
-
return tag
|
| 391 |
-
return '>200'
|
| 392 |
-
|
| 393 |
-
def gen_bpm_descript(bpm, lang='en'):
|
| 394 |
-
bpm_range = get_bpm_range(bpm)
|
| 395 |
-
if lang == 'en':
|
| 396 |
-
return random.choice(_bpm_desc_map[bpm_range])
|
| 397 |
-
elif lang == 'zh':
|
| 398 |
-
return random.choice(_bpm_desc_map_zh[bpm_range])
|
| 399 |
-
else:
|
| 400 |
-
raise ValueError(f"Unknown language {lang}")
|
| 401 |
-
|
| 402 |
-
def read_translate(translate: Union[Dict[str, os.PathLike], os.PathLike, None]):
|
| 403 |
-
if translate is None:
|
| 404 |
-
return None
|
| 405 |
-
if isinstance(translate, str):
|
| 406 |
-
return read_jsonlike(translate)
|
| 407 |
-
return {k: read_jsonlike(path) for k, path in translate.items()}
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
def gen_plain_prompt(key_list, sep=', '):
|
| 411 |
-
if len(key_list) == 0:
|
| 412 |
-
return 'none'
|
| 413 |
-
|
| 414 |
-
key_list = [k.strip() for k in key_list]
|
| 415 |
-
|
| 416 |
-
if len(key_list) > 10:
|
| 417 |
-
random.shuffle(key_list)
|
| 418 |
-
key_list = key_list[:10]
|
| 419 |
-
|
| 420 |
-
probs = dist_prob_map[len(key_list)]
|
| 421 |
-
|
| 422 |
-
num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
|
| 423 |
-
|
| 424 |
-
random.shuffle(key_list)
|
| 425 |
-
tags = key_list[:num_tags]
|
| 426 |
-
tags_str = sep.join(tags)
|
| 427 |
-
return tags_str
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
class MagnaTagATuneDataset(Dataset):
|
| 431 |
-
def __init__(self):
|
| 432 |
-
pass
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
def tags_to_desc(tag_list, sep=',') -> str:
|
| 436 |
-
if not isinstance(tag_list, Sequence):
|
| 437 |
-
return str(tag_list)
|
| 438 |
-
if isinstance(tag_list, str):
|
| 439 |
-
return tag_list
|
| 440 |
-
if len(tag_list) <= 0:
|
| 441 |
-
return ''
|
| 442 |
-
elif len(tag_list) <= 5:
|
| 443 |
-
probs = dist_prob_map[len(tag_list)]
|
| 444 |
-
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
|
| 445 |
-
random.shuffle(tag_list)
|
| 446 |
-
tag_list = tag_list[:tags_num]
|
| 447 |
-
return sep.join(tag_list)
|
| 448 |
-
else:
|
| 449 |
-
probs = dist_prob_map[5]
|
| 450 |
-
tags_num = random.choices(range(1, 6), probs)[0]
|
| 451 |
-
random.shuffle(tag_list)
|
| 452 |
-
tag_list = tag_list[:tags_num]
|
| 453 |
-
return sep.join(tag_list)
|
| 454 |
-
|
| 455 |
-
def get_sr_and_duration_info(item):
|
| 456 |
-
return item.get('sample_rate', None), item.get('duration', None)
|
| 457 |
-
|
| 458 |
-
class MtgJamendoDatasetFromJson(Dataset):
|
| 459 |
-
def __init__(self,
|
| 460 |
-
data_dir:str,
|
| 461 |
-
json_path:str,
|
| 462 |
-
duration:float=10,
|
| 463 |
-
sr:int = 0,
|
| 464 |
-
lang = 'en',
|
| 465 |
-
plain_rate = 0,
|
| 466 |
-
return_audio = True,
|
| 467 |
-
return_path = False,
|
| 468 |
-
prompt_template_path: os.PathLike = None,
|
| 469 |
-
tag_types = [],
|
| 470 |
-
translate:Optional[Dict[str, os.PathLike]] = None,
|
| 471 |
-
use_literal_none = True,
|
| 472 |
-
):
|
| 473 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 474 |
-
|
| 475 |
-
self.data_dir = data_dir
|
| 476 |
-
self._load_metadata_json(json_path)
|
| 477 |
-
self.sr = sr
|
| 478 |
-
self.duration = duration
|
| 479 |
-
self.plain_rate = plain_rate
|
| 480 |
-
self.return_audio = return_audio
|
| 481 |
-
self.return_path = return_path
|
| 482 |
-
self.use_literal_none = use_literal_none
|
| 483 |
-
self.lang = lang
|
| 484 |
-
|
| 485 |
-
self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
|
| 486 |
-
if self.use_dynamic_prompt:
|
| 487 |
-
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
|
| 488 |
-
self.tag_types = tag_types
|
| 489 |
-
|
| 490 |
-
self.translate = read_translate(translate)
|
| 491 |
-
|
| 492 |
-
#这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
|
| 493 |
-
WEAK_TAG_LIST = ["title", "artist"]
|
| 494 |
-
|
| 495 |
-
def _load_metadata_json(self, json_path):
|
| 496 |
-
with open(json_path) as fp:
|
| 497 |
-
self.data = json.load(fp)
|
| 498 |
-
|
| 499 |
-
def convert_key_to_path(self, key):
|
| 500 |
-
return os.path.join(self.data_dir, get_base_dir_file(key))
|
| 501 |
-
|
| 502 |
-
def __len__(self):
|
| 503 |
-
return len(self.data)
|
| 504 |
-
|
| 505 |
-
def __getitem__(self, idx):
|
| 506 |
-
item = self.data[idx]
|
| 507 |
-
path = self.convert_key_to_path(item['key'])
|
| 508 |
-
description = self.generate_description(item)
|
| 509 |
-
|
| 510 |
-
if self.return_audio:
|
| 511 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 512 |
-
audio = self.audio_reader(path, sr, duration)
|
| 513 |
-
else:
|
| 514 |
-
audio = None
|
| 515 |
-
|
| 516 |
-
if self.return_path:
|
| 517 |
-
return audio, description, path
|
| 518 |
-
return audio, description
|
| 519 |
-
|
| 520 |
-
def tags_to_desc(self, tag_list, tag_type) -> str:
|
| 521 |
-
if self.lang == 'en':
|
| 522 |
-
return tags_to_desc(tag_list)
|
| 523 |
-
elif self.lang == 'zh':
|
| 524 |
-
translator = self.translate[tag_type]
|
| 525 |
-
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
| 526 |
-
return tags_to_desc(translated_tag_list, sep='、')
|
| 527 |
-
|
| 528 |
-
def generate_description(self, item):
|
| 529 |
-
if random.random() > self.plain_rate:
|
| 530 |
-
# dynamically generate prompt from given prompt template
|
| 531 |
-
prompt_template = random.choice(self.prompt_templates)
|
| 532 |
-
description = self.generate_description_dynamic(item, prompt_template)
|
| 533 |
-
else:
|
| 534 |
-
# use plain prompt, i.e. tags sequence separated by comma
|
| 535 |
-
description = self.generate_description_plain(item)
|
| 536 |
-
return description
|
| 537 |
-
|
| 538 |
-
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
| 539 |
-
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
| 540 |
-
exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
|
| 541 |
-
exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
|
| 542 |
-
|
| 543 |
-
if len(exists_strong_tag) > 0:
|
| 544 |
-
probs = dist_prob_map[len(exists_strong_tag)]
|
| 545 |
-
tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
|
| 546 |
-
random.shuffle(exists_strong_tag)
|
| 547 |
-
tags = exists_strong_tag[:tags_num]
|
| 548 |
-
weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
|
| 549 |
-
weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
|
| 550 |
-
random.shuffle(exists_weak_tag)
|
| 551 |
-
weak_tags = exists_weak_tag[:weak_tags_num]
|
| 552 |
-
tags += weak_tags
|
| 553 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
| 554 |
-
prompt = prompt_template.apply(**tags_args)
|
| 555 |
-
else:
|
| 556 |
-
# no strong tags, use all weak tags instead
|
| 557 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
|
| 558 |
-
prompt = prompt_template.apply(**tags_args)
|
| 559 |
-
|
| 560 |
-
if self.use_literal_none and len(tags_args) == 0:
|
| 561 |
-
return 'none'
|
| 562 |
-
|
| 563 |
-
return prompt
|
| 564 |
-
|
| 565 |
-
def generate_description_plain(self, item):
|
| 566 |
-
keywords = []
|
| 567 |
-
for tag_t in self.tag_types:
|
| 568 |
-
this_key = item[tag_t]
|
| 569 |
-
if this_key is None:
|
| 570 |
-
continue
|
| 571 |
-
if isinstance(this_key, str):
|
| 572 |
-
this_key = [this_key]
|
| 573 |
-
if self.lang != 'en':
|
| 574 |
-
this_key = [self.get_translation(tag_t, k) for k in this_key]
|
| 575 |
-
keywords += this_key
|
| 576 |
-
return gen_plain_prompt(keywords, sep=self.keysep)
|
| 577 |
-
|
| 578 |
-
def get_translation(self, tag_t, k):
|
| 579 |
-
k = k.strip()
|
| 580 |
-
if k in self.translate[tag_t]:
|
| 581 |
-
return self.translate[tag_t][k]
|
| 582 |
-
else:
|
| 583 |
-
return k
|
| 584 |
-
|
| 585 |
-
@property
|
| 586 |
-
def keysep(self):
|
| 587 |
-
if self.lang == 'zh':
|
| 588 |
-
return ',' if random.random() > 0.5 else '、'
|
| 589 |
-
elif self.lang == 'en':
|
| 590 |
-
return ', '
|
| 591 |
-
|
| 592 |
-
class AudioStockDataset(Dataset):
|
| 593 |
-
def __init__(self,
|
| 594 |
-
metadata_path:str,
|
| 595 |
-
duration:float=10,
|
| 596 |
-
sr:int = 0,
|
| 597 |
-
plain_rate = 0,
|
| 598 |
-
return_path = False,
|
| 599 |
-
return_audio = True,
|
| 600 |
-
prompt_template_path: os.PathLike = None,
|
| 601 |
-
tag_types = [],
|
| 602 |
-
lang = 'en',
|
| 603 |
-
translate:Optional[Dict[str, os.PathLike]] = None,
|
| 604 |
-
use_literal_none = True,
|
| 605 |
-
):
|
| 606 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 607 |
-
|
| 608 |
-
self._load_metadata(metadata_path)
|
| 609 |
-
self.sr = sr
|
| 610 |
-
self.duration = duration
|
| 611 |
-
self.plain_rate = plain_rate
|
| 612 |
-
self.return_path = return_path
|
| 613 |
-
self.return_audio = return_audio
|
| 614 |
-
self.use_literal_none = use_literal_none
|
| 615 |
-
|
| 616 |
-
self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
|
| 617 |
-
if self.use_dynamic_prompt:
|
| 618 |
-
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
|
| 619 |
-
self.tag_types = tag_types
|
| 620 |
-
|
| 621 |
-
self.lang = lang
|
| 622 |
-
self.translate = read_translate(translate)
|
| 623 |
-
|
| 624 |
-
def _load_metadata(self, metadata_path):
|
| 625 |
-
with open(metadata_path) as fp:
|
| 626 |
-
lines = fp.readlines()
|
| 627 |
-
self.data = []
|
| 628 |
-
for line in lines:
|
| 629 |
-
item = json.loads(line)
|
| 630 |
-
self.data.append(item)
|
| 631 |
-
self.is_info_recorded = bool('Tags' in self.data[0])
|
| 632 |
-
|
| 633 |
-
def __len__(self):
|
| 634 |
-
return len(self.data)
|
| 635 |
-
|
| 636 |
-
def __getitem__(self, idx):
|
| 637 |
-
path:str = self.data[idx]["path"]
|
| 638 |
-
json_path = path[:path.rfind('.')] + ".json"
|
| 639 |
-
if self.is_info_recorded:
|
| 640 |
-
item = self.data[idx]
|
| 641 |
-
else:
|
| 642 |
-
try:
|
| 643 |
-
with open(json_path) as fp:
|
| 644 |
-
item:dict = json.load(fp)
|
| 645 |
-
except Exception as e:
|
| 646 |
-
print(f"Error loading json file {json_path} :\n{e}")
|
| 647 |
-
item = {}
|
| 648 |
-
description = self.generate_description(item)
|
| 649 |
-
if self.return_audio:
|
| 650 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 651 |
-
audio = self.audio_reader(path, sr, duration)
|
| 652 |
-
else:
|
| 653 |
-
audio = None
|
| 654 |
-
if self.return_path:
|
| 655 |
-
return audio, description, path
|
| 656 |
-
return audio, description
|
| 657 |
-
|
| 658 |
-
def generate_description(self, item):
|
| 659 |
-
if random.random() > self.plain_rate:
|
| 660 |
-
# dynamically generate prompt from given prompt template
|
| 661 |
-
prompt_template = random.choice(self.prompt_templates)
|
| 662 |
-
description = self.generate_description_dynamic(item, prompt_template)
|
| 663 |
-
else:
|
| 664 |
-
# use plain prompt, i.e. tags sequence separated by comma
|
| 665 |
-
description = self.generate_description_plain(item)
|
| 666 |
-
return description
|
| 667 |
-
|
| 668 |
-
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
| 669 |
-
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
| 670 |
-
|
| 671 |
-
if len(exists_tag) > 0:
|
| 672 |
-
probs = dist_prob_map[len(exists_tag)]
|
| 673 |
-
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
|
| 674 |
-
random.shuffle(exists_tag)
|
| 675 |
-
tags = exists_tag[:tags_num]
|
| 676 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
| 677 |
-
tags_args = self.handle_BPM_tag(tags_args)
|
| 678 |
-
prompt = prompt_template.apply(**tags_args)
|
| 679 |
-
else:
|
| 680 |
-
return 'none'
|
| 681 |
-
|
| 682 |
-
if self.use_literal_none and len(tags_args) == 0:
|
| 683 |
-
return 'none'
|
| 684 |
-
|
| 685 |
-
return prompt
|
| 686 |
-
|
| 687 |
-
def get_translation(self, tag_t, k):
|
| 688 |
-
k = k.strip()
|
| 689 |
-
if k in self.translate[tag_t]:
|
| 690 |
-
return self.translate[tag_t][k]
|
| 691 |
-
else:
|
| 692 |
-
return k
|
| 693 |
-
|
| 694 |
-
def generate_description_plain(self, item):
|
| 695 |
-
keywords = []
|
| 696 |
-
for tag_t in self.tag_types:
|
| 697 |
-
if tag_t == 'BPMDescript':
|
| 698 |
-
bpm = item['BPM']
|
| 699 |
-
if bpm is None or bpm.strip() == '' or bpm.strip() == '0':
|
| 700 |
-
continue
|
| 701 |
-
this_key = gen_bpm_descript(bpm.strip(), lang=self.lang)
|
| 702 |
-
elif tag_t == 'BPM':
|
| 703 |
-
bpm = item['BPM']
|
| 704 |
-
if bpm is None or bpm.strip() == '' or bpm.strip() == '0':
|
| 705 |
-
continue
|
| 706 |
-
this_key = f"{bpm.strip()} bpm"
|
| 707 |
-
else:
|
| 708 |
-
this_key = item[tag_t]
|
| 709 |
-
if this_key is None:
|
| 710 |
-
continue
|
| 711 |
-
if isinstance(this_key, str):
|
| 712 |
-
this_key = [this_key]
|
| 713 |
-
if self.lang != 'en':
|
| 714 |
-
this_key = [self.get_translation(tag_t, k) for k in this_key]
|
| 715 |
-
if this_key is None:
|
| 716 |
-
continue
|
| 717 |
-
if isinstance(this_key, str):
|
| 718 |
-
this_key = [this_key]
|
| 719 |
-
keywords += this_key
|
| 720 |
-
return gen_plain_prompt(keywords, sep=self.keysep)
|
| 721 |
-
|
| 722 |
-
@property
|
| 723 |
-
def keysep(self):
|
| 724 |
-
if self.lang == 'zh':
|
| 725 |
-
return ',' if random.random() > 0.5 else '、'
|
| 726 |
-
elif self.lang == 'en':
|
| 727 |
-
return ', '
|
| 728 |
-
|
| 729 |
-
def tags_to_desc(self, tag_list, tag_type) -> str:
|
| 730 |
-
if self.lang == 'en':
|
| 731 |
-
return tags_to_desc(tag_list)
|
| 732 |
-
elif self.lang == 'zh':
|
| 733 |
-
if tag_type == 'BPM':
|
| 734 |
-
return tags_to_desc(tag_list, sep='、')
|
| 735 |
-
translator = self.translate[tag_type]
|
| 736 |
-
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
| 737 |
-
return tags_to_desc(translated_tag_list, sep='、')
|
| 738 |
-
|
| 739 |
-
def handle_BPM_tag(self, tags_args):
|
| 740 |
-
if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
|
| 741 |
-
bpm = tags_args["BPM"]
|
| 742 |
-
del tags_args["BPM"]
|
| 743 |
-
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
|
| 744 |
-
for tag_type in tag_types_used:
|
| 745 |
-
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
|
| 746 |
-
return tags_args
|
| 747 |
-
|
| 748 |
-
def mp3_path_to_id(mp3_path):
|
| 749 |
-
return int(
|
| 750 |
-
mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.')]
|
| 751 |
-
)
|
| 752 |
-
|
| 753 |
-
class TmeDataset(Dataset):
|
| 754 |
-
def __init__(self,
|
| 755 |
-
data_index:str,
|
| 756 |
-
music_info:str = None,
|
| 757 |
-
duration:float = 10,
|
| 758 |
-
sr:int = 0,
|
| 759 |
-
plain_rate = 0,
|
| 760 |
-
return_path = False,
|
| 761 |
-
return_audio = True,
|
| 762 |
-
return_ID = False,
|
| 763 |
-
prompt_format_path: os.PathLike = None,
|
| 764 |
-
tag_types = ['*'],
|
| 765 |
-
lang = 'zh',
|
| 766 |
-
translate: Optional[os.PathLike] = None,
|
| 767 |
-
prompt_dir: os.PathLike = None, #使用GPT生成的预有的prompt
|
| 768 |
-
):
|
| 769 |
-
if plain_rate > 0:
|
| 770 |
-
print("Tme Dataset do not support plain rate > 0, use plain_rate = 0 instead.")
|
| 771 |
-
plain_rate = 0
|
| 772 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 773 |
-
|
| 774 |
-
self.sr = sr
|
| 775 |
-
self.duration = duration
|
| 776 |
-
self.plain_rate = plain_rate
|
| 777 |
-
self.return_path = return_path
|
| 778 |
-
self.return_audio = return_audio
|
| 779 |
-
self.return_ID = return_ID
|
| 780 |
-
self.lang = lang
|
| 781 |
-
|
| 782 |
-
self.use_ready_prompt = prompt_dir is not None
|
| 783 |
-
|
| 784 |
-
data_index = read_jsonlike(data_index)
|
| 785 |
-
self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
|
| 786 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 787 |
-
|
| 788 |
-
if not self.use_ready_prompt:
|
| 789 |
-
#读取音乐的信息文件
|
| 790 |
-
music_info = read_jsonlike(music_info)
|
| 791 |
-
if 'music' in music_info:
|
| 792 |
-
music_info = music_info['music']
|
| 793 |
-
self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
|
| 794 |
-
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
|
| 795 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 796 |
-
|
| 797 |
-
with open(prompt_format_path) as fp:
|
| 798 |
-
self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
|
| 799 |
-
|
| 800 |
-
#加载tag types,并分成一般的tag_types和关键的key_tag_types
|
| 801 |
-
if '*' in tag_types:
|
| 802 |
-
self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
|
| 803 |
-
else:
|
| 804 |
-
self.tag_types = tag_types
|
| 805 |
-
|
| 806 |
-
self.key_tag_types = []
|
| 807 |
-
if 'tag' in self.tag_types:
|
| 808 |
-
self.tag_types.remove('tag')
|
| 809 |
-
self.key_tag_types = list(self.prompt_formats['tag'].keys())
|
| 810 |
-
|
| 811 |
-
#加载translate翻译
|
| 812 |
-
if translate is not None:
|
| 813 |
-
self.translator = read_jsonlike(translate)
|
| 814 |
-
else:
|
| 815 |
-
data_ids_set = set(self.data_ids)
|
| 816 |
-
self.prompts_dict = {}
|
| 817 |
-
for fname in os.listdir(prompt_dir):
|
| 818 |
-
items = read_jsonlike(os.path.join(prompt_dir, fname))
|
| 819 |
-
for item in items:
|
| 820 |
-
if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
|
| 821 |
-
continue
|
| 822 |
-
if item['ID'] not in self.prompts_dict:
|
| 823 |
-
self.prompts_dict[item['ID']] = []
|
| 824 |
-
self.prompts_dict[item['ID']].append(item['Text'])
|
| 825 |
-
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
|
| 826 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 827 |
-
|
| 828 |
-
def tags_to_desc(self, tag_list) -> str:
|
| 829 |
-
if is_bearable(tag_list, int):
|
| 830 |
-
return str(tag_list)
|
| 831 |
-
if self.lang == 'zh':
|
| 832 |
-
return tags_to_desc(tag_list, sep=self.sep)
|
| 833 |
-
else:
|
| 834 |
-
translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
|
| 835 |
-
return tags_to_desc(translated_tag_list, sep=self.sep)
|
| 836 |
-
|
| 837 |
-
def gen_desc_of_tag(self, formats, tags):
|
| 838 |
-
fmt = random.choice(formats)
|
| 839 |
-
return fmt.format(self.tags_to_desc(tags))
|
| 840 |
-
|
| 841 |
-
@staticmethod
|
| 842 |
-
def check_valid(value):
|
| 843 |
-
if isinstance(value, int) or isinstance(value, float):
|
| 844 |
-
return value > 0
|
| 845 |
-
if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
|
| 846 |
-
return True
|
| 847 |
-
return False
|
| 848 |
-
|
| 849 |
-
@staticmethod
|
| 850 |
-
def remove_repeat(data):
|
| 851 |
-
#若专辑名和歌曲名相同,则只使用后者
|
| 852 |
-
album_name = data.get('专辑名', None)
|
| 853 |
-
if album_name is not None and album_name == data.get('歌曲名', None):
|
| 854 |
-
del data['专辑名']
|
| 855 |
-
return data
|
| 856 |
-
|
| 857 |
-
@property
|
| 858 |
-
def comma(self):
|
| 859 |
-
if self.lang == 'zh':
|
| 860 |
-
return ','
|
| 861 |
-
elif self.lang == 'en':
|
| 862 |
-
return ', '
|
| 863 |
-
|
| 864 |
-
@property
|
| 865 |
-
def sep(self):
|
| 866 |
-
if self.lang == 'zh':
|
| 867 |
-
return '、'
|
| 868 |
-
elif self.lang == 'en':
|
| 869 |
-
return ', '
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
def generate_description(self, item):
|
| 873 |
-
if random.random() > self.plain_rate:
|
| 874 |
-
# dynamically generate prompt from given prompt template
|
| 875 |
-
description = self.generate_description_dynamic(item)
|
| 876 |
-
else:
|
| 877 |
-
# use plain prompt, i.e. tags sequence separated by comma
|
| 878 |
-
description = self.generate_description_plain(item)
|
| 879 |
-
return description
|
| 880 |
-
|
| 881 |
-
def generate_description_dynamic(self, data):
|
| 882 |
-
data = self.remove_repeat(data)
|
| 883 |
-
|
| 884 |
-
weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
|
| 885 |
-
|
| 886 |
-
key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
|
| 887 |
-
|
| 888 |
-
prompts = []
|
| 889 |
-
if len(weak_tags) > 0:
|
| 890 |
-
probs = dist_prob_map_low[len(weak_tags)]
|
| 891 |
-
if len(key_tags) > 0:
|
| 892 |
-
tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
|
| 893 |
-
else:
|
| 894 |
-
tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
|
| 895 |
-
random.shuffle(weak_tags)
|
| 896 |
-
tags = weak_tags[:tags_num]
|
| 897 |
-
for tag_type in tags:
|
| 898 |
-
tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
|
| 899 |
-
prompts.append(tag_desc)
|
| 900 |
-
|
| 901 |
-
if len(key_tags) > 0:
|
| 902 |
-
probs = dist_prob_map[len(key_tags)]
|
| 903 |
-
tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
|
| 904 |
-
random.shuffle(key_tags)
|
| 905 |
-
tags = key_tags[:tags_num]
|
| 906 |
-
for tag_type in tags:
|
| 907 |
-
tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
|
| 908 |
-
prompts.append(tag_desc)
|
| 909 |
-
|
| 910 |
-
random.shuffle(prompts)
|
| 911 |
-
return self.comma.join(prompts)
|
| 912 |
-
|
| 913 |
-
def generate_description_plain(self, item):
|
| 914 |
-
keywords = item['tag']
|
| 915 |
-
if self.lang != 'en':
|
| 916 |
-
keywords = [self.translator[k.strip()] for k in keywords]
|
| 917 |
-
return gen_plain_prompt(keywords, sep=self.keysep)
|
| 918 |
-
|
| 919 |
-
@property
|
| 920 |
-
def keysep(self):
|
| 921 |
-
if self.lang == 'zh':
|
| 922 |
-
return ',' if random.random() > 0.5 else '、'
|
| 923 |
-
elif self.lang == 'en':
|
| 924 |
-
return ', '
|
| 925 |
-
|
| 926 |
-
def is_valid_prompt_text(self, text):
|
| 927 |
-
for bad in ('抱歉','sorry', 'Sorry'):
|
| 928 |
-
if bad in text:
|
| 929 |
-
return False
|
| 930 |
-
return True
|
| 931 |
-
|
| 932 |
-
def get_ready_prompt(self, path):
|
| 933 |
-
sid = mp3_path_to_id(path)
|
| 934 |
-
return random.choice(self.prompts_dict[sid])
|
| 935 |
-
|
| 936 |
-
def __len__(self):
|
| 937 |
-
return len(self.data_ids)
|
| 938 |
-
|
| 939 |
-
def __getitem__(self, idx):
|
| 940 |
-
data_id = self.data_ids[idx]
|
| 941 |
-
item = self.data_index_dict[data_id]
|
| 942 |
-
path = item['path']
|
| 943 |
-
if not self.use_ready_prompt:
|
| 944 |
-
info = self.music_info_dict[data_id]
|
| 945 |
-
description = self.generate_description(info)
|
| 946 |
-
else:
|
| 947 |
-
description = self.get_ready_prompt(path)
|
| 948 |
-
if self.return_audio:
|
| 949 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 950 |
-
audio = self.audio_reader(path, sr, duration)
|
| 951 |
-
else:
|
| 952 |
-
audio = None
|
| 953 |
-
if self.return_path:
|
| 954 |
-
if self.return_ID:
|
| 955 |
-
return audio, description, path, info['歌曲ID']
|
| 956 |
-
return audio, description, path
|
| 957 |
-
if self.return_ID:
|
| 958 |
-
return audio, description, info['歌曲ID']
|
| 959 |
-
return audio, description
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
class Pond5Dataset(Dataset):
|
| 963 |
-
MAX_PROMPT_LEN = 200
|
| 964 |
-
def __init__(self,
|
| 965 |
-
metadata_path:str,
|
| 966 |
-
index_path:str,
|
| 967 |
-
duration:float=10,
|
| 968 |
-
sr:int = 0,
|
| 969 |
-
plain_rate = 0,
|
| 970 |
-
return_path = False,
|
| 971 |
-
return_audio = True,
|
| 972 |
-
lang = 'en',
|
| 973 |
-
translate:Optional[Dict[str, os.PathLike]] = None,
|
| 974 |
-
use_literal_none = True,
|
| 975 |
-
use_avoid_watermark_policy = None,
|
| 976 |
-
):
|
| 977 |
-
|
| 978 |
-
if use_avoid_watermark_policy is None:
|
| 979 |
-
raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type")
|
| 980 |
-
self.use_avoid_watermark_policy = use_avoid_watermark_policy
|
| 981 |
-
self.audio_reader = SafeAudioReader(duration, sr, use_avoid_watermark_policy=use_avoid_watermark_policy)
|
| 982 |
-
|
| 983 |
-
self._load_metadata(metadata_path, index_path)
|
| 984 |
-
self.sr = sr
|
| 985 |
-
self.duration = duration
|
| 986 |
-
self.plain_rate = plain_rate
|
| 987 |
-
self.return_path = return_path
|
| 988 |
-
self.return_audio = return_audio
|
| 989 |
-
self.use_literal_none = use_literal_none
|
| 990 |
-
|
| 991 |
-
self.lang = lang
|
| 992 |
-
self.translate = read_translate(translate)
|
| 993 |
-
|
| 994 |
-
def _load_metadata(self, metadata_path, index_path):
|
| 995 |
-
data_index = read_jsonlike(index_path)
|
| 996 |
-
data_ids = set([item['id'] for item in data_index])
|
| 997 |
-
|
| 998 |
-
with open(metadata_path) as fp:
|
| 999 |
-
lines = fp.readlines()
|
| 1000 |
-
|
| 1001 |
-
append_ids = set()
|
| 1002 |
-
|
| 1003 |
-
self.data = []
|
| 1004 |
-
for line in lines:
|
| 1005 |
-
item = json.loads(line)
|
| 1006 |
-
if item['id'] in data_ids and item['id'] not in append_ids:
|
| 1007 |
-
self.data.append(item)
|
| 1008 |
-
append_ids.add(item['id'])
|
| 1009 |
-
|
| 1010 |
-
def __len__(self):
|
| 1011 |
-
return len(self.data)
|
| 1012 |
-
|
| 1013 |
-
def __getitem__(self, idx):
|
| 1014 |
-
item = self.data[idx]
|
| 1015 |
-
path:str = item["path"]
|
| 1016 |
-
description = self.generate_description(item)
|
| 1017 |
-
if self.return_audio:
|
| 1018 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 1019 |
-
audio = self.audio_reader(path, sr, duration)
|
| 1020 |
-
else:
|
| 1021 |
-
audio = None
|
| 1022 |
-
if self.return_path:
|
| 1023 |
-
return audio, description, path
|
| 1024 |
-
return audio, description
|
| 1025 |
-
|
| 1026 |
-
@property
|
| 1027 |
-
def keysep(self):
|
| 1028 |
-
if self.lang == 'zh':
|
| 1029 |
-
return ',' if random.random() > 0.5 else '、'
|
| 1030 |
-
elif self.lang == 'en':
|
| 1031 |
-
return ', '
|
| 1032 |
-
|
| 1033 |
-
def generate_description(self, item):
|
| 1034 |
-
if random.random() > self.plain_rate:
|
| 1035 |
-
# dynamically generate prompt from given prompt template
|
| 1036 |
-
description = self.generate_description_dynamic(item)
|
| 1037 |
-
else:
|
| 1038 |
-
# use plain prompt, i.e. tags sequence separated by comma
|
| 1039 |
-
description = self.generate_description_plain(item)
|
| 1040 |
-
return description
|
| 1041 |
-
|
| 1042 |
-
def get_translation(self, k):
|
| 1043 |
-
k = k.strip()
|
| 1044 |
-
if k in self.translate:
|
| 1045 |
-
return self.translate[k]
|
| 1046 |
-
else:
|
| 1047 |
-
return k
|
| 1048 |
-
|
| 1049 |
-
def generate_description_plain(self, item):
|
| 1050 |
-
keywords = item['keywords']
|
| 1051 |
-
if self.lang != 'en':
|
| 1052 |
-
keywords = [self.get_translation(k) for k in keywords]
|
| 1053 |
-
return gen_plain_prompt(keywords, sep=self.keysep)
|
| 1054 |
-
|
| 1055 |
-
def generate_description_dynamic(self,item):
|
| 1056 |
-
desc = item.get('desc', 'none')
|
| 1057 |
-
if desc is None:
|
| 1058 |
-
desc = 'none'
|
| 1059 |
-
desc = desc.strip()
|
| 1060 |
-
if len(desc) > self.MAX_PROMPT_LEN:
|
| 1061 |
-
shorter_desc = desc[:self.MAX_PROMPT_LEN]
|
| 1062 |
-
# find last stop
|
| 1063 |
-
stop_idx = shorter_desc.rfind('.')
|
| 1064 |
-
if stop_idx == -1:
|
| 1065 |
-
stop_idx = shorter_desc.rfind('!')
|
| 1066 |
-
if stop_idx == -1:
|
| 1067 |
-
stop_idx = shorter_desc.rfind(',')
|
| 1068 |
-
if stop_idx == -1:
|
| 1069 |
-
stop_idx = self.MAX_PROMPT_LEN - 1
|
| 1070 |
-
desc = desc[:stop_idx+1]
|
| 1071 |
-
return desc
|
| 1072 |
-
|
| 1073 |
-
class SoundDataset(Dataset):
|
| 1074 |
-
def __init__(self,
|
| 1075 |
-
metadata_index: str,
|
| 1076 |
-
duration:float = 10,
|
| 1077 |
-
min_non_silent_duration:float = 3,
|
| 1078 |
-
sr:int = 0,
|
| 1079 |
-
return_path = False,
|
| 1080 |
-
return_audio = True,
|
| 1081 |
-
):
|
| 1082 |
-
self.data = read_jsonlike(metadata_index)
|
| 1083 |
-
self.sr = sr
|
| 1084 |
-
self.reader = SafeAudioReader(duration, sr)
|
| 1085 |
-
self.duration = duration
|
| 1086 |
-
self.min_non_silent_duration = min_non_silent_duration
|
| 1087 |
-
self.return_audio = return_audio
|
| 1088 |
-
self.return_path = return_path
|
| 1089 |
-
|
| 1090 |
-
def __getitem__(self, index):
|
| 1091 |
-
item = self.data[index]
|
| 1092 |
-
if self.return_audio:
|
| 1093 |
-
origin_duration = item['duration']
|
| 1094 |
-
if origin_duration < self.min_non_silent_duration:
|
| 1095 |
-
audio = self.read_and_repeat_and_pad(item)
|
| 1096 |
-
else:
|
| 1097 |
-
audio = self.reader(item['path'], item['sample_rate'], origin_duration)
|
| 1098 |
-
else:
|
| 1099 |
-
audio = None
|
| 1100 |
-
desc = item['caption']
|
| 1101 |
-
if self.return_path:
|
| 1102 |
-
return audio, desc, item['path']
|
| 1103 |
-
else:
|
| 1104 |
-
return audio, desc
|
| 1105 |
-
|
| 1106 |
-
def __len__(self):
|
| 1107 |
-
return len(self.data)
|
| 1108 |
-
|
| 1109 |
-
def read_and_repeat_and_pad(self, item):
|
| 1110 |
-
path = item['path']
|
| 1111 |
-
try:
|
| 1112 |
-
# read
|
| 1113 |
-
clip, sr = torchaudio.load(path)
|
| 1114 |
-
if len(clip.shape) > 1:
|
| 1115 |
-
clip = torch.mean(clip, dim=0, keepdim=True)
|
| 1116 |
-
clip = resample(clip, sr, self.sr)
|
| 1117 |
-
#repeat
|
| 1118 |
-
n_repeats = math.ceil(self.min_non_silent_duration/item['duration'])
|
| 1119 |
-
clip = torch.repeat_interleave(clip, n_repeats, dim=0).reshape(-1)
|
| 1120 |
-
#pad
|
| 1121 |
-
n_samples = int(self.duration * self.sr)
|
| 1122 |
-
if clip.shape[0] >= n_samples:
|
| 1123 |
-
audio = clip[:n_samples]
|
| 1124 |
-
else:
|
| 1125 |
-
audio = torch.zeros(int(self.duration * self.sr), dtype=clip.dtype)
|
| 1126 |
-
start_pos = np.random.randint(0, max(0,(n_samples - clip.shape[0])))
|
| 1127 |
-
audio[start_pos:start_pos+clip.shape[0]] = clip
|
| 1128 |
-
return audio
|
| 1129 |
-
|
| 1130 |
-
except Exception as e:
|
| 1131 |
-
logger.error(f"Error reading {path}: {e}")
|
| 1132 |
-
wav = torch.zeros(int(self.duration * self.sr), dtype=torch.float32)
|
| 1133 |
-
return wav
|
| 1134 |
-
|
| 1135 |
-
class CombinedDataset(Dataset):
|
| 1136 |
-
@beartype
|
| 1137 |
-
def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
|
| 1138 |
-
self.datasets = datasets
|
| 1139 |
-
self.datasets_index = []
|
| 1140 |
-
|
| 1141 |
-
for i,dataset in enumerate(datasets):
|
| 1142 |
-
if dataset is None:
|
| 1143 |
-
continue
|
| 1144 |
-
for dup in range(ratios[i]):
|
| 1145 |
-
for j in range(len(dataset)):
|
| 1146 |
-
self.datasets_index.append((i,j))
|
| 1147 |
-
|
| 1148 |
-
def __len__(self):
|
| 1149 |
-
return len(self.datasets_index)
|
| 1150 |
-
|
| 1151 |
-
def __getitem__(self, idx):
|
| 1152 |
-
index = self.datasets_index[idx]
|
| 1153 |
-
i,j = index
|
| 1154 |
-
return self.datasets[i][j]
|
| 1155 |
-
|
| 1156 |
-
class CombinedDataset_random(Dataset):
|
| 1157 |
-
@beartype
|
| 1158 |
-
def __init__(self, num_examples:int, datasets: Sequence[Dataset], ratios: Sequence[int]):
|
| 1159 |
-
self.datasets = datasets
|
| 1160 |
-
self.datasets_index = []
|
| 1161 |
-
|
| 1162 |
-
for i,dataset in enumerate(datasets):
|
| 1163 |
-
if dataset is None:
|
| 1164 |
-
continue
|
| 1165 |
-
for dup in range(ratios[i]):
|
| 1166 |
-
for j in range(len(dataset)):
|
| 1167 |
-
self.datasets_index.append((i,j))
|
| 1168 |
-
|
| 1169 |
-
if num_examples > 0:
|
| 1170 |
-
self.random_choose = True
|
| 1171 |
-
self.dataset_len = num_examples
|
| 1172 |
-
else:
|
| 1173 |
-
self.random_choose = False
|
| 1174 |
-
self.dataset_len = len(self.datasets_index)
|
| 1175 |
-
|
| 1176 |
-
def __len__(self):
|
| 1177 |
-
return self.dataset_len
|
| 1178 |
-
|
| 1179 |
-
def __getitem__(self, idx):
|
| 1180 |
-
first_try = True
|
| 1181 |
-
try_cnt = 0
|
| 1182 |
-
while True:
|
| 1183 |
-
try:
|
| 1184 |
-
if(self.random_choose or not first_try):
|
| 1185 |
-
index2 = []
|
| 1186 |
-
index2.append(np.random.randint(0,len(self.datasets)))
|
| 1187 |
-
index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
|
| 1188 |
-
else:
|
| 1189 |
-
index2 = self.datasets_index[idx]
|
| 1190 |
-
first_try = False
|
| 1191 |
-
out = list(self.datasets[index2[0]][index2[1]])
|
| 1192 |
-
return out
|
| 1193 |
-
except:
|
| 1194 |
-
print("Error loadding ", index2)
|
| 1195 |
-
try_cnt += 1
|
| 1196 |
-
if(try_cnt>10):
|
| 1197 |
-
raise ValueError()
|
| 1198 |
-
|
| 1199 |
-
class SoundMixedDataset(Dataset):
|
| 1200 |
-
@staticmethod
|
| 1201 |
-
def music_desc(desc):
|
| 1202 |
-
return f'Music:<{desc}>'
|
| 1203 |
-
@staticmethod
|
| 1204 |
-
def sound_desc(desc):
|
| 1205 |
-
return f'Effect:<{desc}>'
|
| 1206 |
-
|
| 1207 |
-
def __init__(self,
|
| 1208 |
-
music_dataset: Dataset,
|
| 1209 |
-
sound_dataset: Dataset,
|
| 1210 |
-
mixed_ratios: Tuple[float, float, float] = (0.3, 0.3, 0.4) # 只有音乐:只有音效:音乐音效混合 的比例
|
| 1211 |
-
) -> None:
|
| 1212 |
-
self.music_dataset = music_dataset
|
| 1213 |
-
self.sound_dataset = sound_dataset
|
| 1214 |
-
music_r, sound_r, mix_r = [r/sum(mixed_ratios) for r in mixed_ratios] #化为0-1间的比例
|
| 1215 |
-
#三个概率区间的左端点
|
| 1216 |
-
self.music_anchor = 0
|
| 1217 |
-
self.sound_anchor = music_r
|
| 1218 |
-
self.mix_anchor = music_r + sound_r
|
| 1219 |
-
|
| 1220 |
-
def __len__(self):
|
| 1221 |
-
return len(self.music_dataset)
|
| 1222 |
-
|
| 1223 |
-
def get_random_sound_data(self):
|
| 1224 |
-
idx = random.randint(0, len(self.sound_dataset)-1)
|
| 1225 |
-
return self.sound_dataset[idx]
|
| 1226 |
-
|
| 1227 |
-
def __getitem__(self, idx):
|
| 1228 |
-
p = random.random()
|
| 1229 |
-
if p >= self.mix_anchor:
|
| 1230 |
-
music, m_desc = self.music_dataset[idx]
|
| 1231 |
-
sound, s_desc = self.get_random_sound_data()
|
| 1232 |
-
audio = music + sound
|
| 1233 |
-
if(audio.abs().max()>1.0):
|
| 1234 |
-
music = music / audio.abs().max() * 0.95
|
| 1235 |
-
audio = audio / audio.abs().max() * 0.95
|
| 1236 |
-
desc = self.music_desc(m_desc) + self.sound_desc(s_desc)
|
| 1237 |
-
return audio[None,:], music[None,:], desc
|
| 1238 |
-
elif p >= self.sound_anchor:
|
| 1239 |
-
audio, desc = self.get_random_sound_data()
|
| 1240 |
-
return audio[None,:], torch.zeros_like(audio[None,:]), self.sound_desc(desc)
|
| 1241 |
-
else:
|
| 1242 |
-
audio, desc = self.music_dataset[idx]
|
| 1243 |
-
return audio[None,:], audio[None,:], self.music_desc(desc)
|
| 1244 |
-
|
| 1245 |
-
|
| 1246 |
-
class DecoTagDataset(Dataset):
|
| 1247 |
-
'''这个类把普通的datatset包装成适用于标签解耦学习的dataset'''
|
| 1248 |
-
|
| 1249 |
-
TAG_TYPES = ('genre', 'mood', 'insrument')
|
| 1250 |
-
|
| 1251 |
-
def __init__(self, dataset_class: type, tag_map: Dict[str, str], *args, **kwargs):
|
| 1252 |
-
self.datasets = []
|
| 1253 |
-
for i, tag_t in enumerate(self.TAG_TYPES):
|
| 1254 |
-
kwargs['tag_types'] = [tag_map[tag_t]]
|
| 1255 |
-
kwargs['return_audio'] = (i == 0) #只有第0个需要返回音频和文本,其余只需要返回文本
|
| 1256 |
-
self.datasets.append(dataset_class(*args, **kwargs))
|
| 1257 |
-
|
| 1258 |
-
def __len__(self):
|
| 1259 |
-
return len(self.datasets[0])
|
| 1260 |
-
|
| 1261 |
-
def __getitem__(self, idx):
|
| 1262 |
-
audio, text = self.datasets[0][idx]
|
| 1263 |
-
texts = (text, self.datasets[1][idx][1], self.datasets[2][idx][1])
|
| 1264 |
-
return audio, texts
|
| 1265 |
-
|
| 1266 |
-
|
| 1267 |
-
class DecoTagWrapper:
|
| 1268 |
-
'''这是一个包装器,便于选择是否使用标签解耦学习'''
|
| 1269 |
-
def __init__(self, dataset_class: Dataset, deco_tag_types: List[str] = list(), switch_on: bool = False):
|
| 1270 |
-
self.dataset_class = dataset_class
|
| 1271 |
-
self.tag_map = dict(zip(DecoTagDataset.TAG_TYPES, deco_tag_types))
|
| 1272 |
-
self.switch_on = switch_on
|
| 1273 |
-
|
| 1274 |
-
def __call__(self, *args, **kwargs):
|
| 1275 |
-
if self.switch_on:
|
| 1276 |
-
return DecoTagDataset(self.dataset_class, self.tag_map, *args, **kwargs)
|
| 1277 |
-
else:
|
| 1278 |
-
return self.dataset_class(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py
DELETED
|
@@ -1,372 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import sys
|
| 3 |
-
import json
|
| 4 |
-
from typing import List, Union
|
| 5 |
-
|
| 6 |
-
from torch.utils.data import Dataset
|
| 7 |
-
import torchaudio
|
| 8 |
-
from torchaudio.functional import resample
|
| 9 |
-
import torch
|
| 10 |
-
import numpy as np
|
| 11 |
-
|
| 12 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 13 |
-
|
| 14 |
-
PARAGRAPH_GAP = 6
|
| 15 |
-
MIN_MUSIC_LEN = 3
|
| 16 |
-
|
| 17 |
-
def check_lryics(lyric):
|
| 18 |
-
_FILTER_STRING = [
|
| 19 |
-
'作词', '作曲', '编曲', '【', '策划',
|
| 20 |
-
'录音', '混音', '母带', ':', '制作',
|
| 21 |
-
'版权', '校对', '演奏', '制作', '伴奏'
|
| 22 |
-
]
|
| 23 |
-
for item in _FILTER_STRING:
|
| 24 |
-
if item in lyric:
|
| 25 |
-
return True
|
| 26 |
-
|
| 27 |
-
return False
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def process_lyrics(lines):
|
| 32 |
-
lyric_part = []
|
| 33 |
-
timestamp_part = []
|
| 34 |
-
|
| 35 |
-
timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
|
| 36 |
-
|
| 37 |
-
for i, line in enumerate(lines):
|
| 38 |
-
|
| 39 |
-
# 删除前几行的特定信息
|
| 40 |
-
if i<10 and check_lryics(line):
|
| 41 |
-
continue
|
| 42 |
-
|
| 43 |
-
# 检查是否包含有效的时间戳和歌词内容
|
| 44 |
-
if timestamp_pattern.match(line):
|
| 45 |
-
timestamp_end = line.rfind(']')
|
| 46 |
-
lyrics = line[timestamp_end + 1:].strip()
|
| 47 |
-
timestamps = line[:timestamp_end + 1]
|
| 48 |
-
|
| 49 |
-
if ':' in lyrics:
|
| 50 |
-
if len(lyrics.split(":")[0]) <=5:
|
| 51 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 52 |
-
# if lyrics: # 确保歌词部分不是空的
|
| 53 |
-
# lyric_part.append(lyrics)
|
| 54 |
-
# timestamp_part.append(timestamps)
|
| 55 |
-
# print(processed_lyrics)
|
| 56 |
-
return timestamp_part, lyric_part
|
| 57 |
-
|
| 58 |
-
def get_timestamps(timestamp_part):
|
| 59 |
-
|
| 60 |
-
# 转换为秒
|
| 61 |
-
|
| 62 |
-
timestamps = []
|
| 63 |
-
|
| 64 |
-
for line in timestamp_part:
|
| 65 |
-
match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
|
| 66 |
-
if match:
|
| 67 |
-
minutes = int(match.group(1))
|
| 68 |
-
seconds = float(match.group(2))
|
| 69 |
-
millis = float(match.group(3)) if match.group(3) else 0
|
| 70 |
-
total_seconds = minutes * 60 + seconds + millis
|
| 71 |
-
timestamps.append(total_seconds)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
return timestamps
|
| 75 |
-
|
| 76 |
-
def process_lyrics_lrc(lyrics):
|
| 77 |
-
timestamp_part, lyric_part = process_lyrics(lyrics)
|
| 78 |
-
# print(timestamp_part)
|
| 79 |
-
# print(lyric_part)
|
| 80 |
-
timestamps = get_timestamps(timestamp_part)
|
| 81 |
-
# print(timestamps)
|
| 82 |
-
if len(timestamps) == 0:
|
| 83 |
-
# print(f'{lyric_path}')
|
| 84 |
-
return []
|
| 85 |
-
|
| 86 |
-
slice_start = timestamps[0]
|
| 87 |
-
slice_start_idx = 0
|
| 88 |
-
|
| 89 |
-
output_list = []
|
| 90 |
-
for i in range(1, len(timestamps)):
|
| 91 |
-
# 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
|
| 92 |
-
if timestamps[i] - slice_start > 30:
|
| 93 |
-
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
| 94 |
-
|
| 95 |
-
slice_start = timestamps[i]
|
| 96 |
-
slice_start_idx = i
|
| 97 |
-
|
| 98 |
-
return output_list
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def process_lyrics_yrc(lyrics):
|
| 103 |
-
|
| 104 |
-
timestamps, lyric_part = extract_lrc(lyrics)
|
| 105 |
-
|
| 106 |
-
# timestamp_part, lyric_part = process_lyrics(lyrics)
|
| 107 |
-
# import pdb; pdb.set_trace()
|
| 108 |
-
# print(timestamp_part)
|
| 109 |
-
# print(lyric_part)
|
| 110 |
-
# timestamps = get_timestamps(timestamp_part)
|
| 111 |
-
# print(timestamps)
|
| 112 |
-
if len(timestamps) == 0:
|
| 113 |
-
# print(f'{lyric_path}')
|
| 114 |
-
return []
|
| 115 |
-
|
| 116 |
-
slice_start = timestamps[0]
|
| 117 |
-
slice_start_idx = 0
|
| 118 |
-
|
| 119 |
-
output_list = []
|
| 120 |
-
for i in range(1, len(timestamps)):
|
| 121 |
-
# 如果累积时间超过30秒,则进行切分
|
| 122 |
-
if timestamps[i] - slice_start > 30:
|
| 123 |
-
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
| 124 |
-
|
| 125 |
-
slice_start = timestamps[i]
|
| 126 |
-
slice_start_idx = i
|
| 127 |
-
# import pdb; pdb.set_trace()
|
| 128 |
-
return output_list
|
| 129 |
-
|
| 130 |
-
def extract_lrc(lyrics):
|
| 131 |
-
timestamp_part, lyric_part = [], []
|
| 132 |
-
|
| 133 |
-
for i, text in enumerate(lyrics):
|
| 134 |
-
# 提取中括号内的内容
|
| 135 |
-
bracket_content = re.search(r'\[(.*?)\]', text).group(1)
|
| 136 |
-
bracket_content = bracket_content.split(',')
|
| 137 |
-
# 提取小括号内的内容
|
| 138 |
-
parentheses_content = re.findall(r'\((.*?)\)', text)
|
| 139 |
-
# 提取其他内容
|
| 140 |
-
other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
|
| 141 |
-
|
| 142 |
-
# 数据怎么处理?
|
| 143 |
-
if i<10 and check_lryics(other_content):
|
| 144 |
-
continue
|
| 145 |
-
timestamp_part.append(float(bracket_content[0])/1000)
|
| 146 |
-
lyric_part.append(other_content)
|
| 147 |
-
return timestamp_part, lyric_part
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
class WYYSongDataset(Dataset):
|
| 152 |
-
def __init__(self,
|
| 153 |
-
metadata_path: Union[str, List[str]],
|
| 154 |
-
sr:int = 0,
|
| 155 |
-
use_lang = ['en', 'zh-cn'],
|
| 156 |
-
num_examples = -1,
|
| 157 |
-
max_dur = 20,
|
| 158 |
-
min_dur=0,
|
| 159 |
-
add_music=False,
|
| 160 |
-
pad_to_max= True,
|
| 161 |
-
):
|
| 162 |
-
|
| 163 |
-
self.sr = sr
|
| 164 |
-
self.use_lang = use_lang
|
| 165 |
-
self.data = []
|
| 166 |
-
if type(metadata_path) == str:
|
| 167 |
-
metadata_path = [metadata_path]
|
| 168 |
-
for _meta in metadata_path:
|
| 169 |
-
self._load_metadata(_meta)
|
| 170 |
-
self.max_dur = max_dur
|
| 171 |
-
self.min_dur = min_dur
|
| 172 |
-
self.pad_to_max = pad_to_max
|
| 173 |
-
self.add_music = add_music
|
| 174 |
-
|
| 175 |
-
# buffer
|
| 176 |
-
self.lyric_buffer = {}
|
| 177 |
-
|
| 178 |
-
if(num_examples<=0):
|
| 179 |
-
self.dataset_len = len(self.data)
|
| 180 |
-
self.random_slc = False
|
| 181 |
-
else:
|
| 182 |
-
self.dataset_len = num_examples
|
| 183 |
-
self.random_slc = True
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
# 读取jsonl文件
|
| 187 |
-
def _load_metadata(self, metadata_path):
|
| 188 |
-
with open(metadata_path) as fp:
|
| 189 |
-
lines = fp.readlines()
|
| 190 |
-
for line in lines:
|
| 191 |
-
item = json.loads(line)
|
| 192 |
-
if '伴奏' not in item['path']:
|
| 193 |
-
# if "lang_type" in item and item['lang_type'] == 'en':
|
| 194 |
-
if "lang_type" in item:
|
| 195 |
-
self.data.append(item)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
def __len__(self):
|
| 199 |
-
return self.dataset_len
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def __getitem__(self, idx):
|
| 203 |
-
try_cnt = 0
|
| 204 |
-
while True:
|
| 205 |
-
if(self.random_slc):
|
| 206 |
-
idx = np.random.randint(0, len(self.data))
|
| 207 |
-
yrc_lyrics = []
|
| 208 |
-
lrc_lyrics = []
|
| 209 |
-
try:
|
| 210 |
-
info = self.data[idx]
|
| 211 |
-
|
| 212 |
-
# audio path
|
| 213 |
-
path = info["path"]
|
| 214 |
-
lang_type = info["lang_type"]
|
| 215 |
-
lyrics = info['lyrics'] # chinese
|
| 216 |
-
# lyrics = info['lyrics_phone']
|
| 217 |
-
|
| 218 |
-
# 随机选取一个lyric段落
|
| 219 |
-
|
| 220 |
-
parsed_lyrics = []
|
| 221 |
-
# st_idx = np.random.randint(0, len(lyrics))
|
| 222 |
-
for ly_id in range(len(lyrics)):
|
| 223 |
-
lyric = lyrics[ly_id].strip()
|
| 224 |
-
st, et, lyric = self.parse_lyric(lyric)
|
| 225 |
-
|
| 226 |
-
if et - st >= self.max_dur:
|
| 227 |
-
continue #TODO 前后外沿 [MUSIC]
|
| 228 |
-
|
| 229 |
-
if parsed_lyrics != []:
|
| 230 |
-
if st - parsed_lyrics[-1][1] >= PARAGRAPH_GAP: # 大gap
|
| 231 |
-
parsed_lyrics.append((parsed_lyrics[-1][1], st, '[GAP]'))
|
| 232 |
-
elif self.add_music and st - parsed_lyrics[-1][1] >= MIN_MUSIC_LEN:
|
| 233 |
-
parsed_lyrics.append((parsed_lyrics[-1][1], st, '[MUSIC]'))
|
| 234 |
-
|
| 235 |
-
lyric = lyric.replace("\xa0", " ")
|
| 236 |
-
lyric = " ".join(lyric.split())
|
| 237 |
-
parsed_lyrics.append((st, et, lyric))
|
| 238 |
-
|
| 239 |
-
assert parsed_lyrics != []
|
| 240 |
-
# if parsed_lyrics[-1][1] - parsed_lyrics[0][0] > self.max_dur:
|
| 241 |
-
# print(f"{parsed_lyrics[0][0]}-{parsed_lyrics[-1][1]} {parsed_lyrics}", file=open('tmp.txt', 'a'))
|
| 242 |
-
|
| 243 |
-
parsed_lyrics = [(0, parsed_lyrics[0][0], '[GAP]')] + parsed_lyrics
|
| 244 |
-
|
| 245 |
-
possible_starts = [e for e,i in enumerate(parsed_lyrics) if i[2]=='[GAP]']
|
| 246 |
-
st_idx = np.random.choice(possible_starts)
|
| 247 |
-
|
| 248 |
-
paraphrase = []
|
| 249 |
-
for i in parsed_lyrics[st_idx+1:]:
|
| 250 |
-
if i[2] == '[GAP]':
|
| 251 |
-
break
|
| 252 |
-
paraphrase.append(i)
|
| 253 |
-
# print(paraphrase, lyrics)
|
| 254 |
-
|
| 255 |
-
while paraphrase[-1][1] - paraphrase[0][0] > self.max_dur:
|
| 256 |
-
if np.random.rand() > 0.2:
|
| 257 |
-
paraphrase.pop(-1) # 大概率从后面截断
|
| 258 |
-
else:
|
| 259 |
-
paraphrase.pop(0) # 小概率截前面
|
| 260 |
-
|
| 261 |
-
st, et, lyric = paraphrase[0][0], paraphrase[-1][1], ', '.join([i[2] for i in paraphrase]) # [SEP]
|
| 262 |
-
# print(st, et, lyric)
|
| 263 |
-
# import pdb; pdb.set_trace()
|
| 264 |
-
assert self.min_dur < et - st < self.max_dur, f"{st}-{et} {lyric}"
|
| 265 |
-
# print(et-st, lyric)
|
| 266 |
-
# import pdb; pdb.set_trace()
|
| 267 |
-
|
| 268 |
-
if info["lang_type"] == 'en':
|
| 269 |
-
# print(len(lyric.split())/(et-st))
|
| 270 |
-
char_num = sum([len(lrc[-1].split()) for lrc in paraphrase])
|
| 271 |
-
assert 6 > char_num / (et-st) > 1
|
| 272 |
-
else:
|
| 273 |
-
# print(len(lyric.split())/(et-st))
|
| 274 |
-
char_num = sum([len(lrc[-1]) for lrc in paraphrase])
|
| 275 |
-
assert 6 > char_num / (et-st) > 1
|
| 276 |
-
|
| 277 |
-
# 读取音频文件
|
| 278 |
-
cur_sample_rate = torchaudio.info(path).sample_rate
|
| 279 |
-
offset = int(cur_sample_rate*st)
|
| 280 |
-
num_frames = int(cur_sample_rate * (et -st))
|
| 281 |
-
chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
|
| 282 |
-
# chunk = torch.zeros(1, 48000*15)
|
| 283 |
-
if abs(chunk.shape[-1] - num_frames) > num_frames * 0.05: # 音频文件长度与歌词不一致
|
| 284 |
-
print(f"fail to load {path} from {st} to {et} !")
|
| 285 |
-
raise FileNotFoundError
|
| 286 |
-
# 随机选取一个channel
|
| 287 |
-
if(chunk.shape[0]>1):
|
| 288 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 289 |
-
else:
|
| 290 |
-
chunk = chunk[[0],:].float()
|
| 291 |
-
|
| 292 |
-
if(cur_sample_rate!=self.sr):
|
| 293 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 294 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
|
| 295 |
-
|
| 296 |
-
if self.pad_to_max:
|
| 297 |
-
chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0)
|
| 298 |
-
|
| 299 |
-
# print(self.sz_cnt)
|
| 300 |
-
return chunk, lyric, [st, et], path, lang_type
|
| 301 |
-
except (AssertionError, FileNotFoundError, RuntimeError) as e: # 其他Error不ok
|
| 302 |
-
# print("Error loadding ", info["path"])
|
| 303 |
-
try_cnt += 1
|
| 304 |
-
idx = np.random.randint(0, len(self.data))
|
| 305 |
-
if(try_cnt>100):
|
| 306 |
-
raise e
|
| 307 |
-
|
| 308 |
-
def parse_lyric(self, lyric):
|
| 309 |
-
pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
|
| 310 |
-
match = re.search(pattern, lyric)
|
| 311 |
-
|
| 312 |
-
start_time = float(match.group(1))
|
| 313 |
-
end_time = float(match.group(2))
|
| 314 |
-
content = match.group(3)
|
| 315 |
-
return start_time, end_time, content
|
| 316 |
-
|
| 317 |
-
def pad_2d_tensor(self, x, max_len, pad_id):
|
| 318 |
-
# 获取输入 tensor 的形状
|
| 319 |
-
batch_size, seq_len = x.size()
|
| 320 |
-
max_len = max(max_len, seq_len)
|
| 321 |
-
# 计算需要填充的长度
|
| 322 |
-
pad_len = max_len - seq_len
|
| 323 |
-
|
| 324 |
-
# 如果需要填充
|
| 325 |
-
if pad_len > 0:
|
| 326 |
-
# 创建填充 tensor
|
| 327 |
-
pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
|
| 328 |
-
|
| 329 |
-
# 沿第二个维度(列)连接输入 tensor 和填充 tensor
|
| 330 |
-
padded_tensor = torch.cat([x, pad_tensor], dim=1)
|
| 331 |
-
else:
|
| 332 |
-
# 如果不需要填充,直接返回输入 tensor
|
| 333 |
-
padded_tensor = x
|
| 334 |
-
|
| 335 |
-
return padded_tensor
|
| 336 |
-
|
| 337 |
-
def collect_data(data_list):
|
| 338 |
-
audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
|
| 339 |
-
lyrics = [data[1] for data in data_list]
|
| 340 |
-
st_et = [data[2] for data in data_list]
|
| 341 |
-
paths = [data[3] for data in data_list]
|
| 342 |
-
lang_types = [data[4] for data in data_list]
|
| 343 |
-
return audios, lyrics, st_et
|
| 344 |
-
# return audios, lyrics, st_et
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
def build_dataset(train_jsonl_list, val_jsonl_list, min_dur=0, max_dur=20, add_music=False):
|
| 348 |
-
print(min_dur,max_dur)
|
| 349 |
-
print(train_jsonl_list)
|
| 350 |
-
# ["exp/wyy3_20240418_v2f.jsonl",
|
| 351 |
-
# "exp/tme_lyric_baokuan.jsonl"]
|
| 352 |
-
train_dataset = WYYSongDataset(
|
| 353 |
-
metadata_path = train_jsonl_list,
|
| 354 |
-
sr = 48000,
|
| 355 |
-
use_lang = ['zh-cn', 'en'],
|
| 356 |
-
num_examples = 10*10000,
|
| 357 |
-
min_dur=min_dur,
|
| 358 |
-
max_dur=max_dur,
|
| 359 |
-
add_music=add_music
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
-
valid_dataset = WYYSongDataset(
|
| 363 |
-
metadata_path = val_jsonl_list,
|
| 364 |
-
sr = 48000,
|
| 365 |
-
use_lang = ['zh-cn', 'en'],
|
| 366 |
-
num_examples = 500,
|
| 367 |
-
min_dur=min_dur,
|
| 368 |
-
max_dur=max_dur,
|
| 369 |
-
add_music=add_music
|
| 370 |
-
)
|
| 371 |
-
print(train_jsonl_list, "\t total_song = ", len(train_dataset.data))
|
| 372 |
-
return train_dataset, valid_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py
DELETED
|
@@ -1,830 +0,0 @@
|
|
| 1 |
-
from torch.utils.data import Dataset
|
| 2 |
-
from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List
|
| 3 |
-
from beartype import beartype
|
| 4 |
-
from beartype.door import is_bearable
|
| 5 |
-
import random
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import os
|
| 8 |
-
from torchaudio.functional import resample
|
| 9 |
-
import torch
|
| 10 |
-
import typing as tp
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
import torchaudio as ta
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
import numpy as np
|
| 15 |
-
import json
|
| 16 |
-
import yaml
|
| 17 |
-
import torchaudio
|
| 18 |
-
import math
|
| 19 |
-
import re
|
| 20 |
-
from loguru import logger
|
| 21 |
-
|
| 22 |
-
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
| 23 |
-
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 24 |
-
|
| 25 |
-
super().__init__()
|
| 26 |
-
|
| 27 |
-
self.n_samples = n_samples
|
| 28 |
-
self.sample_rate = sample_rate
|
| 29 |
-
self.randomize = randomize
|
| 30 |
-
|
| 31 |
-
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 32 |
-
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
| 33 |
-
# print(duration,(float(self.n_samples)/self.sample_rate+1))
|
| 34 |
-
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
| 35 |
-
t_start = 0.
|
| 36 |
-
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
| 37 |
-
offset = 0
|
| 38 |
-
# print('c1:',chunk.shape)
|
| 39 |
-
else:
|
| 40 |
-
offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 41 |
-
t_start = offset / float(cur_sample_rate) / duration
|
| 42 |
-
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
| 43 |
-
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 44 |
-
# print('offset:',offset)
|
| 45 |
-
# print('c0:',chunk.shape)
|
| 46 |
-
# Pad with silence if necessary.
|
| 47 |
-
if(chunk.shape[0]>1):
|
| 48 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 49 |
-
else:
|
| 50 |
-
chunk = chunk[[0],:].float()
|
| 51 |
-
if(cur_sample_rate!=self.sample_rate):
|
| 52 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 53 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
| 54 |
-
# print('b:',self.sample_rate,chunk.shape)
|
| 55 |
-
if chunk.shape[-1] < self.n_samples:
|
| 56 |
-
chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
| 57 |
-
else:
|
| 58 |
-
chunk = chunk[:,0:self.n_samples]
|
| 59 |
-
seconds_start = math.floor(offset / cur_sample_rate)
|
| 60 |
-
seconds_total = math.floor(duration)
|
| 61 |
-
|
| 62 |
-
return (
|
| 63 |
-
chunk,
|
| 64 |
-
t_start,
|
| 65 |
-
t_end,
|
| 66 |
-
seconds_start,
|
| 67 |
-
seconds_total
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
|
| 72 |
-
if USE_DUMMY_AUDIO:
|
| 73 |
-
logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
|
| 74 |
-
|
| 75 |
-
class SafeAudioReader:
|
| 76 |
-
"""
|
| 77 |
-
This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
|
| 78 |
-
"""
|
| 79 |
-
def __init__(self,
|
| 80 |
-
duration: float, # 返回音频长度
|
| 81 |
-
sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
|
| 82 |
-
randomize: bool = True
|
| 83 |
-
):
|
| 84 |
-
self.n_samples = int(sample_rate * max(duration, 0))
|
| 85 |
-
self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
|
| 86 |
-
|
| 87 |
-
#NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
|
| 88 |
-
def __call__(self,
|
| 89 |
-
filepath: os.PathLike, # 音频路径
|
| 90 |
-
origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
|
| 91 |
-
origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
|
| 92 |
-
) -> torch.Tensor:
|
| 93 |
-
if USE_DUMMY_AUDIO:
|
| 94 |
-
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
| 95 |
-
return wav
|
| 96 |
-
try:
|
| 97 |
-
if origin_sample_rate is None or origin_duration is None:
|
| 98 |
-
audio_info = torchaudio.info(filepath)
|
| 99 |
-
origin_sample_rate = audio_info.sample_rate
|
| 100 |
-
origin_duration = audio_info.num_frames / origin_sample_rate
|
| 101 |
-
wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
|
| 102 |
-
except Exception as e:
|
| 103 |
-
logger.error(f"Error reading {filepath}: {e}")
|
| 104 |
-
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
| 105 |
-
return wav
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class PromptTemplate:
|
| 109 |
-
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
|
| 110 |
-
self.template_text = template_text
|
| 111 |
-
self.tag_map = tag_map
|
| 112 |
-
self.lang = lang
|
| 113 |
-
|
| 114 |
-
@property
|
| 115 |
-
def tags(self):
|
| 116 |
-
return tuple(self.tag_map.keys())
|
| 117 |
-
|
| 118 |
-
def apply(self, **kwargs):
|
| 119 |
-
for tag in list(kwargs.keys()):
|
| 120 |
-
if kwargs[tag] == '':
|
| 121 |
-
kwargs.pop(tag)
|
| 122 |
-
for tag in self.tags:
|
| 123 |
-
if tag in kwargs:
|
| 124 |
-
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
|
| 125 |
-
else:
|
| 126 |
-
kwargs[tag] = ''
|
| 127 |
-
prompt = self.template_text.format(**kwargs)
|
| 128 |
-
|
| 129 |
-
return self.beautify(prompt)
|
| 130 |
-
|
| 131 |
-
def beautify(self, text):
|
| 132 |
-
if self.lang == 'en':
|
| 133 |
-
return self._beautify_en(text)
|
| 134 |
-
elif self.lang == 'zh':
|
| 135 |
-
return self._beautify_zh(text)
|
| 136 |
-
else:
|
| 137 |
-
raise ValueError(f'Unknown language {self.lang}')
|
| 138 |
-
|
| 139 |
-
@staticmethod
|
| 140 |
-
def _beautify_en(text):
|
| 141 |
-
# no continuous commas without content between them
|
| 142 |
-
text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
|
| 143 |
-
# no continuous whitespace
|
| 144 |
-
text = re.sub(r'\s+', ' ', text)
|
| 145 |
-
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace
|
| 146 |
-
text = re.sub(r'\s+,', r',', text)
|
| 147 |
-
text = re.sub(r',\s+', r', ', text)
|
| 148 |
-
# no whitespace before the full stop
|
| 149 |
-
text = re.sub(r'\s+\.', r'.', text)
|
| 150 |
-
# strip whitespace, comma, and replace ',.'
|
| 151 |
-
text = text.strip(' ,')
|
| 152 |
-
text = text.replace(',.', '.')
|
| 153 |
-
return text
|
| 154 |
-
|
| 155 |
-
@staticmethod
|
| 156 |
-
def _beautify_zh(text):
|
| 157 |
-
# no continuous commas without content between them
|
| 158 |
-
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
|
| 159 |
-
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
|
| 160 |
-
# assume there should be NO whitespace in Chinese
|
| 161 |
-
text = re.sub(r'\s+', r'', text)
|
| 162 |
-
# strip whitespace, comma, and replace ',。'
|
| 163 |
-
text = text.strip(', 、')
|
| 164 |
-
text = text.replace(',。', '。')
|
| 165 |
-
return text
|
| 166 |
-
|
| 167 |
-
def __repr__(self):
|
| 168 |
-
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
|
| 169 |
-
|
| 170 |
-
__str__ = __repr__
|
| 171 |
-
|
| 172 |
-
def parse_prompt_template(prompt_template_text, lang='en'):
|
| 173 |
-
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
|
| 174 |
-
tag_pattern = re.compile(r'{.+?}', re.DOTALL)
|
| 175 |
-
|
| 176 |
-
template_text = prompt_template_text.strip()
|
| 177 |
-
span_texts = span_pattern.findall(prompt_template_text)
|
| 178 |
-
tag_map = {}
|
| 179 |
-
for span_text in span_texts:
|
| 180 |
-
tag = tag_pattern.findall(span_text)[0].strip('{}')
|
| 181 |
-
tag_map[tag] = span_text
|
| 182 |
-
template_text = template_text.replace(span_text, '{'+tag+'}')
|
| 183 |
-
|
| 184 |
-
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
|
| 185 |
-
|
| 186 |
-
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
|
| 187 |
-
with open(path, 'r') as f:
|
| 188 |
-
lines = f.readlines()
|
| 189 |
-
cnt = 0
|
| 190 |
-
pts = []
|
| 191 |
-
for line in lines:
|
| 192 |
-
pt = parse_prompt_template(line, lang=lang)
|
| 193 |
-
cnt += 1
|
| 194 |
-
if len(pt.tags) < num:
|
| 195 |
-
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
|
| 196 |
-
pts.append(pt)
|
| 197 |
-
|
| 198 |
-
return pts
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def get_base_dir_file(key: os.PathLike):
|
| 202 |
-
base = os.path.basename(key)
|
| 203 |
-
dirname = os.path.basename(os.path.dirname(key))
|
| 204 |
-
return os.path.join(dirname, base)
|
| 205 |
-
|
| 206 |
-
def read_jsonlike(path: os.PathLike):
|
| 207 |
-
#json or jsonl
|
| 208 |
-
if str(path).endswith(".json"):
|
| 209 |
-
with open(path, 'r', encoding='utf8') as f:
|
| 210 |
-
data = json.load(f)
|
| 211 |
-
return data
|
| 212 |
-
elif str(path).endswith(".jsonl"):
|
| 213 |
-
with open(path, 'r', encoding='utf8') as f:
|
| 214 |
-
data = [json.loads(line) for line in f.readlines()]
|
| 215 |
-
return data
|
| 216 |
-
else:
|
| 217 |
-
raise ValueError("Unknown file format")
|
| 218 |
-
|
| 219 |
-
dist_prob_map = {
|
| 220 |
-
1: (1.0,),
|
| 221 |
-
2: (0.5, 0.5),
|
| 222 |
-
3: (0.3, 0.4, 0.3),
|
| 223 |
-
4: (0.2, 0.3, 0.3, 0.2),
|
| 224 |
-
5: (0.2, 0.2, 0.3, 0.2, 0.1),
|
| 225 |
-
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
|
| 226 |
-
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
| 227 |
-
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
| 228 |
-
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
| 229 |
-
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
| 230 |
-
}
|
| 231 |
-
|
| 232 |
-
dist_prob_map_low = {
|
| 233 |
-
1: (1.0,),
|
| 234 |
-
2: (0.8, 0.2),
|
| 235 |
-
3: (0.8, 0.1, 0.1),
|
| 236 |
-
4: (0.7, 0.1, 0.1, 0.1),
|
| 237 |
-
5: (0.7, 0.1, 0.1, 0.05, 0.05),
|
| 238 |
-
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
_bpm_range_rights = (
|
| 242 |
-
(40, '20-40'),
|
| 243 |
-
(60, '40-60'),
|
| 244 |
-
(66, '60-66'),
|
| 245 |
-
(76, '66-76'),
|
| 246 |
-
(108, '76-108'),
|
| 247 |
-
(120, '108-120'),
|
| 248 |
-
(168, '120-168'),
|
| 249 |
-
(176, '168-176'),
|
| 250 |
-
(200, '176-200')
|
| 251 |
-
)
|
| 252 |
-
_bpm_desc_map = {
|
| 253 |
-
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
|
| 254 |
-
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
|
| 255 |
-
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
|
| 256 |
-
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
|
| 257 |
-
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
|
| 258 |
-
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
|
| 259 |
-
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
|
| 260 |
-
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
|
| 261 |
-
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
|
| 262 |
-
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
|
| 263 |
-
}
|
| 264 |
-
_bpm_desc_map_zh = {
|
| 265 |
-
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
|
| 266 |
-
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
|
| 267 |
-
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
|
| 268 |
-
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
|
| 269 |
-
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
|
| 270 |
-
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
|
| 271 |
-
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
|
| 272 |
-
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
|
| 273 |
-
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
|
| 274 |
-
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
|
| 275 |
-
}
|
| 276 |
-
def get_bpm_range(bpm):
|
| 277 |
-
bpm = int(bpm)
|
| 278 |
-
for right, tag in _bpm_range_rights:
|
| 279 |
-
if bpm <= right:
|
| 280 |
-
return tag
|
| 281 |
-
return '>200'
|
| 282 |
-
|
| 283 |
-
def gen_bpm_descript(bpm, lang='en'):
|
| 284 |
-
bpm_range = get_bpm_range(bpm)
|
| 285 |
-
if lang == 'en':
|
| 286 |
-
return random.choice(_bpm_desc_map[bpm_range])
|
| 287 |
-
elif lang == 'zh':
|
| 288 |
-
return random.choice(_bpm_desc_map_zh[bpm_range])
|
| 289 |
-
else:
|
| 290 |
-
raise ValueError(f"Unknown language {lang}")
|
| 291 |
-
|
| 292 |
-
def read_translate(translate: Optional[Dict[str, os.PathLike]]):
|
| 293 |
-
if translate is None:
|
| 294 |
-
return None
|
| 295 |
-
return {k: read_jsonlike(path) for k, path in translate.items()}
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
class MagnaTagATuneDataset(Dataset):
|
| 299 |
-
def __init__(self):
|
| 300 |
-
pass
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
def tags_to_desc(tag_list, sep=',') -> str:
|
| 304 |
-
if not isinstance(tag_list, Sequence):
|
| 305 |
-
return str(tag_list)
|
| 306 |
-
if isinstance(tag_list, str):
|
| 307 |
-
return tag_list
|
| 308 |
-
if len(tag_list) <= 0:
|
| 309 |
-
return ''
|
| 310 |
-
elif len(tag_list) <= 5:
|
| 311 |
-
probs = dist_prob_map[len(tag_list)]
|
| 312 |
-
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
|
| 313 |
-
random.shuffle(tag_list)
|
| 314 |
-
tag_list = tag_list[:tags_num]
|
| 315 |
-
return sep.join(tag_list)
|
| 316 |
-
else:
|
| 317 |
-
probs = dist_prob_map[5]
|
| 318 |
-
tags_num = random.choices(range(1, 6), probs)[0]
|
| 319 |
-
random.shuffle(tag_list)
|
| 320 |
-
tag_list = tag_list[:tags_num]
|
| 321 |
-
return sep.join(tag_list)
|
| 322 |
-
|
| 323 |
-
def get_sr_and_duration_info(item):
|
| 324 |
-
return item.get('sample_rate', None), item.get('duration', None)
|
| 325 |
-
|
| 326 |
-
class MtgJamendoDatasetFromJson(Dataset):
|
| 327 |
-
def __init__(self,
|
| 328 |
-
data_dir:str,
|
| 329 |
-
json_path:str,
|
| 330 |
-
duration:float=10,
|
| 331 |
-
sr:int = 0,
|
| 332 |
-
*,
|
| 333 |
-
lang = 'en',
|
| 334 |
-
return_path = False,
|
| 335 |
-
prompt_template_path: os.PathLike = None,
|
| 336 |
-
tag_types = [],
|
| 337 |
-
translate:Optional[Dict[str, os.PathLike]] = None,
|
| 338 |
-
):
|
| 339 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 340 |
-
|
| 341 |
-
self.data_dir = data_dir
|
| 342 |
-
self._load_metadata_json(json_path)
|
| 343 |
-
self.sr = sr
|
| 344 |
-
self.duration = duration
|
| 345 |
-
self.return_path = return_path
|
| 346 |
-
self.lang = lang
|
| 347 |
-
|
| 348 |
-
self.use_dynamic_prompt = prompt_template_path is not None
|
| 349 |
-
if self.use_dynamic_prompt:
|
| 350 |
-
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
|
| 351 |
-
self.tag_types = tag_types
|
| 352 |
-
|
| 353 |
-
self.translate = read_translate(translate)
|
| 354 |
-
if not self.use_dynamic_prompt and self.lang != 'en':
|
| 355 |
-
raise NotImplementedError
|
| 356 |
-
|
| 357 |
-
#这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
|
| 358 |
-
WEAK_TAG_LIST = ["title", "artist"]
|
| 359 |
-
|
| 360 |
-
def _load_metadata_json(self, json_path):
|
| 361 |
-
with open(json_path) as fp:
|
| 362 |
-
self.data = json.load(fp)
|
| 363 |
-
|
| 364 |
-
def convert_key_to_path(self, key):
|
| 365 |
-
return os.path.join(self.data_dir, get_base_dir_file(key))
|
| 366 |
-
|
| 367 |
-
def __len__(self):
|
| 368 |
-
return len(self.data)
|
| 369 |
-
|
| 370 |
-
def __getitem__(self, idx):
|
| 371 |
-
item = self.data[idx]
|
| 372 |
-
path = self.convert_key_to_path(item['key'])
|
| 373 |
-
description = self.generate_description(item)
|
| 374 |
-
|
| 375 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 376 |
-
audio = self.audio_reader(path, sr, duration)
|
| 377 |
-
|
| 378 |
-
if self.return_path:
|
| 379 |
-
return audio, description, path
|
| 380 |
-
return audio, description
|
| 381 |
-
|
| 382 |
-
def tags_to_desc(self, tag_list, tag_type) -> str:
|
| 383 |
-
if self.lang == 'en':
|
| 384 |
-
return tags_to_desc(tag_list)
|
| 385 |
-
elif self.lang == 'zh':
|
| 386 |
-
translator = self.translate[tag_type]
|
| 387 |
-
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
| 388 |
-
return tags_to_desc(translated_tag_list, sep='、')
|
| 389 |
-
|
| 390 |
-
def generate_description(self, item):
|
| 391 |
-
if self.use_dynamic_prompt:
|
| 392 |
-
# dynamically generate prompt from given prompt template
|
| 393 |
-
prompt_template = random.choice(self.prompt_templates)
|
| 394 |
-
description = self.generate_description_dynamic(item, prompt_template)
|
| 395 |
-
|
| 396 |
-
else:
|
| 397 |
-
# use ordinary static prompt instead
|
| 398 |
-
description = self.generate_description_ordinary(item)
|
| 399 |
-
return description
|
| 400 |
-
|
| 401 |
-
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
| 402 |
-
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
| 403 |
-
exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
|
| 404 |
-
exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
|
| 405 |
-
|
| 406 |
-
if len(exists_strong_tag) > 0:
|
| 407 |
-
probs = dist_prob_map[len(exists_strong_tag)]
|
| 408 |
-
tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
|
| 409 |
-
random.shuffle(exists_strong_tag)
|
| 410 |
-
tags = exists_strong_tag[:tags_num]
|
| 411 |
-
weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
|
| 412 |
-
weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
|
| 413 |
-
random.shuffle(exists_weak_tag)
|
| 414 |
-
weak_tags = exists_weak_tag[:weak_tags_num]
|
| 415 |
-
tags += weak_tags
|
| 416 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
| 417 |
-
prompt = prompt_template.apply(**tags_args)
|
| 418 |
-
else:
|
| 419 |
-
# no strong tags, use all weak tags instead
|
| 420 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
|
| 421 |
-
prompt = prompt_template.apply(**tags_args)
|
| 422 |
-
|
| 423 |
-
return prompt
|
| 424 |
-
|
| 425 |
-
def generate_description_ordinary(self, data, thresh = 0.3):
|
| 426 |
-
# Initialize the description with title and artist
|
| 427 |
-
description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}'
|
| 428 |
-
|
| 429 |
-
# Add genre if available
|
| 430 |
-
if data["genre"] and random.random() > thresh:
|
| 431 |
-
genres = ', '.join(data["genre"])
|
| 432 |
-
description += f', belonging to the {genres} genres'
|
| 433 |
-
|
| 434 |
-
# Add moods if available
|
| 435 |
-
if data["moods"] and random.random() > thresh:
|
| 436 |
-
moods = ', '.join(data["moods"])
|
| 437 |
-
description += f'. This track conveys a {moods} mood'
|
| 438 |
-
|
| 439 |
-
# Add instruments if available
|
| 440 |
-
if data["instrument"] and random.random() > thresh:
|
| 441 |
-
instruments = ', '.join(data["instrument"])
|
| 442 |
-
description += f', and primarily features the following instruments: {instruments}'
|
| 443 |
-
|
| 444 |
-
# Add a period to end the description
|
| 445 |
-
description += '.'
|
| 446 |
-
|
| 447 |
-
return description
|
| 448 |
-
|
| 449 |
-
class AudioStockDataset(Dataset):
|
| 450 |
-
def __init__(self,
|
| 451 |
-
metadata_path:str,
|
| 452 |
-
duration:float=10,
|
| 453 |
-
sr:int = 0,
|
| 454 |
-
return_path = False,
|
| 455 |
-
return_audio = True,
|
| 456 |
-
prompt_template_path: os.PathLike = None,
|
| 457 |
-
tag_types = [],
|
| 458 |
-
lang = 'en',
|
| 459 |
-
translate:Optional[Dict[str, os.PathLike]] = None
|
| 460 |
-
):
|
| 461 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 462 |
-
|
| 463 |
-
self._load_metadata(metadata_path)
|
| 464 |
-
self.sr = sr
|
| 465 |
-
self.duration = duration
|
| 466 |
-
self.return_path = return_path
|
| 467 |
-
self.return_audio = return_audio
|
| 468 |
-
|
| 469 |
-
self.use_dynamic_prompt = prompt_template_path is not None
|
| 470 |
-
if self.use_dynamic_prompt:
|
| 471 |
-
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
|
| 472 |
-
self.tag_types = tag_types
|
| 473 |
-
|
| 474 |
-
self.lang = lang
|
| 475 |
-
self.translate = read_translate(translate)
|
| 476 |
-
|
| 477 |
-
def _load_metadata(self, metadata_path):
|
| 478 |
-
with open(metadata_path) as fp:
|
| 479 |
-
lines = fp.readlines()
|
| 480 |
-
self.data = []
|
| 481 |
-
for line in lines:
|
| 482 |
-
item = json.loads(line)
|
| 483 |
-
self.data.append(item)
|
| 484 |
-
self.is_info_recorded = bool('Tags' in self.data[0])
|
| 485 |
-
|
| 486 |
-
def __len__(self):
|
| 487 |
-
return len(self.data)
|
| 488 |
-
|
| 489 |
-
def __getitem__(self, idx):
|
| 490 |
-
path:str = self.data[idx]["path"]
|
| 491 |
-
json_path = path[:path.rfind('.')] + ".json"
|
| 492 |
-
if self.is_info_recorded:
|
| 493 |
-
item = self.data[idx]
|
| 494 |
-
else:
|
| 495 |
-
try:
|
| 496 |
-
with open(json_path) as fp:
|
| 497 |
-
item:dict = json.load(fp)
|
| 498 |
-
except Exception as e:
|
| 499 |
-
print(f"Error loading json file {json_path} :\n{e}")
|
| 500 |
-
item = {}
|
| 501 |
-
description = self.generate_description(item)
|
| 502 |
-
if self.return_audio:
|
| 503 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 504 |
-
audio = self.audio_reader(path, sr, duration)
|
| 505 |
-
else:
|
| 506 |
-
audio = None
|
| 507 |
-
if self.return_path:
|
| 508 |
-
return audio, description, path
|
| 509 |
-
return audio, description
|
| 510 |
-
|
| 511 |
-
def generate_description(self, item):
|
| 512 |
-
if self.use_dynamic_prompt:
|
| 513 |
-
# dynamically generate prompt from given prompt template
|
| 514 |
-
prompt_template = random.choice(self.prompt_templates)
|
| 515 |
-
description = self.generate_description_dynamic(item, prompt_template)
|
| 516 |
-
else:
|
| 517 |
-
# use ordinary static prompt instead
|
| 518 |
-
description = self.generate_description_ordinary(item)
|
| 519 |
-
return description
|
| 520 |
-
|
| 521 |
-
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
| 522 |
-
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
| 523 |
-
|
| 524 |
-
if len(exists_tag) > 0:
|
| 525 |
-
probs = dist_prob_map[len(exists_tag)]
|
| 526 |
-
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
|
| 527 |
-
random.shuffle(exists_tag)
|
| 528 |
-
tags = exists_tag[:tags_num]
|
| 529 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
| 530 |
-
tags_args = self.handle_BPM_tag(tags_args)
|
| 531 |
-
prompt = prompt_template.apply(**tags_args)
|
| 532 |
-
else:
|
| 533 |
-
# no strong tags, use all weak tags instead
|
| 534 |
-
prompt = prompt_template.apply()
|
| 535 |
-
|
| 536 |
-
return prompt
|
| 537 |
-
|
| 538 |
-
def tags_to_desc(self, tag_list, tag_type) -> str:
|
| 539 |
-
if self.lang == 'en':
|
| 540 |
-
return tags_to_desc(tag_list)
|
| 541 |
-
elif self.lang == 'zh':
|
| 542 |
-
if tag_type == 'BPM':
|
| 543 |
-
return tags_to_desc(tag_list, sep='、')
|
| 544 |
-
translator = self.translate[tag_type]
|
| 545 |
-
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
| 546 |
-
return tags_to_desc(translated_tag_list, sep='、')
|
| 547 |
-
|
| 548 |
-
def handle_BPM_tag(self, tags_args):
|
| 549 |
-
if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
|
| 550 |
-
bpm = tags_args["BPM"]
|
| 551 |
-
del tags_args["BPM"]
|
| 552 |
-
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
|
| 553 |
-
for tag_type in tag_types_used:
|
| 554 |
-
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
|
| 555 |
-
return tags_args
|
| 556 |
-
|
| 557 |
-
def generate_description_ordinary(self, data, thresh = 0.3):
|
| 558 |
-
if self.lang != 'en':
|
| 559 |
-
raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
|
| 560 |
-
description = f'a piece of music by {data["Artist"]}'
|
| 561 |
-
|
| 562 |
-
# Add genre if available
|
| 563 |
-
if data["Genre"] and random.random() > thresh:
|
| 564 |
-
genres = ', '.join(data["Genre"])
|
| 565 |
-
description += f', belonging to the {genres} genres'
|
| 566 |
-
|
| 567 |
-
# Add moods if available
|
| 568 |
-
if data["Tags"] and random.random() > thresh:
|
| 569 |
-
tags = ', '.join(data["Tags"])
|
| 570 |
-
description += f'. This track contains the tags:{tags}'
|
| 571 |
-
|
| 572 |
-
# Add moods if available
|
| 573 |
-
if data["Mood"] and random.random() > thresh:
|
| 574 |
-
moods = ', '.join(data["Mood"])
|
| 575 |
-
description += f'. This track conveys a {moods} mood.'
|
| 576 |
-
|
| 577 |
-
# Add instruments if available
|
| 578 |
-
if data["Instrument"] and random.random() > thresh:
|
| 579 |
-
instruments = ', '.join(data["Instrument"])
|
| 580 |
-
description += f'. and primarily features the following instruments: {instruments}'
|
| 581 |
-
|
| 582 |
-
# Add a period to end the description
|
| 583 |
-
description += '.'
|
| 584 |
-
|
| 585 |
-
return description
|
| 586 |
-
|
| 587 |
-
def mp3_path_to_id(mp3_path):
|
| 588 |
-
return int(
|
| 589 |
-
mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')]
|
| 590 |
-
)
|
| 591 |
-
|
| 592 |
-
class TmeDataset(Dataset):
|
| 593 |
-
def __init__(self,
|
| 594 |
-
data_index:str,
|
| 595 |
-
music_info:str = None,
|
| 596 |
-
duration:float = 10,
|
| 597 |
-
sr:int = 0,
|
| 598 |
-
return_path = False,
|
| 599 |
-
return_audio = True,
|
| 600 |
-
prompt_format_path: os.PathLike = None,
|
| 601 |
-
tag_types = ['*'],
|
| 602 |
-
lang = 'zh',
|
| 603 |
-
translate: Optional[os.PathLike] = None,
|
| 604 |
-
prompt_dir: os.PathLike = None,
|
| 605 |
-
):
|
| 606 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 607 |
-
|
| 608 |
-
self.sr = sr
|
| 609 |
-
self.duration = duration
|
| 610 |
-
self.return_path = return_path
|
| 611 |
-
self.return_audio = return_audio
|
| 612 |
-
self.lang = lang
|
| 613 |
-
|
| 614 |
-
self.use_ready_prompt = prompt_dir is not None
|
| 615 |
-
|
| 616 |
-
data_index = read_jsonlike(data_index)
|
| 617 |
-
self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
|
| 618 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 619 |
-
|
| 620 |
-
if not self.use_ready_prompt:
|
| 621 |
-
#读取音乐的信息文件
|
| 622 |
-
music_info = read_jsonlike(music_info)
|
| 623 |
-
if 'music' in music_info:
|
| 624 |
-
music_info = music_info['music']
|
| 625 |
-
self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
|
| 626 |
-
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
|
| 627 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 628 |
-
|
| 629 |
-
with open(prompt_format_path) as fp:
|
| 630 |
-
self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
|
| 631 |
-
|
| 632 |
-
#加载tag types,并分成一般的tag_types和关键的key_tag_types
|
| 633 |
-
if '*' in tag_types:
|
| 634 |
-
self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
|
| 635 |
-
else:
|
| 636 |
-
self.tag_types = tag_types
|
| 637 |
-
|
| 638 |
-
self.key_tag_types = []
|
| 639 |
-
if 'tag' in self.tag_types:
|
| 640 |
-
self.tag_types.remove('tag')
|
| 641 |
-
self.key_tag_types = list(self.prompt_formats['tag'].keys())
|
| 642 |
-
|
| 643 |
-
#加载translate翻译
|
| 644 |
-
if translate is not None:
|
| 645 |
-
self.translator = read_jsonlike(translate)
|
| 646 |
-
else:
|
| 647 |
-
data_ids_set = set(self.data_ids)
|
| 648 |
-
self.prompts_dict = {}
|
| 649 |
-
for fname in os.listdir(prompt_dir):
|
| 650 |
-
items = read_jsonlike(os.path.join(prompt_dir, fname))
|
| 651 |
-
for item in items:
|
| 652 |
-
if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
|
| 653 |
-
continue
|
| 654 |
-
if item['ID'] not in self.prompts_dict:
|
| 655 |
-
self.prompts_dict[item['ID']] = []
|
| 656 |
-
self.prompts_dict[item['ID']].append(item['Text'])
|
| 657 |
-
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
|
| 658 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 659 |
-
|
| 660 |
-
def tags_to_desc(self, tag_list) -> str:
|
| 661 |
-
if is_bearable(tag_list, int):
|
| 662 |
-
return str(tag_list)
|
| 663 |
-
if self.lang == 'zh':
|
| 664 |
-
return tags_to_desc(tag_list, sep=self.sep)
|
| 665 |
-
else:
|
| 666 |
-
translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
|
| 667 |
-
return tags_to_desc(translated_tag_list, sep=self.sep)
|
| 668 |
-
|
| 669 |
-
def gen_desc_of_tag(self, formats, tags):
|
| 670 |
-
fmt = random.choice(formats)
|
| 671 |
-
return fmt.format(self.tags_to_desc(tags))
|
| 672 |
-
|
| 673 |
-
@staticmethod
|
| 674 |
-
def check_valid(value):
|
| 675 |
-
if isinstance(value, int) or isinstance(value, float):
|
| 676 |
-
return value > 0
|
| 677 |
-
if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
|
| 678 |
-
return True
|
| 679 |
-
return False
|
| 680 |
-
|
| 681 |
-
@staticmethod
|
| 682 |
-
def remove_repeat(data):
|
| 683 |
-
#若专辑名和歌曲名相同,则只使用后者
|
| 684 |
-
album_name = data.get('专辑名', None)
|
| 685 |
-
if album_name is not None and album_name == data.get('歌曲名', None):
|
| 686 |
-
del data['专辑名']
|
| 687 |
-
return data
|
| 688 |
-
|
| 689 |
-
@property
|
| 690 |
-
def comma(self):
|
| 691 |
-
if self.lang == 'zh':
|
| 692 |
-
return ','
|
| 693 |
-
elif self.lang == 'en':
|
| 694 |
-
return ', '
|
| 695 |
-
|
| 696 |
-
@property
|
| 697 |
-
def sep(self):
|
| 698 |
-
if self.lang == 'zh':
|
| 699 |
-
return '、'
|
| 700 |
-
elif self.lang == 'en':
|
| 701 |
-
return ', '
|
| 702 |
-
|
| 703 |
-
def generate_description(self, data):
|
| 704 |
-
data = self.remove_repeat(data)
|
| 705 |
-
weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
|
| 706 |
-
|
| 707 |
-
key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
|
| 708 |
-
|
| 709 |
-
prompts = []
|
| 710 |
-
if len(weak_tags) > 0:
|
| 711 |
-
probs = dist_prob_map_low[len(weak_tags)]
|
| 712 |
-
if len(key_tags) > 0:
|
| 713 |
-
tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
|
| 714 |
-
else:
|
| 715 |
-
tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
|
| 716 |
-
random.shuffle(weak_tags)
|
| 717 |
-
tags = weak_tags[:tags_num]
|
| 718 |
-
for tag_type in tags:
|
| 719 |
-
tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
|
| 720 |
-
prompts.append(tag_desc)
|
| 721 |
-
|
| 722 |
-
if len(key_tags) > 0:
|
| 723 |
-
probs = dist_prob_map[len(key_tags)]
|
| 724 |
-
tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
|
| 725 |
-
random.shuffle(key_tags)
|
| 726 |
-
tags = key_tags[:tags_num]
|
| 727 |
-
for tag_type in tags:
|
| 728 |
-
tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
|
| 729 |
-
prompts.append(tag_desc)
|
| 730 |
-
|
| 731 |
-
random.shuffle(prompts)
|
| 732 |
-
return self.comma.join(prompts)
|
| 733 |
-
|
| 734 |
-
def is_valid_prompt_text(self, text):
|
| 735 |
-
for bad in ('抱歉','sorry', 'Sorry'):
|
| 736 |
-
if bad in text:
|
| 737 |
-
return False
|
| 738 |
-
return True
|
| 739 |
-
|
| 740 |
-
def get_ready_prompt(self, path):
|
| 741 |
-
sid = mp3_path_to_id(path)
|
| 742 |
-
return random.choice(self.prompts_dict[sid])
|
| 743 |
-
|
| 744 |
-
def __len__(self):
|
| 745 |
-
return len(self.data_ids)
|
| 746 |
-
|
| 747 |
-
def __getitem__(self, idx):
|
| 748 |
-
data_id = self.data_ids[idx]
|
| 749 |
-
item = self.data_index_dict[data_id]
|
| 750 |
-
path = item['path']
|
| 751 |
-
if not self.use_ready_prompt:
|
| 752 |
-
info = self.music_info_dict[data_id]
|
| 753 |
-
description = self.generate_description(info)
|
| 754 |
-
else:
|
| 755 |
-
description = self.get_ready_prompt(path)
|
| 756 |
-
if self.return_audio:
|
| 757 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 758 |
-
audio = self.audio_reader(path, sr, duration)
|
| 759 |
-
else:
|
| 760 |
-
audio = None
|
| 761 |
-
if self.return_path:
|
| 762 |
-
return audio, description, path
|
| 763 |
-
return audio, description
|
| 764 |
-
|
| 765 |
-
class CombinedDataset(Dataset):
|
| 766 |
-
@beartype
|
| 767 |
-
def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
|
| 768 |
-
self.datasets = datasets
|
| 769 |
-
self.datasets_index = []
|
| 770 |
-
|
| 771 |
-
for i,dataset in enumerate(datasets):
|
| 772 |
-
if dataset is None:
|
| 773 |
-
continue
|
| 774 |
-
for dup in range(ratios[i]):
|
| 775 |
-
for j in range(len(dataset)):
|
| 776 |
-
self.datasets_index.append((i,j))
|
| 777 |
-
|
| 778 |
-
def __len__(self):
|
| 779 |
-
return len(self.datasets_index)
|
| 780 |
-
|
| 781 |
-
def __getitem__(self, idx):
|
| 782 |
-
index = self.datasets_index[idx]
|
| 783 |
-
i,j = index
|
| 784 |
-
return self.datasets[i][j]
|
| 785 |
-
|
| 786 |
-
class CombinedDataset_random(Dataset):
|
| 787 |
-
@beartype
|
| 788 |
-
def __init__(self,
|
| 789 |
-
num_examples:int,
|
| 790 |
-
datasets: Sequence[Dataset], ratios: Sequence[int]
|
| 791 |
-
):
|
| 792 |
-
self.datasets = datasets
|
| 793 |
-
self.datasets_index = []
|
| 794 |
-
|
| 795 |
-
for i,dataset in enumerate(datasets):
|
| 796 |
-
if dataset is None:
|
| 797 |
-
continue
|
| 798 |
-
for dup in range(ratios[i]):
|
| 799 |
-
for j in range(len(dataset)):
|
| 800 |
-
self.datasets_index.append((i,j))
|
| 801 |
-
if num_examples > 0:
|
| 802 |
-
self.random_choose = True
|
| 803 |
-
self.dataset_len = num_examples
|
| 804 |
-
else:
|
| 805 |
-
self.random_choose = False
|
| 806 |
-
self.dataset_len = len(self.datasets_index)
|
| 807 |
-
|
| 808 |
-
def __len__(self):
|
| 809 |
-
return self.dataset_len
|
| 810 |
-
|
| 811 |
-
def __getitem__(self, idx):
|
| 812 |
-
first_try = True
|
| 813 |
-
try_cnt = 0
|
| 814 |
-
while True:
|
| 815 |
-
try:
|
| 816 |
-
if(self.random_choose or not first_try):
|
| 817 |
-
index2 = []
|
| 818 |
-
index2.append(np.random.randint(0,len(self.datasets)))
|
| 819 |
-
index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
|
| 820 |
-
else:
|
| 821 |
-
index2 = self.datasets_index[idx]
|
| 822 |
-
first_try = False
|
| 823 |
-
out = self.datasets[index2[0]][index2[1]]
|
| 824 |
-
if(len(out[0].shape)==1):out[0]=out[0][None,:]
|
| 825 |
-
return out
|
| 826 |
-
except:
|
| 827 |
-
print("Error loadding ", index2)
|
| 828 |
-
try_cnt += 1
|
| 829 |
-
if(try_cnt>10):
|
| 830 |
-
raise ValueError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py
DELETED
|
@@ -1,994 +0,0 @@
|
|
| 1 |
-
from torch.utils.data import Dataset
|
| 2 |
-
from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List
|
| 3 |
-
from beartype import beartype
|
| 4 |
-
from beartype.door import is_bearable
|
| 5 |
-
import random
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import os
|
| 8 |
-
from torchaudio.functional import resample
|
| 9 |
-
import torch
|
| 10 |
-
import typing as tp
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
import torchaudio as ta
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
import numpy as np
|
| 15 |
-
import json
|
| 16 |
-
import yaml
|
| 17 |
-
import torchaudio
|
| 18 |
-
import math
|
| 19 |
-
import re
|
| 20 |
-
from loguru import logger
|
| 21 |
-
|
| 22 |
-
def gen_plain_prompt(key_list, sep=', '):
|
| 23 |
-
if len(key_list) == 0:
|
| 24 |
-
return 'none'
|
| 25 |
-
|
| 26 |
-
key_list = [k.strip() for k in key_list]
|
| 27 |
-
|
| 28 |
-
if len(key_list) > 10:
|
| 29 |
-
random.shuffle(key_list)
|
| 30 |
-
key_list = key_list[:10]
|
| 31 |
-
|
| 32 |
-
probs = dist_prob_map[len(key_list)]
|
| 33 |
-
|
| 34 |
-
num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
|
| 35 |
-
|
| 36 |
-
random.shuffle(key_list)
|
| 37 |
-
tags = key_list[:num_tags]
|
| 38 |
-
tags_str = sep.join(tags)
|
| 39 |
-
return tags_str
|
| 40 |
-
|
| 41 |
-
class Read_and_PadCrop_Normalized_T(torch.nn.Module):
|
| 42 |
-
|
| 43 |
-
def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
|
| 44 |
-
|
| 45 |
-
super().__init__()
|
| 46 |
-
|
| 47 |
-
self.n_samples = n_samples
|
| 48 |
-
self.sample_rate = sample_rate
|
| 49 |
-
self.randomize = randomize
|
| 50 |
-
self.prob = {"is_start":0.2, "is_end":0.9}
|
| 51 |
-
self.shift_secs = 5
|
| 52 |
-
|
| 53 |
-
def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
|
| 54 |
-
if(duration<(float(self.n_samples)/self.sample_rate+1)):
|
| 55 |
-
raise ValueError(duration,float(self.n_samples),self.sample_rate)
|
| 56 |
-
chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
|
| 57 |
-
t_start = 0.
|
| 58 |
-
t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
|
| 59 |
-
offset = 0
|
| 60 |
-
is_start = True
|
| 61 |
-
is_end = True
|
| 62 |
-
else:
|
| 63 |
-
prob = random.uniform(0,1)
|
| 64 |
-
if(prob<self.prob['is_start']):
|
| 65 |
-
is_start = True
|
| 66 |
-
is_end = False
|
| 67 |
-
offset = 0
|
| 68 |
-
elif(prob>self.prob['is_end']):
|
| 69 |
-
is_start = False
|
| 70 |
-
is_end = True
|
| 71 |
-
offset = int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)
|
| 72 |
-
else:
|
| 73 |
-
is_start = False
|
| 74 |
-
is_end = False
|
| 75 |
-
offset = np.random.randint(self.shift_secs*cur_sample_rate, \
|
| 76 |
-
int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)-self.shift_secs*cur_sample_rate)
|
| 77 |
-
t_start = offset / float(cur_sample_rate) / duration
|
| 78 |
-
t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
|
| 79 |
-
chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 80 |
-
if(chunk.shape[0]>1):
|
| 81 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 82 |
-
else:
|
| 83 |
-
chunk = chunk[[0],:].float()
|
| 84 |
-
if(cur_sample_rate!=self.sample_rate):
|
| 85 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 86 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
|
| 87 |
-
# print('b:',self.sample_rate,chunk.shape)
|
| 88 |
-
if chunk.shape[-1] != self.n_samples:
|
| 89 |
-
raise ValueError(chunk.shape, self.n_samples, offset, int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
|
| 90 |
-
# if chunk.shape[-1] < self.n_samples:
|
| 91 |
-
# chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
|
| 92 |
-
# else:
|
| 93 |
-
# chunk = chunk[:,0:self.n_samples]
|
| 94 |
-
seconds_start = math.floor(offset / cur_sample_rate)
|
| 95 |
-
seconds_total = math.floor(duration)
|
| 96 |
-
|
| 97 |
-
# # In this dataset, we do not introduce zeros
|
| 98 |
-
# if(is_start):
|
| 99 |
-
# chunk = torch.cat([torch.zeros(1, self.shift_secs*self.sample_rate), chunk],1)[:,0:self.n_samples]
|
| 100 |
-
# elif(is_end):
|
| 101 |
-
# chunk = torch.cat([chunk, torch.zeros(1, self.shift_secs*self.sample_rate)],1)[:,self.shift_secs*self.sample_rate:]
|
| 102 |
-
|
| 103 |
-
return (
|
| 104 |
-
chunk,
|
| 105 |
-
t_start,
|
| 106 |
-
t_end,
|
| 107 |
-
seconds_start,
|
| 108 |
-
seconds_total,
|
| 109 |
-
is_start,
|
| 110 |
-
is_end,
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
|
| 115 |
-
if USE_DUMMY_AUDIO:
|
| 116 |
-
logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
|
| 117 |
-
|
| 118 |
-
class SafeAudioReader:
|
| 119 |
-
"""
|
| 120 |
-
This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
|
| 121 |
-
"""
|
| 122 |
-
def __init__(self,
|
| 123 |
-
duration: float, # 返回音频长度
|
| 124 |
-
sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
|
| 125 |
-
randomize: bool = True
|
| 126 |
-
):
|
| 127 |
-
self.n_samples = int(sample_rate * max(duration, 0))
|
| 128 |
-
self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
|
| 129 |
-
|
| 130 |
-
#NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
|
| 131 |
-
def __call__(self,
|
| 132 |
-
filepath: os.PathLike, # 音频路径
|
| 133 |
-
origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
|
| 134 |
-
origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
|
| 135 |
-
) -> torch.Tensor:
|
| 136 |
-
if USE_DUMMY_AUDIO:
|
| 137 |
-
wav = torch.zeros(self.n_samples, dtype=torch.float32)
|
| 138 |
-
return wav
|
| 139 |
-
try:
|
| 140 |
-
# if origin_sample_rate is None or origin_duration is None:
|
| 141 |
-
# audio_info = torchaudio.info(filepath)
|
| 142 |
-
# origin_sample_rate = audio_info.sample_rate
|
| 143 |
-
# origin_duration = audio_info.num_frames / origin_sample_rate
|
| 144 |
-
audio_info = torchaudio.info(filepath)
|
| 145 |
-
origin_sample_rate = audio_info.sample_rate
|
| 146 |
-
origin_duration = audio_info.num_frames / origin_sample_rate
|
| 147 |
-
wav, *ignored, is_start, is_end = self.reader(filepath, origin_duration, origin_sample_rate)
|
| 148 |
-
except Exception as e:
|
| 149 |
-
logger.error(f"Error reading {filepath}: {e}")
|
| 150 |
-
raise FileNotFoundError(filepath)
|
| 151 |
-
return wav, is_start, is_end
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
class PromptTemplate:
|
| 155 |
-
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
|
| 156 |
-
self.template_text = template_text
|
| 157 |
-
self.tag_map = tag_map
|
| 158 |
-
self.lang = lang
|
| 159 |
-
|
| 160 |
-
@property
|
| 161 |
-
def tags(self):
|
| 162 |
-
return tuple(self.tag_map.keys())
|
| 163 |
-
|
| 164 |
-
def apply(self, **kwargs):
|
| 165 |
-
for tag in list(kwargs.keys()):
|
| 166 |
-
if kwargs[tag] == '':
|
| 167 |
-
kwargs.pop(tag)
|
| 168 |
-
for tag in self.tags:
|
| 169 |
-
if tag in kwargs:
|
| 170 |
-
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
|
| 171 |
-
else:
|
| 172 |
-
kwargs[tag] = ''
|
| 173 |
-
prompt = self.template_text.format(**kwargs)
|
| 174 |
-
|
| 175 |
-
return self.beautify(prompt)
|
| 176 |
-
|
| 177 |
-
def beautify(self, text):
|
| 178 |
-
if self.lang == 'en':
|
| 179 |
-
return self._beautify_en(text)
|
| 180 |
-
elif self.lang == 'zh':
|
| 181 |
-
return self._beautify_zh(text)
|
| 182 |
-
else:
|
| 183 |
-
raise ValueError(f'Unknown language {self.lang}')
|
| 184 |
-
|
| 185 |
-
@staticmethod
|
| 186 |
-
def _beautify_en(text):
|
| 187 |
-
# no continuous commas without content between them
|
| 188 |
-
text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
|
| 189 |
-
# no continuous whitespace
|
| 190 |
-
text = re.sub(r'\s+', ' ', text)
|
| 191 |
-
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace
|
| 192 |
-
text = re.sub(r'\s+,', r',', text)
|
| 193 |
-
text = re.sub(r',\s+', r', ', text)
|
| 194 |
-
# no whitespace before the full stop
|
| 195 |
-
text = re.sub(r'\s+\.', r'.', text)
|
| 196 |
-
# strip whitespace, comma, and replace ',.'
|
| 197 |
-
text = text.strip(' ,')
|
| 198 |
-
text = text.replace(',.', '.')
|
| 199 |
-
return text
|
| 200 |
-
|
| 201 |
-
@staticmethod
|
| 202 |
-
def _beautify_zh(text):
|
| 203 |
-
# no continuous commas without content between them
|
| 204 |
-
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
|
| 205 |
-
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
|
| 206 |
-
# assume there should be NO whitespace in Chinese
|
| 207 |
-
text = re.sub(r'\s+', r'', text)
|
| 208 |
-
# strip whitespace, comma, and replace ',。'
|
| 209 |
-
text = text.strip(', 、')
|
| 210 |
-
text = text.replace(',。', '。')
|
| 211 |
-
return text
|
| 212 |
-
|
| 213 |
-
def __repr__(self):
|
| 214 |
-
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
|
| 215 |
-
|
| 216 |
-
__str__ = __repr__
|
| 217 |
-
|
| 218 |
-
def parse_prompt_template(prompt_template_text, lang='en'):
|
| 219 |
-
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
|
| 220 |
-
tag_pattern = re.compile(r'{.+?}', re.DOTALL)
|
| 221 |
-
|
| 222 |
-
template_text = prompt_template_text.strip()
|
| 223 |
-
span_texts = span_pattern.findall(prompt_template_text)
|
| 224 |
-
tag_map = {}
|
| 225 |
-
for span_text in span_texts:
|
| 226 |
-
tag = tag_pattern.findall(span_text)[0].strip('{}')
|
| 227 |
-
tag_map[tag] = span_text
|
| 228 |
-
template_text = template_text.replace(span_text, '{'+tag+'}')
|
| 229 |
-
|
| 230 |
-
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
|
| 231 |
-
|
| 232 |
-
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
|
| 233 |
-
with open(path, 'r') as f:
|
| 234 |
-
lines = f.readlines()
|
| 235 |
-
cnt = 0
|
| 236 |
-
pts = []
|
| 237 |
-
for line in lines:
|
| 238 |
-
pt = parse_prompt_template(line, lang=lang)
|
| 239 |
-
cnt += 1
|
| 240 |
-
if len(pt.tags) < num:
|
| 241 |
-
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
|
| 242 |
-
pts.append(pt)
|
| 243 |
-
|
| 244 |
-
return pts
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def get_base_dir_file(key: os.PathLike):
|
| 248 |
-
base = os.path.basename(key)
|
| 249 |
-
dirname = os.path.basename(os.path.dirname(key))
|
| 250 |
-
return os.path.join(dirname, base)
|
| 251 |
-
|
| 252 |
-
def read_jsonlike(path: os.PathLike):
|
| 253 |
-
#json or jsonl
|
| 254 |
-
if str(path).endswith(".json"):
|
| 255 |
-
with open(path, 'r', encoding='utf8') as f:
|
| 256 |
-
data = json.load(f)
|
| 257 |
-
return data
|
| 258 |
-
elif str(path).endswith(".jsonl"):
|
| 259 |
-
with open(path, 'r', encoding='utf8') as f:
|
| 260 |
-
data = [json.loads(line) for line in f.readlines()]
|
| 261 |
-
return data
|
| 262 |
-
else:
|
| 263 |
-
raise ValueError("Unknown file format")
|
| 264 |
-
|
| 265 |
-
dist_prob_map = {
|
| 266 |
-
1: (1.0,),
|
| 267 |
-
2: (0.5, 0.5),
|
| 268 |
-
3: (0.3, 0.4, 0.3),
|
| 269 |
-
4: (0.2, 0.3, 0.3, 0.2),
|
| 270 |
-
5: (0.2, 0.2, 0.3, 0.2, 0.1),
|
| 271 |
-
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
|
| 272 |
-
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
| 273 |
-
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
| 274 |
-
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
| 275 |
-
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
| 276 |
-
}
|
| 277 |
-
|
| 278 |
-
dist_prob_map_low = {
|
| 279 |
-
1: (1.0,),
|
| 280 |
-
2: (0.8, 0.2),
|
| 281 |
-
3: (0.8, 0.1, 0.1),
|
| 282 |
-
4: (0.7, 0.1, 0.1, 0.1),
|
| 283 |
-
5: (0.7, 0.1, 0.1, 0.05, 0.05),
|
| 284 |
-
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
|
| 285 |
-
}
|
| 286 |
-
|
| 287 |
-
_bpm_range_rights = (
|
| 288 |
-
(40, '20-40'),
|
| 289 |
-
(60, '40-60'),
|
| 290 |
-
(66, '60-66'),
|
| 291 |
-
(76, '66-76'),
|
| 292 |
-
(108, '76-108'),
|
| 293 |
-
(120, '108-120'),
|
| 294 |
-
(168, '120-168'),
|
| 295 |
-
(176, '168-176'),
|
| 296 |
-
(200, '176-200')
|
| 297 |
-
)
|
| 298 |
-
_bpm_desc_map = {
|
| 299 |
-
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
|
| 300 |
-
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
|
| 301 |
-
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
|
| 302 |
-
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
|
| 303 |
-
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
|
| 304 |
-
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
|
| 305 |
-
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
|
| 306 |
-
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
|
| 307 |
-
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
|
| 308 |
-
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
|
| 309 |
-
}
|
| 310 |
-
_bpm_desc_map_zh = {
|
| 311 |
-
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
|
| 312 |
-
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
|
| 313 |
-
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
|
| 314 |
-
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
|
| 315 |
-
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
|
| 316 |
-
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
|
| 317 |
-
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
|
| 318 |
-
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
|
| 319 |
-
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
|
| 320 |
-
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
|
| 321 |
-
}
|
| 322 |
-
def get_bpm_range(bpm):
|
| 323 |
-
bpm = int(bpm)
|
| 324 |
-
for right, tag in _bpm_range_rights:
|
| 325 |
-
if bpm <= right:
|
| 326 |
-
return tag
|
| 327 |
-
return '>200'
|
| 328 |
-
|
| 329 |
-
def gen_bpm_descript(bpm, lang='en'):
|
| 330 |
-
bpm_range = get_bpm_range(bpm)
|
| 331 |
-
if lang == 'en':
|
| 332 |
-
return random.choice(_bpm_desc_map[bpm_range])
|
| 333 |
-
elif lang == 'zh':
|
| 334 |
-
return random.choice(_bpm_desc_map_zh[bpm_range])
|
| 335 |
-
else:
|
| 336 |
-
raise ValueError(f"Unknown language {lang}")
|
| 337 |
-
|
| 338 |
-
def read_translate(translate: Optional[Dict[str, os.PathLike]]):
|
| 339 |
-
if translate is None:
|
| 340 |
-
return None
|
| 341 |
-
if isinstance(translate, str):
|
| 342 |
-
return read_jsonlike(translate)
|
| 343 |
-
return {k: read_jsonlike(path) for k, path in translate.items()}
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
class MagnaTagATuneDataset(Dataset):
|
| 347 |
-
def __init__(self):
|
| 348 |
-
pass
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
def tags_to_desc(tag_list, sep=',') -> str:
|
| 352 |
-
if not isinstance(tag_list, Sequence):
|
| 353 |
-
return str(tag_list)
|
| 354 |
-
if isinstance(tag_list, str):
|
| 355 |
-
return tag_list
|
| 356 |
-
if len(tag_list) <= 0:
|
| 357 |
-
return ''
|
| 358 |
-
elif len(tag_list) <= 5:
|
| 359 |
-
probs = dist_prob_map[len(tag_list)]
|
| 360 |
-
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
|
| 361 |
-
random.shuffle(tag_list)
|
| 362 |
-
tag_list = tag_list[:tags_num]
|
| 363 |
-
return sep.join(tag_list)
|
| 364 |
-
else:
|
| 365 |
-
probs = dist_prob_map[5]
|
| 366 |
-
tags_num = random.choices(range(1, 6), probs)[0]
|
| 367 |
-
random.shuffle(tag_list)
|
| 368 |
-
tag_list = tag_list[:tags_num]
|
| 369 |
-
return sep.join(tag_list)
|
| 370 |
-
|
| 371 |
-
def get_sr_and_duration_info(item):
|
| 372 |
-
return item.get('sample_rate', None), item.get('duration', None)
|
| 373 |
-
|
| 374 |
-
class MtgJamendoDatasetFromJson(Dataset):
|
| 375 |
-
def __init__(self,
|
| 376 |
-
data_dir:str,
|
| 377 |
-
json_path:str,
|
| 378 |
-
duration:float=10,
|
| 379 |
-
sr:int = 0,
|
| 380 |
-
*,
|
| 381 |
-
lang = 'en',
|
| 382 |
-
return_path = False,
|
| 383 |
-
prompt_template_path: os.PathLike = None,
|
| 384 |
-
tag_types = [],
|
| 385 |
-
translate:Optional[Dict[str, os.PathLike]] = None,
|
| 386 |
-
):
|
| 387 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 388 |
-
|
| 389 |
-
self.data_dir = data_dir
|
| 390 |
-
self._load_metadata_json(json_path)
|
| 391 |
-
self.sr = sr
|
| 392 |
-
self.duration = duration
|
| 393 |
-
self.return_path = return_path
|
| 394 |
-
self.lang = lang
|
| 395 |
-
|
| 396 |
-
self.use_dynamic_prompt = prompt_template_path is not None
|
| 397 |
-
if self.use_dynamic_prompt:
|
| 398 |
-
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
|
| 399 |
-
self.tag_types = tag_types
|
| 400 |
-
|
| 401 |
-
self.translate = read_translate(translate)
|
| 402 |
-
if not self.use_dynamic_prompt and self.lang != 'en':
|
| 403 |
-
raise NotImplementedError
|
| 404 |
-
|
| 405 |
-
#这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
|
| 406 |
-
WEAK_TAG_LIST = ["title", "artist"]
|
| 407 |
-
|
| 408 |
-
def _load_metadata_json(self, json_path):
|
| 409 |
-
with open(json_path) as fp:
|
| 410 |
-
self.data = json.load(fp)
|
| 411 |
-
|
| 412 |
-
def convert_key_to_path(self, key):
|
| 413 |
-
return os.path.join(self.data_dir, get_base_dir_file(key))
|
| 414 |
-
|
| 415 |
-
def __len__(self):
|
| 416 |
-
return len(self.data)
|
| 417 |
-
|
| 418 |
-
def __getitem__(self, idx):
|
| 419 |
-
item = self.data[idx]
|
| 420 |
-
path = self.convert_key_to_path(item['key'])
|
| 421 |
-
description = self.generate_description(item)
|
| 422 |
-
|
| 423 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 424 |
-
audio, is_start, is_end = self.audio_reader(path, sr, duration)
|
| 425 |
-
|
| 426 |
-
if self.return_path:
|
| 427 |
-
return audio, description, path
|
| 428 |
-
return audio, description, is_start, is_end
|
| 429 |
-
|
| 430 |
-
def tags_to_desc(self, tag_list, tag_type) -> str:
|
| 431 |
-
if self.lang == 'en':
|
| 432 |
-
return tags_to_desc(tag_list)
|
| 433 |
-
elif self.lang == 'zh':
|
| 434 |
-
translator = self.translate[tag_type]
|
| 435 |
-
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
| 436 |
-
return tags_to_desc(translated_tag_list, sep='、')
|
| 437 |
-
|
| 438 |
-
def generate_description(self, item):
|
| 439 |
-
if self.use_dynamic_prompt:
|
| 440 |
-
# dynamically generate prompt from given prompt template
|
| 441 |
-
prompt_template = random.choice(self.prompt_templates)
|
| 442 |
-
description = self.generate_description_dynamic(item, prompt_template)
|
| 443 |
-
|
| 444 |
-
else:
|
| 445 |
-
# use ordinary static prompt instead
|
| 446 |
-
description = self.generate_description_ordinary(item)
|
| 447 |
-
return description
|
| 448 |
-
|
| 449 |
-
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
| 450 |
-
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
| 451 |
-
exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
|
| 452 |
-
exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
|
| 453 |
-
|
| 454 |
-
if len(exists_strong_tag) > 0:
|
| 455 |
-
probs = dist_prob_map[len(exists_strong_tag)]
|
| 456 |
-
tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
|
| 457 |
-
random.shuffle(exists_strong_tag)
|
| 458 |
-
tags = exists_strong_tag[:tags_num]
|
| 459 |
-
weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
|
| 460 |
-
weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
|
| 461 |
-
random.shuffle(exists_weak_tag)
|
| 462 |
-
weak_tags = exists_weak_tag[:weak_tags_num]
|
| 463 |
-
tags += weak_tags
|
| 464 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
| 465 |
-
prompt = prompt_template.apply(**tags_args)
|
| 466 |
-
else:
|
| 467 |
-
# no strong tags, use all weak tags instead
|
| 468 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
|
| 469 |
-
prompt = prompt_template.apply(**tags_args)
|
| 470 |
-
|
| 471 |
-
return prompt
|
| 472 |
-
|
| 473 |
-
def generate_description_ordinary(self, data, thresh = 0.3):
|
| 474 |
-
# Initialize the description with title and artist
|
| 475 |
-
description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}'
|
| 476 |
-
|
| 477 |
-
# Add genre if available
|
| 478 |
-
if data["genre"] and random.random() > thresh:
|
| 479 |
-
genres = ', '.join(data["genre"])
|
| 480 |
-
description += f', belonging to the {genres} genres'
|
| 481 |
-
|
| 482 |
-
# Add moods if available
|
| 483 |
-
if data["moods"] and random.random() > thresh:
|
| 484 |
-
moods = ', '.join(data["moods"])
|
| 485 |
-
description += f'. This track conveys a {moods} mood'
|
| 486 |
-
|
| 487 |
-
# Add instruments if available
|
| 488 |
-
if data["instrument"] and random.random() > thresh:
|
| 489 |
-
instruments = ', '.join(data["instrument"])
|
| 490 |
-
description += f', and primarily features the following instruments: {instruments}'
|
| 491 |
-
|
| 492 |
-
# Add a period to end the description
|
| 493 |
-
description += '.'
|
| 494 |
-
|
| 495 |
-
return description
|
| 496 |
-
|
| 497 |
-
class AudioStockDataset(Dataset):
|
| 498 |
-
def __init__(self,
|
| 499 |
-
metadata_path:str,
|
| 500 |
-
duration:float=10,
|
| 501 |
-
sr:int = 0,
|
| 502 |
-
return_path = False,
|
| 503 |
-
return_audio = True,
|
| 504 |
-
prompt_template_path: os.PathLike = None,
|
| 505 |
-
tag_types = [],
|
| 506 |
-
lang = 'en',
|
| 507 |
-
translate:Optional[Dict[str, os.PathLike]] = None
|
| 508 |
-
):
|
| 509 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 510 |
-
|
| 511 |
-
self.duration = duration
|
| 512 |
-
self._load_metadata(metadata_path)
|
| 513 |
-
self.sr = sr
|
| 514 |
-
self.return_path = return_path
|
| 515 |
-
self.return_audio = return_audio
|
| 516 |
-
|
| 517 |
-
self.use_dynamic_prompt = prompt_template_path is not None
|
| 518 |
-
if self.use_dynamic_prompt:
|
| 519 |
-
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
|
| 520 |
-
self.tag_types = tag_types
|
| 521 |
-
|
| 522 |
-
self.lang = lang
|
| 523 |
-
self.translate = read_translate(translate)
|
| 524 |
-
|
| 525 |
-
def _load_metadata(self, metadata_path):
|
| 526 |
-
with open(metadata_path) as fp:
|
| 527 |
-
lines = fp.readlines()
|
| 528 |
-
self.data = []
|
| 529 |
-
for line in lines:
|
| 530 |
-
item = json.loads(line)
|
| 531 |
-
if(item['duration']>self.duration+10):
|
| 532 |
-
self.data.append(item)
|
| 533 |
-
self.is_info_recorded = bool('Tags' in self.data[0])
|
| 534 |
-
|
| 535 |
-
def __len__(self):
|
| 536 |
-
return len(self.data)
|
| 537 |
-
|
| 538 |
-
def __getitem__(self, idx):
|
| 539 |
-
path:str = self.data[idx]["path"]
|
| 540 |
-
json_path = path[:path.rfind('.')] + ".json"
|
| 541 |
-
if self.is_info_recorded:
|
| 542 |
-
item = self.data[idx]
|
| 543 |
-
else:
|
| 544 |
-
try:
|
| 545 |
-
with open(json_path) as fp:
|
| 546 |
-
item:dict = json.load(fp)
|
| 547 |
-
except Exception as e:
|
| 548 |
-
print(f"Error loading json file {json_path} :\n{e}")
|
| 549 |
-
item = {}
|
| 550 |
-
description = self.generate_description(item)
|
| 551 |
-
if self.return_audio:
|
| 552 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 553 |
-
audio, is_start, is_end = self.audio_reader(path, sr, duration)
|
| 554 |
-
else:
|
| 555 |
-
audio = None
|
| 556 |
-
if self.return_path:
|
| 557 |
-
return audio, description, path, is_start, is_end
|
| 558 |
-
else:
|
| 559 |
-
return audio, description, is_start, is_end
|
| 560 |
-
|
| 561 |
-
def generate_description(self, item):
|
| 562 |
-
if self.use_dynamic_prompt:
|
| 563 |
-
# dynamically generate prompt from given prompt template
|
| 564 |
-
prompt_template = random.choice(self.prompt_templates)
|
| 565 |
-
description = self.generate_description_dynamic(item, prompt_template)
|
| 566 |
-
else:
|
| 567 |
-
# use ordinary static prompt instead
|
| 568 |
-
description = self.generate_description_ordinary(item)
|
| 569 |
-
return description
|
| 570 |
-
|
| 571 |
-
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
| 572 |
-
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
| 573 |
-
|
| 574 |
-
if len(exists_tag) > 0:
|
| 575 |
-
probs = dist_prob_map[len(exists_tag)]
|
| 576 |
-
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
|
| 577 |
-
random.shuffle(exists_tag)
|
| 578 |
-
tags = exists_tag[:tags_num]
|
| 579 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
| 580 |
-
tags_args = self.handle_BPM_tag(tags_args)
|
| 581 |
-
prompt = prompt_template.apply(**tags_args)
|
| 582 |
-
else:
|
| 583 |
-
# no strong tags, use all weak tags instead
|
| 584 |
-
prompt = prompt_template.apply()
|
| 585 |
-
|
| 586 |
-
return prompt
|
| 587 |
-
|
| 588 |
-
def tags_to_desc(self, tag_list, tag_type) -> str:
|
| 589 |
-
if self.lang == 'en':
|
| 590 |
-
return tags_to_desc(tag_list)
|
| 591 |
-
elif self.lang == 'zh':
|
| 592 |
-
if tag_type == 'BPM':
|
| 593 |
-
return tags_to_desc(tag_list, sep='、')
|
| 594 |
-
translator = self.translate[tag_type]
|
| 595 |
-
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
| 596 |
-
return tags_to_desc(translated_tag_list, sep='、')
|
| 597 |
-
|
| 598 |
-
def handle_BPM_tag(self, tags_args):
|
| 599 |
-
if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
|
| 600 |
-
bpm = tags_args["BPM"]
|
| 601 |
-
del tags_args["BPM"]
|
| 602 |
-
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
|
| 603 |
-
for tag_type in tag_types_used:
|
| 604 |
-
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
|
| 605 |
-
return tags_args
|
| 606 |
-
|
| 607 |
-
def generate_description_ordinary(self, data, thresh = 0.3):
|
| 608 |
-
if self.lang != 'en':
|
| 609 |
-
raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
|
| 610 |
-
description = f'a piece of music by {data["Artist"]}'
|
| 611 |
-
|
| 612 |
-
# Add genre if available
|
| 613 |
-
if data["Genre"] and random.random() > thresh:
|
| 614 |
-
genres = ', '.join(data["Genre"])
|
| 615 |
-
description += f', belonging to the {genres} genres'
|
| 616 |
-
|
| 617 |
-
# Add moods if available
|
| 618 |
-
if data["Tags"] and random.random() > thresh:
|
| 619 |
-
tags = ', '.join(data["Tags"])
|
| 620 |
-
description += f'. This track contains the tags:{tags}'
|
| 621 |
-
|
| 622 |
-
# Add moods if available
|
| 623 |
-
if data["Mood"] and random.random() > thresh:
|
| 624 |
-
moods = ', '.join(data["Mood"])
|
| 625 |
-
description += f'. This track conveys a {moods} mood.'
|
| 626 |
-
|
| 627 |
-
# Add instruments if available
|
| 628 |
-
if data["Instrument"] and random.random() > thresh:
|
| 629 |
-
instruments = ', '.join(data["Instrument"])
|
| 630 |
-
description += f'. and primarily features the following instruments: {instruments}'
|
| 631 |
-
|
| 632 |
-
# Add a period to end the description
|
| 633 |
-
description += '.'
|
| 634 |
-
|
| 635 |
-
return description
|
| 636 |
-
|
| 637 |
-
def mp3_path_to_id(mp3_path):
|
| 638 |
-
return int(
|
| 639 |
-
mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')]
|
| 640 |
-
)
|
| 641 |
-
|
| 642 |
-
class TmeDataset(Dataset):
|
| 643 |
-
def __init__(self,
|
| 644 |
-
data_index:str,
|
| 645 |
-
music_info:str = None,
|
| 646 |
-
duration:float = 10,
|
| 647 |
-
sr:int = 0,
|
| 648 |
-
return_path = False,
|
| 649 |
-
return_audio = True,
|
| 650 |
-
prompt_format_path: os.PathLike = None,
|
| 651 |
-
tag_types = ['*'],
|
| 652 |
-
lang = 'zh',
|
| 653 |
-
translate: Optional[os.PathLike] = None,
|
| 654 |
-
prompt_dir: os.PathLike = None,
|
| 655 |
-
):
|
| 656 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 657 |
-
|
| 658 |
-
self.sr = sr
|
| 659 |
-
self.duration = duration
|
| 660 |
-
self.return_path = return_path
|
| 661 |
-
self.return_audio = return_audio
|
| 662 |
-
self.lang = lang
|
| 663 |
-
|
| 664 |
-
self.use_ready_prompt = prompt_dir is not None
|
| 665 |
-
|
| 666 |
-
data_index = read_jsonlike(data_index)
|
| 667 |
-
data_index = [d for d in data_index if d['duration']>self.duration+10]
|
| 668 |
-
self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
|
| 669 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 670 |
-
|
| 671 |
-
if not self.use_ready_prompt:
|
| 672 |
-
#读取音乐的信息文件
|
| 673 |
-
music_info = read_jsonlike(music_info)
|
| 674 |
-
if 'music' in music_info:
|
| 675 |
-
music_info = music_info['music']
|
| 676 |
-
self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
|
| 677 |
-
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
|
| 678 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 679 |
-
|
| 680 |
-
with open(prompt_format_path) as fp:
|
| 681 |
-
self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
|
| 682 |
-
|
| 683 |
-
#加载tag types,并分成一般的tag_types和关键的key_tag_types
|
| 684 |
-
if '*' in tag_types:
|
| 685 |
-
self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
|
| 686 |
-
else:
|
| 687 |
-
self.tag_types = tag_types
|
| 688 |
-
|
| 689 |
-
self.key_tag_types = []
|
| 690 |
-
if 'tag' in self.tag_types:
|
| 691 |
-
self.tag_types.remove('tag')
|
| 692 |
-
self.key_tag_types = list(self.prompt_formats['tag'].keys())
|
| 693 |
-
|
| 694 |
-
#加载translate翻译
|
| 695 |
-
if translate is not None:
|
| 696 |
-
self.translator = read_jsonlike(translate)
|
| 697 |
-
else:
|
| 698 |
-
data_ids_set = set(self.data_ids)
|
| 699 |
-
self.prompts_dict = {}
|
| 700 |
-
for fname in os.listdir(prompt_dir):
|
| 701 |
-
items = read_jsonlike(os.path.join(prompt_dir, fname))
|
| 702 |
-
for item in items:
|
| 703 |
-
if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
|
| 704 |
-
continue
|
| 705 |
-
if item['ID'] not in self.prompts_dict:
|
| 706 |
-
self.prompts_dict[item['ID']] = []
|
| 707 |
-
self.prompts_dict[item['ID']].append(item['Text'])
|
| 708 |
-
self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
|
| 709 |
-
self.data_ids = list(self.data_index_dict.keys())
|
| 710 |
-
|
| 711 |
-
def tags_to_desc(self, tag_list) -> str:
|
| 712 |
-
if is_bearable(tag_list, int):
|
| 713 |
-
return str(tag_list)
|
| 714 |
-
if self.lang == 'zh':
|
| 715 |
-
return tags_to_desc(tag_list, sep=self.sep)
|
| 716 |
-
else:
|
| 717 |
-
translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
|
| 718 |
-
return tags_to_desc(translated_tag_list, sep=self.sep)
|
| 719 |
-
|
| 720 |
-
def gen_desc_of_tag(self, formats, tags):
|
| 721 |
-
fmt = random.choice(formats)
|
| 722 |
-
return fmt.format(self.tags_to_desc(tags))
|
| 723 |
-
|
| 724 |
-
@staticmethod
|
| 725 |
-
def check_valid(value):
|
| 726 |
-
if isinstance(value, int) or isinstance(value, float):
|
| 727 |
-
return value > 0
|
| 728 |
-
if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
|
| 729 |
-
return True
|
| 730 |
-
return False
|
| 731 |
-
|
| 732 |
-
@staticmethod
|
| 733 |
-
def remove_repeat(data):
|
| 734 |
-
#若专辑名和歌曲名相同,则只使用后者
|
| 735 |
-
album_name = data.get('专辑名', None)
|
| 736 |
-
if album_name is not None and album_name == data.get('歌曲名', None):
|
| 737 |
-
del data['专辑名']
|
| 738 |
-
return data
|
| 739 |
-
|
| 740 |
-
@property
|
| 741 |
-
def comma(self):
|
| 742 |
-
if self.lang == 'zh':
|
| 743 |
-
return ','
|
| 744 |
-
elif self.lang == 'en':
|
| 745 |
-
return ', '
|
| 746 |
-
|
| 747 |
-
@property
|
| 748 |
-
def sep(self):
|
| 749 |
-
if self.lang == 'zh':
|
| 750 |
-
return '、'
|
| 751 |
-
elif self.lang == 'en':
|
| 752 |
-
return ', '
|
| 753 |
-
|
| 754 |
-
def generate_description(self, data):
|
| 755 |
-
data = self.remove_repeat(data)
|
| 756 |
-
weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
|
| 757 |
-
|
| 758 |
-
key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
|
| 759 |
-
|
| 760 |
-
prompts = []
|
| 761 |
-
if len(weak_tags) > 0:
|
| 762 |
-
probs = dist_prob_map_low[len(weak_tags)]
|
| 763 |
-
if len(key_tags) > 0:
|
| 764 |
-
tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
|
| 765 |
-
else:
|
| 766 |
-
tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
|
| 767 |
-
random.shuffle(weak_tags)
|
| 768 |
-
tags = weak_tags[:tags_num]
|
| 769 |
-
for tag_type in tags:
|
| 770 |
-
tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
|
| 771 |
-
prompts.append(tag_desc)
|
| 772 |
-
|
| 773 |
-
if len(key_tags) > 0:
|
| 774 |
-
probs = dist_prob_map[len(key_tags)]
|
| 775 |
-
tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
|
| 776 |
-
random.shuffle(key_tags)
|
| 777 |
-
tags = key_tags[:tags_num]
|
| 778 |
-
for tag_type in tags:
|
| 779 |
-
tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
|
| 780 |
-
prompts.append(tag_desc)
|
| 781 |
-
|
| 782 |
-
random.shuffle(prompts)
|
| 783 |
-
return self.comma.join(prompts)
|
| 784 |
-
|
| 785 |
-
def is_valid_prompt_text(self, text):
|
| 786 |
-
for bad in ('抱歉','sorry', 'Sorry'):
|
| 787 |
-
if bad in text:
|
| 788 |
-
return False
|
| 789 |
-
return True
|
| 790 |
-
|
| 791 |
-
def get_ready_prompt(self, path):
|
| 792 |
-
sid = mp3_path_to_id(path)
|
| 793 |
-
return random.choice(self.prompts_dict[sid])
|
| 794 |
-
|
| 795 |
-
def __len__(self):
|
| 796 |
-
return len(self.data_ids)
|
| 797 |
-
|
| 798 |
-
def __getitem__(self, idx):
|
| 799 |
-
data_id = self.data_ids[idx]
|
| 800 |
-
item = self.data_index_dict[data_id]
|
| 801 |
-
path = item['path']
|
| 802 |
-
if not self.use_ready_prompt:
|
| 803 |
-
info = self.music_info_dict[data_id]
|
| 804 |
-
description = self.generate_description(info)
|
| 805 |
-
else:
|
| 806 |
-
description = self.get_ready_prompt(path)
|
| 807 |
-
if self.return_audio:
|
| 808 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 809 |
-
audio, is_start, is_end = self.audio_reader(path, sr, duration)
|
| 810 |
-
else:
|
| 811 |
-
audio = None
|
| 812 |
-
if self.return_path:
|
| 813 |
-
return audio, description, path, is_start, is_end
|
| 814 |
-
else:
|
| 815 |
-
return audio, description, is_start, is_end
|
| 816 |
-
|
| 817 |
-
class Pond5Dataset(Dataset):
|
| 818 |
-
MAX_PROMPT_LEN = 200
|
| 819 |
-
def __init__(self,
|
| 820 |
-
metadata_path:str,
|
| 821 |
-
index_path:str,
|
| 822 |
-
duration:float=10,
|
| 823 |
-
sr:int = 0,
|
| 824 |
-
plain_rate = 0,
|
| 825 |
-
return_path = False,
|
| 826 |
-
return_audio = True,
|
| 827 |
-
lang = 'en',
|
| 828 |
-
translate:Optional[Dict[str, os.PathLike]] = None,
|
| 829 |
-
use_literal_none = True,
|
| 830 |
-
use_avoid_watermark_policy = None,
|
| 831 |
-
):
|
| 832 |
-
|
| 833 |
-
if use_avoid_watermark_policy is None:
|
| 834 |
-
raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type")
|
| 835 |
-
self.use_avoid_watermark_policy = use_avoid_watermark_policy
|
| 836 |
-
assert self.use_avoid_watermark_policy is False
|
| 837 |
-
self.audio_reader = SafeAudioReader(duration, sr)
|
| 838 |
-
|
| 839 |
-
self.duration = duration
|
| 840 |
-
self._load_metadata(metadata_path, index_path)
|
| 841 |
-
self.sr = sr
|
| 842 |
-
self.plain_rate = plain_rate
|
| 843 |
-
self.return_path = return_path
|
| 844 |
-
self.return_audio = return_audio
|
| 845 |
-
self.use_literal_none = use_literal_none
|
| 846 |
-
|
| 847 |
-
self.lang = lang
|
| 848 |
-
self.translate = read_translate(translate)
|
| 849 |
-
|
| 850 |
-
def _load_metadata(self, metadata_path, index_path):
|
| 851 |
-
data_index = read_jsonlike(index_path)
|
| 852 |
-
data_ids = set([item['id'] for item in data_index])
|
| 853 |
-
|
| 854 |
-
with open(metadata_path) as fp:
|
| 855 |
-
lines = fp.readlines()
|
| 856 |
-
|
| 857 |
-
append_ids = set()
|
| 858 |
-
|
| 859 |
-
self.data = []
|
| 860 |
-
for line in lines:
|
| 861 |
-
item = json.loads(line)
|
| 862 |
-
if item['id'] in data_ids and item['id'] not in append_ids and item["details"]["duration"] is not None and item["details"]["duration"]>self.duration+10:
|
| 863 |
-
self.data.append(item)
|
| 864 |
-
append_ids.add(item['id'])
|
| 865 |
-
|
| 866 |
-
def __len__(self):
|
| 867 |
-
return len(self.data)
|
| 868 |
-
|
| 869 |
-
def __getitem__(self, idx):
|
| 870 |
-
item = self.data[idx]
|
| 871 |
-
path:str = item["path"]
|
| 872 |
-
description = self.generate_description(item)
|
| 873 |
-
if self.return_audio:
|
| 874 |
-
sr, duration = get_sr_and_duration_info(item)
|
| 875 |
-
audio, is_start, is_end = self.audio_reader(path, sr, duration)
|
| 876 |
-
else:
|
| 877 |
-
audio = None
|
| 878 |
-
if self.return_path:
|
| 879 |
-
return audio, description, path
|
| 880 |
-
return audio, description, is_start, is_end
|
| 881 |
-
|
| 882 |
-
@property
|
| 883 |
-
def keysep(self):
|
| 884 |
-
if self.lang == 'zh':
|
| 885 |
-
return ',' if random.random() > 0.5 else '、'
|
| 886 |
-
elif self.lang == 'en':
|
| 887 |
-
return ', '
|
| 888 |
-
|
| 889 |
-
def generate_description(self, item):
|
| 890 |
-
if random.random() > self.plain_rate:
|
| 891 |
-
# dynamically generate prompt from given prompt template
|
| 892 |
-
description = self.generate_description_dynamic(item)
|
| 893 |
-
else:
|
| 894 |
-
# use plain prompt, i.e. tags sequence separated by comma
|
| 895 |
-
description = self.generate_description_plain(item)
|
| 896 |
-
return description
|
| 897 |
-
|
| 898 |
-
def get_translation(self, k):
|
| 899 |
-
k = k.strip()
|
| 900 |
-
if k in self.translate:
|
| 901 |
-
return self.translate[k]
|
| 902 |
-
else:
|
| 903 |
-
return k
|
| 904 |
-
|
| 905 |
-
def generate_description_plain(self, item):
|
| 906 |
-
keywords = item['keywords']
|
| 907 |
-
if self.lang != 'en':
|
| 908 |
-
keywords = [self.get_translation(k) for k in keywords]
|
| 909 |
-
return gen_plain_prompt(keywords, sep=self.keysep)
|
| 910 |
-
|
| 911 |
-
def generate_description_dynamic(self,item):
|
| 912 |
-
desc = item.get('desc', 'none')
|
| 913 |
-
if desc is None:
|
| 914 |
-
desc = 'none'
|
| 915 |
-
desc = desc.strip()
|
| 916 |
-
if len(desc) > self.MAX_PROMPT_LEN:
|
| 917 |
-
shorter_desc = desc[:self.MAX_PROMPT_LEN]
|
| 918 |
-
# find last stop
|
| 919 |
-
stop_idx = shorter_desc.rfind('.')
|
| 920 |
-
if stop_idx == -1:
|
| 921 |
-
stop_idx = shorter_desc.rfind('!')
|
| 922 |
-
if stop_idx == -1:
|
| 923 |
-
stop_idx = shorter_desc.rfind(',')
|
| 924 |
-
if stop_idx == -1:
|
| 925 |
-
stop_idx = self.MAX_PROMPT_LEN - 1
|
| 926 |
-
desc = desc[:stop_idx+1]
|
| 927 |
-
return desc
|
| 928 |
-
|
| 929 |
-
class CombinedDataset(Dataset):
|
| 930 |
-
@beartype
|
| 931 |
-
def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
|
| 932 |
-
self.datasets = datasets
|
| 933 |
-
self.datasets_index = []
|
| 934 |
-
|
| 935 |
-
for i,dataset in enumerate(datasets):
|
| 936 |
-
if dataset is None:
|
| 937 |
-
continue
|
| 938 |
-
for dup in range(ratios[i]):
|
| 939 |
-
for j in range(len(dataset)):
|
| 940 |
-
self.datasets_index.append((i,j))
|
| 941 |
-
|
| 942 |
-
def __len__(self):
|
| 943 |
-
return len(self.datasets_index)
|
| 944 |
-
|
| 945 |
-
def __getitem__(self, idx):
|
| 946 |
-
index = self.datasets_index[idx]
|
| 947 |
-
i,j = index
|
| 948 |
-
return self.datasets[i][j]
|
| 949 |
-
|
| 950 |
-
class CombinedDataset_random(Dataset):
|
| 951 |
-
@beartype
|
| 952 |
-
def __init__(self,
|
| 953 |
-
num_examples:int,
|
| 954 |
-
datasets: Sequence[Dataset], ratios: Sequence[int]
|
| 955 |
-
):
|
| 956 |
-
self.datasets = datasets
|
| 957 |
-
self.datasets_index = []
|
| 958 |
-
|
| 959 |
-
for i,dataset in enumerate(datasets):
|
| 960 |
-
if dataset is None:
|
| 961 |
-
continue
|
| 962 |
-
for dup in range(ratios[i]):
|
| 963 |
-
for j in range(len(dataset)):
|
| 964 |
-
self.datasets_index.append((i,j))
|
| 965 |
-
if num_examples > 0:
|
| 966 |
-
self.random_choose = True
|
| 967 |
-
self.dataset_len = num_examples
|
| 968 |
-
else:
|
| 969 |
-
self.random_choose = False
|
| 970 |
-
self.dataset_len = len(self.datasets_index)
|
| 971 |
-
|
| 972 |
-
def __len__(self):
|
| 973 |
-
return self.dataset_len
|
| 974 |
-
|
| 975 |
-
def __getitem__(self, idx):
|
| 976 |
-
first_try = True
|
| 977 |
-
try_cnt = 0
|
| 978 |
-
while True:
|
| 979 |
-
try:
|
| 980 |
-
if(self.random_choose or not first_try):
|
| 981 |
-
index2 = []
|
| 982 |
-
index2.append(np.random.randint(0,len(self.datasets)))
|
| 983 |
-
index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
|
| 984 |
-
else:
|
| 985 |
-
index2 = self.datasets_index[idx]
|
| 986 |
-
first_try = False
|
| 987 |
-
out = self.datasets[index2[0]][index2[1]]
|
| 988 |
-
if(len(out[0].shape)==1):out[0]=out[0][None,:]
|
| 989 |
-
return out
|
| 990 |
-
except:
|
| 991 |
-
print("Error loadding ", index2)
|
| 992 |
-
try_cnt += 1
|
| 993 |
-
if(try_cnt>10):
|
| 994 |
-
raise FileNotFoundError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py
DELETED
|
@@ -1,313 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import sys
|
| 3 |
-
import json
|
| 4 |
-
|
| 5 |
-
from torch.utils.data import Dataset
|
| 6 |
-
import torchaudio
|
| 7 |
-
from torchaudio.functional import resample
|
| 8 |
-
import torch
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def check_lryics(lyric):
|
| 16 |
-
_FILTER_STRING = [
|
| 17 |
-
'作词', '作曲', '编曲', '【', '策划',
|
| 18 |
-
'录音', '混音', '母带', ':', '制作',
|
| 19 |
-
'版权', '校对', '演奏', '制作', '伴奏'
|
| 20 |
-
]
|
| 21 |
-
for item in _FILTER_STRING:
|
| 22 |
-
if item in lyric:
|
| 23 |
-
return True
|
| 24 |
-
|
| 25 |
-
return False
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def process_lyrics(lines):
|
| 30 |
-
lyric_part = []
|
| 31 |
-
timestamp_part = []
|
| 32 |
-
|
| 33 |
-
timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
|
| 34 |
-
|
| 35 |
-
for i, line in enumerate(lines):
|
| 36 |
-
|
| 37 |
-
# 删除前几行的特定信息
|
| 38 |
-
if i<10 and check_lryics(line):
|
| 39 |
-
continue
|
| 40 |
-
|
| 41 |
-
# 检查是否包含有效的时间戳和歌词内容
|
| 42 |
-
if timestamp_pattern.match(line):
|
| 43 |
-
timestamp_end = line.rfind(']')
|
| 44 |
-
lyrics = line[timestamp_end + 1:].strip()
|
| 45 |
-
timestamps = line[:timestamp_end + 1]
|
| 46 |
-
|
| 47 |
-
if ':' in lyrics:
|
| 48 |
-
if len(lyrics.split(":")[0]) <=5:
|
| 49 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 50 |
-
# if lyrics: # 确保歌词部分不是空的
|
| 51 |
-
# lyric_part.append(lyrics)
|
| 52 |
-
# timestamp_part.append(timestamps)
|
| 53 |
-
# print(processed_lyrics)
|
| 54 |
-
return timestamp_part, lyric_part
|
| 55 |
-
|
| 56 |
-
def get_timestamps(timestamp_part):
|
| 57 |
-
|
| 58 |
-
# 转换为秒
|
| 59 |
-
|
| 60 |
-
timestamps = []
|
| 61 |
-
|
| 62 |
-
for line in timestamp_part:
|
| 63 |
-
match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
|
| 64 |
-
if match:
|
| 65 |
-
minutes = int(match.group(1))
|
| 66 |
-
seconds = float(match.group(2))
|
| 67 |
-
millis = float(match.group(3)) if match.group(3) else 0
|
| 68 |
-
total_seconds = minutes * 60 + seconds + millis
|
| 69 |
-
timestamps.append(total_seconds)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
return timestamps
|
| 73 |
-
|
| 74 |
-
def process_lyrics_lrc(lyrics):
|
| 75 |
-
timestamp_part, lyric_part = process_lyrics(lyrics)
|
| 76 |
-
# print(timestamp_part)
|
| 77 |
-
# print(lyric_part)
|
| 78 |
-
timestamps = get_timestamps(timestamp_part)
|
| 79 |
-
# print(timestamps)
|
| 80 |
-
if len(timestamps) == 0:
|
| 81 |
-
# print(f'{lyric_path}')
|
| 82 |
-
return []
|
| 83 |
-
|
| 84 |
-
slice_start = timestamps[0]
|
| 85 |
-
slice_start_idx = 0
|
| 86 |
-
|
| 87 |
-
output_list = []
|
| 88 |
-
for i in range(1, len(timestamps)):
|
| 89 |
-
# 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
|
| 90 |
-
if timestamps[i] - slice_start > 30:
|
| 91 |
-
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
| 92 |
-
|
| 93 |
-
slice_start = timestamps[i]
|
| 94 |
-
slice_start_idx = i
|
| 95 |
-
|
| 96 |
-
return output_list
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def process_lyrics_yrc(lyrics):
|
| 101 |
-
|
| 102 |
-
timestamps, lyric_part = extract_lrc(lyrics)
|
| 103 |
-
|
| 104 |
-
# timestamp_part, lyric_part = process_lyrics(lyrics)
|
| 105 |
-
# import pdb; pdb.set_trace()
|
| 106 |
-
# print(timestamp_part)
|
| 107 |
-
# print(lyric_part)
|
| 108 |
-
# timestamps = get_timestamps(timestamp_part)
|
| 109 |
-
# print(timestamps)
|
| 110 |
-
if len(timestamps) == 0:
|
| 111 |
-
# print(f'{lyric_path}')
|
| 112 |
-
return []
|
| 113 |
-
|
| 114 |
-
slice_start = timestamps[0]
|
| 115 |
-
slice_start_idx = 0
|
| 116 |
-
|
| 117 |
-
output_list = []
|
| 118 |
-
for i in range(1, len(timestamps)):
|
| 119 |
-
# 如果累积时间超过30秒,则进行切分
|
| 120 |
-
if timestamps[i] - slice_start > 30:
|
| 121 |
-
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
| 122 |
-
|
| 123 |
-
slice_start = timestamps[i]
|
| 124 |
-
slice_start_idx = i
|
| 125 |
-
# import pdb; pdb.set_trace()
|
| 126 |
-
return output_list
|
| 127 |
-
|
| 128 |
-
def extract_lrc(lyrics):
|
| 129 |
-
timestamp_part, lyric_part = [], []
|
| 130 |
-
|
| 131 |
-
for i, text in enumerate(lyrics):
|
| 132 |
-
# 提取中括号内的内容
|
| 133 |
-
bracket_content = re.search(r'\[(.*?)\]', text).group(1)
|
| 134 |
-
bracket_content = bracket_content.split(',')
|
| 135 |
-
# 提取小括号内的内容
|
| 136 |
-
parentheses_content = re.findall(r'\((.*?)\)', text)
|
| 137 |
-
# 提取其他内容
|
| 138 |
-
other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
|
| 139 |
-
|
| 140 |
-
# 数据怎么处理?
|
| 141 |
-
# import pdb; pdb.set_trace()
|
| 142 |
-
if i<10 and check_lryics(other_content):
|
| 143 |
-
continue
|
| 144 |
-
|
| 145 |
-
# import pdb; pdb.set_trace()
|
| 146 |
-
timestamp_part.append(float(bracket_content[0])/1000)
|
| 147 |
-
lyric_part.append(other_content)
|
| 148 |
-
# import pdb; pdb.set_trace()
|
| 149 |
-
return timestamp_part, lyric_part
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
class WYYSongDataset(Dataset):
|
| 154 |
-
def __init__(self,
|
| 155 |
-
metadata_path:str,
|
| 156 |
-
sr:int = 0,
|
| 157 |
-
use_lang = ['en', 'zh-cn'],
|
| 158 |
-
num_examples = -1,
|
| 159 |
-
):
|
| 160 |
-
|
| 161 |
-
self.sr = sr
|
| 162 |
-
self.use_lang = use_lang
|
| 163 |
-
self._load_metadata(metadata_path)
|
| 164 |
-
|
| 165 |
-
# buffer
|
| 166 |
-
self.lyric_buffer = {}
|
| 167 |
-
|
| 168 |
-
if(num_examples<=0):
|
| 169 |
-
self.dataset_len = len(self.data)
|
| 170 |
-
self.random_slc = False
|
| 171 |
-
else:
|
| 172 |
-
self.dataset_len = num_examples
|
| 173 |
-
self.random_slc = True
|
| 174 |
-
|
| 175 |
-
# 读取jsonl文件
|
| 176 |
-
def _load_metadata(self, metadata_path):
|
| 177 |
-
with open(metadata_path) as fp:
|
| 178 |
-
lines = fp.readlines()
|
| 179 |
-
self.data = []
|
| 180 |
-
for line in lines:
|
| 181 |
-
item = json.loads(line)
|
| 182 |
-
# if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None:
|
| 183 |
-
if 'lyrics' in item and 'lang_info' in item:
|
| 184 |
-
if len(item['lyrics']) > 0:
|
| 185 |
-
for lang in self.use_lang:
|
| 186 |
-
if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9:
|
| 187 |
-
# if '伴奏' not in item['path'] and "cloud" in item['path']:
|
| 188 |
-
if '伴奏' not in item['path']:
|
| 189 |
-
self.data.append(item)
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
def __len__(self):
|
| 193 |
-
return self.dataset_len
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def __getitem__(self, idx):
|
| 197 |
-
try_cnt = 0
|
| 198 |
-
while True:
|
| 199 |
-
if(self.random_slc):
|
| 200 |
-
idx = np.random.randint(0, len(self.data))
|
| 201 |
-
yrc_lyrics = []
|
| 202 |
-
lrc_lyrics = []
|
| 203 |
-
try:
|
| 204 |
-
info = self.data[idx]
|
| 205 |
-
|
| 206 |
-
# audio path
|
| 207 |
-
path:str = info["path"]
|
| 208 |
-
|
| 209 |
-
# 读取歌词段落
|
| 210 |
-
if 'lyrics' not in info:
|
| 211 |
-
if idx not in self.lyric_buffer:
|
| 212 |
-
# 字级别align的歌词
|
| 213 |
-
if info['yrc-lyric'] is not None:
|
| 214 |
-
with open(info['yrc-lyric']) as f_in:
|
| 215 |
-
yrc_lyric = json.load(f_in)
|
| 216 |
-
yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1])
|
| 217 |
-
|
| 218 |
-
# 句子级align的歌词
|
| 219 |
-
if info['lrc-lyric'] is not None:
|
| 220 |
-
with open(info['lrc-lyric']) as f_in:
|
| 221 |
-
lrc_lyric = json.load(f_in)
|
| 222 |
-
lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1])
|
| 223 |
-
|
| 224 |
-
# 优先使用字级别align的歌词
|
| 225 |
-
if len(yrc_lyrics) > 0:
|
| 226 |
-
lyrics = yrc_lyrics
|
| 227 |
-
else:
|
| 228 |
-
lyrics = lrc_lyrics
|
| 229 |
-
self.lyric_buffer[idx] = lyrics
|
| 230 |
-
|
| 231 |
-
# TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲
|
| 232 |
-
else:
|
| 233 |
-
lyrics = self.lyric_buffer[idx]
|
| 234 |
-
else:
|
| 235 |
-
lyrics = info['lyrics']
|
| 236 |
-
|
| 237 |
-
# 随机选取一个lyric段落
|
| 238 |
-
ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
|
| 239 |
-
# ly_id = 0
|
| 240 |
-
|
| 241 |
-
lyric = lyrics[ly_id]
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
st, et, lyric = self.parse_lyric(lyric)
|
| 246 |
-
|
| 247 |
-
assert et - st < 40
|
| 248 |
-
|
| 249 |
-
# 文本过滤
|
| 250 |
-
|
| 251 |
-
lyric = re.sub(r'【.*?】', '', lyric)
|
| 252 |
-
if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8:
|
| 253 |
-
assert 200 > len(lyric.replace(" ", "")) > 30
|
| 254 |
-
if ':' in lyrics:
|
| 255 |
-
if len(lyrics.split(":")[0]) <=5:
|
| 256 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 257 |
-
|
| 258 |
-
if ':' in lyrics:
|
| 259 |
-
if len(lyrics.split(":")[0]) <=5:
|
| 260 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 261 |
-
|
| 262 |
-
if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8:
|
| 263 |
-
assert 200 > len(lyric.split()) > 20
|
| 264 |
-
|
| 265 |
-
if ':' in lyrics:
|
| 266 |
-
if len(lyrics.split(":")[0].split()) <=3:
|
| 267 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 268 |
-
|
| 269 |
-
if ':' in lyrics:
|
| 270 |
-
if len(lyrics.split(":")[0].split()) <=3:
|
| 271 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
# 读取音频文件
|
| 276 |
-
cur_sample_rate = torchaudio.info(path).sample_rate
|
| 277 |
-
offset = int(cur_sample_rate*st)
|
| 278 |
-
num_frames = int(cur_sample_rate * (et -st))
|
| 279 |
-
chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
|
| 280 |
-
|
| 281 |
-
# 随机选取一个channel
|
| 282 |
-
if(chunk.shape[0]>1):
|
| 283 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 284 |
-
else:
|
| 285 |
-
chunk = chunk[[0],:].float()
|
| 286 |
-
|
| 287 |
-
if(cur_sample_rate!=self.sr):
|
| 288 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 289 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
|
| 290 |
-
|
| 291 |
-
return chunk, lyric, [st, et], path
|
| 292 |
-
except:
|
| 293 |
-
print("Error loadding ", info["path"])
|
| 294 |
-
try_cnt += 1
|
| 295 |
-
idx = np.random.randint(0, len(self.data))
|
| 296 |
-
if(try_cnt>10):
|
| 297 |
-
raise FileNotFoundError()
|
| 298 |
-
|
| 299 |
-
def parse_lyric(self, lyric):
|
| 300 |
-
pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
|
| 301 |
-
match = re.search(pattern, lyric)
|
| 302 |
-
|
| 303 |
-
start_time = float(match.group(1))
|
| 304 |
-
end_time = float(match.group(2))
|
| 305 |
-
content = match.group(3)
|
| 306 |
-
return start_time, end_time, content
|
| 307 |
-
|
| 308 |
-
def collect_song(data_list):
|
| 309 |
-
audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
|
| 310 |
-
lyrics = [data[1] for data in data_list]
|
| 311 |
-
st_et = [data[2] for data in data_list]
|
| 312 |
-
paths = [data[3] for data in data_list]
|
| 313 |
-
return audios, lyrics, st_et
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py
DELETED
|
@@ -1,313 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import sys
|
| 3 |
-
import json
|
| 4 |
-
|
| 5 |
-
from torch.utils.data import Dataset
|
| 6 |
-
import torchaudio
|
| 7 |
-
from torchaudio.functional import resample
|
| 8 |
-
import torch
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def check_lryics(lyric):
|
| 16 |
-
_FILTER_STRING = [
|
| 17 |
-
'作词', '作曲', '编曲', '【', '策划',
|
| 18 |
-
'录音', '混音', '母带', ':', '制作',
|
| 19 |
-
'版权', '校对', '演奏', '制作', '伴奏'
|
| 20 |
-
]
|
| 21 |
-
for item in _FILTER_STRING:
|
| 22 |
-
if item in lyric:
|
| 23 |
-
return True
|
| 24 |
-
|
| 25 |
-
return False
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def process_lyrics(lines):
|
| 30 |
-
lyric_part = []
|
| 31 |
-
timestamp_part = []
|
| 32 |
-
|
| 33 |
-
timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
|
| 34 |
-
|
| 35 |
-
for i, line in enumerate(lines):
|
| 36 |
-
|
| 37 |
-
# 删除前几行的特定信息
|
| 38 |
-
if i<10 and check_lryics(line):
|
| 39 |
-
continue
|
| 40 |
-
|
| 41 |
-
# 检查是否包含有效的时间戳和歌词内容
|
| 42 |
-
if timestamp_pattern.match(line):
|
| 43 |
-
timestamp_end = line.rfind(']')
|
| 44 |
-
lyrics = line[timestamp_end + 1:].strip()
|
| 45 |
-
timestamps = line[:timestamp_end + 1]
|
| 46 |
-
|
| 47 |
-
if ':' in lyrics:
|
| 48 |
-
if len(lyrics.split(":")[0]) <=5:
|
| 49 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 50 |
-
# if lyrics: # 确保歌词部分不是空的
|
| 51 |
-
# lyric_part.append(lyrics)
|
| 52 |
-
# timestamp_part.append(timestamps)
|
| 53 |
-
# print(processed_lyrics)
|
| 54 |
-
return timestamp_part, lyric_part
|
| 55 |
-
|
| 56 |
-
def get_timestamps(timestamp_part):
|
| 57 |
-
|
| 58 |
-
# 转换为秒
|
| 59 |
-
|
| 60 |
-
timestamps = []
|
| 61 |
-
|
| 62 |
-
for line in timestamp_part:
|
| 63 |
-
match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
|
| 64 |
-
if match:
|
| 65 |
-
minutes = int(match.group(1))
|
| 66 |
-
seconds = float(match.group(2))
|
| 67 |
-
millis = float(match.group(3)) if match.group(3) else 0
|
| 68 |
-
total_seconds = minutes * 60 + seconds + millis
|
| 69 |
-
timestamps.append(total_seconds)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
return timestamps
|
| 73 |
-
|
| 74 |
-
def process_lyrics_lrc(lyrics):
|
| 75 |
-
timestamp_part, lyric_part = process_lyrics(lyrics)
|
| 76 |
-
# print(timestamp_part)
|
| 77 |
-
# print(lyric_part)
|
| 78 |
-
timestamps = get_timestamps(timestamp_part)
|
| 79 |
-
# print(timestamps)
|
| 80 |
-
if len(timestamps) == 0:
|
| 81 |
-
# print(f'{lyric_path}')
|
| 82 |
-
return []
|
| 83 |
-
|
| 84 |
-
slice_start = timestamps[0]
|
| 85 |
-
slice_start_idx = 0
|
| 86 |
-
|
| 87 |
-
output_list = []
|
| 88 |
-
for i in range(1, len(timestamps)):
|
| 89 |
-
# 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
|
| 90 |
-
if timestamps[i] - slice_start > 30:
|
| 91 |
-
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
| 92 |
-
|
| 93 |
-
slice_start = timestamps[i]
|
| 94 |
-
slice_start_idx = i
|
| 95 |
-
|
| 96 |
-
return output_list
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def process_lyrics_yrc(lyrics):
|
| 101 |
-
|
| 102 |
-
timestamps, lyric_part = extract_lrc(lyrics)
|
| 103 |
-
|
| 104 |
-
# timestamp_part, lyric_part = process_lyrics(lyrics)
|
| 105 |
-
# import pdb; pdb.set_trace()
|
| 106 |
-
# print(timestamp_part)
|
| 107 |
-
# print(lyric_part)
|
| 108 |
-
# timestamps = get_timestamps(timestamp_part)
|
| 109 |
-
# print(timestamps)
|
| 110 |
-
if len(timestamps) == 0:
|
| 111 |
-
# print(f'{lyric_path}')
|
| 112 |
-
return []
|
| 113 |
-
|
| 114 |
-
slice_start = timestamps[0]
|
| 115 |
-
slice_start_idx = 0
|
| 116 |
-
|
| 117 |
-
output_list = []
|
| 118 |
-
for i in range(1, len(timestamps)):
|
| 119 |
-
# 如果累积时间超过30秒,则进行切分
|
| 120 |
-
if timestamps[i] - slice_start > 30:
|
| 121 |
-
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
| 122 |
-
|
| 123 |
-
slice_start = timestamps[i]
|
| 124 |
-
slice_start_idx = i
|
| 125 |
-
# import pdb; pdb.set_trace()
|
| 126 |
-
return output_list
|
| 127 |
-
|
| 128 |
-
def extract_lrc(lyrics):
|
| 129 |
-
timestamp_part, lyric_part = [], []
|
| 130 |
-
|
| 131 |
-
for i, text in enumerate(lyrics):
|
| 132 |
-
# 提取中括号内的内容
|
| 133 |
-
bracket_content = re.search(r'\[(.*?)\]', text).group(1)
|
| 134 |
-
bracket_content = bracket_content.split(',')
|
| 135 |
-
# 提取小括号内的内容
|
| 136 |
-
parentheses_content = re.findall(r'\((.*?)\)', text)
|
| 137 |
-
# 提取其他内容
|
| 138 |
-
other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
|
| 139 |
-
|
| 140 |
-
# 数据怎么处理?
|
| 141 |
-
# import pdb; pdb.set_trace()
|
| 142 |
-
if i<10 and check_lryics(other_content):
|
| 143 |
-
continue
|
| 144 |
-
|
| 145 |
-
# import pdb; pdb.set_trace()
|
| 146 |
-
timestamp_part.append(float(bracket_content[0])/1000)
|
| 147 |
-
lyric_part.append(other_content)
|
| 148 |
-
# import pdb; pdb.set_trace()
|
| 149 |
-
return timestamp_part, lyric_part
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
class WYYSongDataset(Dataset):
|
| 154 |
-
def __init__(self,
|
| 155 |
-
metadata_path:str,
|
| 156 |
-
sr:int = 0,
|
| 157 |
-
use_lang = ['en', 'zh-cn'],
|
| 158 |
-
num_examples = -1,
|
| 159 |
-
):
|
| 160 |
-
|
| 161 |
-
self.sr = sr
|
| 162 |
-
self.use_lang = use_lang
|
| 163 |
-
self._load_metadata(metadata_path)
|
| 164 |
-
|
| 165 |
-
# buffer
|
| 166 |
-
self.lyric_buffer = {}
|
| 167 |
-
|
| 168 |
-
if(num_examples<=0):
|
| 169 |
-
self.dataset_len = len(self.data)
|
| 170 |
-
self.random_slc = False
|
| 171 |
-
else:
|
| 172 |
-
self.dataset_len = num_examples
|
| 173 |
-
self.random_slc = True
|
| 174 |
-
|
| 175 |
-
# 读取jsonl文件
|
| 176 |
-
def _load_metadata(self, metadata_path):
|
| 177 |
-
with open(metadata_path) as fp:
|
| 178 |
-
lines = fp.readlines()
|
| 179 |
-
self.data = []
|
| 180 |
-
for line in lines:
|
| 181 |
-
item = json.loads(line)
|
| 182 |
-
# if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None:
|
| 183 |
-
if 'lyrics' in item and 'lang_info' in item:
|
| 184 |
-
if len(item['lyrics']) > 0:
|
| 185 |
-
for lang in self.use_lang:
|
| 186 |
-
if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9:
|
| 187 |
-
# if '伴奏' not in item['path'] and "cloud" in item['path']:
|
| 188 |
-
if '伴奏' not in item['path']:
|
| 189 |
-
self.data.append(item)
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
def __len__(self):
|
| 193 |
-
return self.dataset_len
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def __getitem__(self, idx):
|
| 197 |
-
try_cnt = 0
|
| 198 |
-
while True:
|
| 199 |
-
if(self.random_slc):
|
| 200 |
-
idx = np.random.randint(0, len(self.data))
|
| 201 |
-
yrc_lyrics = []
|
| 202 |
-
lrc_lyrics = []
|
| 203 |
-
try:
|
| 204 |
-
info = self.data[idx]
|
| 205 |
-
|
| 206 |
-
# audio path
|
| 207 |
-
path:str = info["path"]
|
| 208 |
-
|
| 209 |
-
# 读取歌词段落
|
| 210 |
-
if 'lyrics' not in info:
|
| 211 |
-
if idx not in self.lyric_buffer:
|
| 212 |
-
# 字级别align的歌词
|
| 213 |
-
if info['yrc-lyric'] is not None:
|
| 214 |
-
with open(info['yrc-lyric']) as f_in:
|
| 215 |
-
yrc_lyric = json.load(f_in)
|
| 216 |
-
yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1])
|
| 217 |
-
|
| 218 |
-
# 句子级align的歌词
|
| 219 |
-
if info['lrc-lyric'] is not None:
|
| 220 |
-
with open(info['lrc-lyric']) as f_in:
|
| 221 |
-
lrc_lyric = json.load(f_in)
|
| 222 |
-
lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1])
|
| 223 |
-
|
| 224 |
-
# 优先使用字级别align的歌词
|
| 225 |
-
if len(yrc_lyrics) > 0:
|
| 226 |
-
lyrics = yrc_lyrics
|
| 227 |
-
else:
|
| 228 |
-
lyrics = lrc_lyrics
|
| 229 |
-
self.lyric_buffer[idx] = lyrics
|
| 230 |
-
|
| 231 |
-
# TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲
|
| 232 |
-
else:
|
| 233 |
-
lyrics = self.lyric_buffer[idx]
|
| 234 |
-
else:
|
| 235 |
-
lyrics = info['lyrics']
|
| 236 |
-
|
| 237 |
-
# 随机选取一个lyric段落
|
| 238 |
-
ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
|
| 239 |
-
# ly_id = 0
|
| 240 |
-
|
| 241 |
-
lyric = lyrics[ly_id]
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
st, et, lyric = self.parse_lyric(lyric)
|
| 246 |
-
|
| 247 |
-
assert et - st < 20
|
| 248 |
-
|
| 249 |
-
# 文本过滤
|
| 250 |
-
|
| 251 |
-
lyric = re.sub(r'【.*?】', '', lyric)
|
| 252 |
-
if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8:
|
| 253 |
-
assert 100 > len(lyric.replace(" ", "")) > 5
|
| 254 |
-
if ':' in lyrics:
|
| 255 |
-
if len(lyrics.split(":")[0]) <=5:
|
| 256 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 257 |
-
|
| 258 |
-
if ':' in lyrics:
|
| 259 |
-
if len(lyrics.split(":")[0]) <=5:
|
| 260 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 261 |
-
|
| 262 |
-
if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8:
|
| 263 |
-
assert 100 > len(lyric.split()) > 5
|
| 264 |
-
|
| 265 |
-
if ':' in lyrics:
|
| 266 |
-
if len(lyrics.split(":")[0].split()) <=3:
|
| 267 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 268 |
-
|
| 269 |
-
if ':' in lyrics:
|
| 270 |
-
if len(lyrics.split(":")[0].split()) <=3:
|
| 271 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
# 读取音频文件
|
| 276 |
-
cur_sample_rate = torchaudio.info(path).sample_rate
|
| 277 |
-
offset = int(cur_sample_rate*st)
|
| 278 |
-
num_frames = int(cur_sample_rate * (et -st))
|
| 279 |
-
chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
|
| 280 |
-
|
| 281 |
-
# 随机选取一个channel
|
| 282 |
-
if(chunk.shape[0]>1):
|
| 283 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 284 |
-
else:
|
| 285 |
-
chunk = chunk[[0],:].float()
|
| 286 |
-
|
| 287 |
-
if(cur_sample_rate!=self.sr):
|
| 288 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 289 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
|
| 290 |
-
|
| 291 |
-
return chunk, lyric, [st, et], path
|
| 292 |
-
except:
|
| 293 |
-
print("Error loadding ", info["path"])
|
| 294 |
-
try_cnt += 1
|
| 295 |
-
idx = np.random.randint(0, len(self.data))
|
| 296 |
-
if(try_cnt>10):
|
| 297 |
-
raise FileNotFoundError()
|
| 298 |
-
|
| 299 |
-
def parse_lyric(self, lyric):
|
| 300 |
-
pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
|
| 301 |
-
match = re.search(pattern, lyric)
|
| 302 |
-
|
| 303 |
-
start_time = float(match.group(1))
|
| 304 |
-
end_time = float(match.group(2))
|
| 305 |
-
content = match.group(3)
|
| 306 |
-
return start_time, end_time, content
|
| 307 |
-
|
| 308 |
-
def collect_song(data_list):
|
| 309 |
-
audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
|
| 310 |
-
lyrics = [data[1] for data in data_list]
|
| 311 |
-
st_et = [data[2] for data in data_list]
|
| 312 |
-
paths = [data[3] for data in data_list]
|
| 313 |
-
return audios, lyrics, st_et
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py
DELETED
|
@@ -1,313 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import sys
|
| 3 |
-
import json
|
| 4 |
-
|
| 5 |
-
from torch.utils.data import Dataset
|
| 6 |
-
import torchaudio
|
| 7 |
-
from torchaudio.functional import resample
|
| 8 |
-
import torch
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def check_lryics(lyric):
|
| 16 |
-
_FILTER_STRING = [
|
| 17 |
-
'作词', '作曲', '编曲', '【', '策划',
|
| 18 |
-
'录音', '混音', '母带', ':', '制作',
|
| 19 |
-
'版权', '校对', '演奏', '制作', '伴奏'
|
| 20 |
-
]
|
| 21 |
-
for item in _FILTER_STRING:
|
| 22 |
-
if item in lyric:
|
| 23 |
-
return True
|
| 24 |
-
|
| 25 |
-
return False
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def process_lyrics(lines):
|
| 30 |
-
lyric_part = []
|
| 31 |
-
timestamp_part = []
|
| 32 |
-
|
| 33 |
-
timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
|
| 34 |
-
|
| 35 |
-
for i, line in enumerate(lines):
|
| 36 |
-
|
| 37 |
-
# 删除前几行的特定信息
|
| 38 |
-
if i<10 and check_lryics(line):
|
| 39 |
-
continue
|
| 40 |
-
|
| 41 |
-
# 检查是否包含有效的时间戳和歌词内容
|
| 42 |
-
if timestamp_pattern.match(line):
|
| 43 |
-
timestamp_end = line.rfind(']')
|
| 44 |
-
lyrics = line[timestamp_end + 1:].strip()
|
| 45 |
-
timestamps = line[:timestamp_end + 1]
|
| 46 |
-
|
| 47 |
-
if ':' in lyrics:
|
| 48 |
-
if len(lyrics.split(":")[0]) <=5:
|
| 49 |
-
lyrics = "".join(lyrics.split(":")[1:])
|
| 50 |
-
# if lyrics: # 确保歌词部分不是空的
|
| 51 |
-
# lyric_part.append(lyrics)
|
| 52 |
-
# timestamp_part.append(timestamps)
|
| 53 |
-
# print(processed_lyrics)
|
| 54 |
-
return timestamp_part, lyric_part
|
| 55 |
-
|
| 56 |
-
def get_timestamps(timestamp_part):
|
| 57 |
-
|
| 58 |
-
# 转换为秒
|
| 59 |
-
|
| 60 |
-
timestamps = []
|
| 61 |
-
|
| 62 |
-
for line in timestamp_part:
|
| 63 |
-
match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
|
| 64 |
-
if match:
|
| 65 |
-
minutes = int(match.group(1))
|
| 66 |
-
seconds = float(match.group(2))
|
| 67 |
-
millis = float(match.group(3)) if match.group(3) else 0
|
| 68 |
-
total_seconds = minutes * 60 + seconds + millis
|
| 69 |
-
timestamps.append(total_seconds)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
return timestamps
|
| 73 |
-
|
| 74 |
-
def process_lyrics_lrc(lyrics):
|
| 75 |
-
timestamp_part, lyric_part = process_lyrics(lyrics)
|
| 76 |
-
# print(timestamp_part)
|
| 77 |
-
# print(lyric_part)
|
| 78 |
-
timestamps = get_timestamps(timestamp_part)
|
| 79 |
-
# print(timestamps)
|
| 80 |
-
if len(timestamps) == 0:
|
| 81 |
-
# print(f'{lyric_path}')
|
| 82 |
-
return []
|
| 83 |
-
|
| 84 |
-
slice_start = timestamps[0]
|
| 85 |
-
slice_start_idx = 0
|
| 86 |
-
|
| 87 |
-
output_list = []
|
| 88 |
-
for i in range(1, len(timestamps)):
|
| 89 |
-
# 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
|
| 90 |
-
if timestamps[i] - slice_start > 30:
|
| 91 |
-
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
| 92 |
-
|
| 93 |
-
slice_start = timestamps[i]
|
| 94 |
-
slice_start_idx = i
|
| 95 |
-
|
| 96 |
-
return output_list
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def process_lyrics_yrc(lyrics):
|
| 101 |
-
|
| 102 |
-
timestamps, lyric_part = extract_lrc(lyrics)
|
| 103 |
-
|
| 104 |
-
# timestamp_part, lyric_part = process_lyrics(lyrics)
|
| 105 |
-
# import pdb; pdb.set_trace()
|
| 106 |
-
# print(timestamp_part)
|
| 107 |
-
# print(lyric_part)
|
| 108 |
-
# timestamps = get_timestamps(timestamp_part)
|
| 109 |
-
# print(timestamps)
|
| 110 |
-
if len(timestamps) == 0:
|
| 111 |
-
# print(f'{lyric_path}')
|
| 112 |
-
return []
|
| 113 |
-
|
| 114 |
-
slice_start = timestamps[0]
|
| 115 |
-
slice_start_idx = 0
|
| 116 |
-
|
| 117 |
-
output_list = []
|
| 118 |
-
for i in range(1, len(timestamps)):
|
| 119 |
-
# 如果累积时间超过30秒,则进行切分
|
| 120 |
-
if timestamps[i] - slice_start > 30:
|
| 121 |
-
output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
|
| 122 |
-
|
| 123 |
-
slice_start = timestamps[i]
|
| 124 |
-
slice_start_idx = i
|
| 125 |
-
# import pdb; pdb.set_trace()
|
| 126 |
-
return output_list
|
| 127 |
-
|
| 128 |
-
def extract_lrc(lyrics):
|
| 129 |
-
timestamp_part, lyric_part = [], []
|
| 130 |
-
|
| 131 |
-
for i, text in enumerate(lyrics):
|
| 132 |
-
# 提取中括号内的内容
|
| 133 |
-
bracket_content = re.search(r'\[(.*?)\]', text).group(1)
|
| 134 |
-
bracket_content = bracket_content.split(',')
|
| 135 |
-
# 提取小括号内的内容
|
| 136 |
-
parentheses_content = re.findall(r'\((.*?)\)', text)
|
| 137 |
-
# 提取其他内容
|
| 138 |
-
other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
|
| 139 |
-
|
| 140 |
-
# 数据怎么处理?
|
| 141 |
-
if i<10 and check_lryics(other_content):
|
| 142 |
-
continue
|
| 143 |
-
timestamp_part.append(float(bracket_content[0])/1000)
|
| 144 |
-
lyric_part.append(other_content)
|
| 145 |
-
return timestamp_part, lyric_part
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
class WYYSongDataset(Dataset):
|
| 150 |
-
def __init__(self,
|
| 151 |
-
metadata_path:str,
|
| 152 |
-
sr:int = 0,
|
| 153 |
-
use_lang = ['en', 'zh-cn'],
|
| 154 |
-
num_examples = -1,
|
| 155 |
-
max_dur = 20,
|
| 156 |
-
pad_to_max= True,
|
| 157 |
-
):
|
| 158 |
-
|
| 159 |
-
self.sr = sr
|
| 160 |
-
self.use_lang = use_lang
|
| 161 |
-
self._load_metadata(metadata_path)
|
| 162 |
-
self.max_dur = max_dur
|
| 163 |
-
self.pad_to_max = pad_to_max
|
| 164 |
-
|
| 165 |
-
# buffer
|
| 166 |
-
self.lyric_buffer = {}
|
| 167 |
-
|
| 168 |
-
if(num_examples<=0):
|
| 169 |
-
self.dataset_len = len(self.data)
|
| 170 |
-
self.random_slc = False
|
| 171 |
-
else:
|
| 172 |
-
self.dataset_len = num_examples
|
| 173 |
-
self.random_slc = True
|
| 174 |
-
|
| 175 |
-
# 读取jsonl文件
|
| 176 |
-
def _load_metadata(self, metadata_path):
|
| 177 |
-
with open(metadata_path) as fp:
|
| 178 |
-
lines = fp.readlines()
|
| 179 |
-
self.data = []
|
| 180 |
-
for line in lines:
|
| 181 |
-
item = json.loads(line)
|
| 182 |
-
if '伴奏' not in item['path']:
|
| 183 |
-
# if "lang_type" in item and item['lang_type'] == 'en':
|
| 184 |
-
if "lang_type" in item:
|
| 185 |
-
self.data.append(item)
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def __len__(self):
|
| 189 |
-
return self.dataset_len
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
def __getitem__(self, idx):
|
| 193 |
-
try_cnt = 0
|
| 194 |
-
while True:
|
| 195 |
-
if(self.random_slc):
|
| 196 |
-
idx = np.random.randint(0, len(self.data))
|
| 197 |
-
yrc_lyrics = []
|
| 198 |
-
lrc_lyrics = []
|
| 199 |
-
try:
|
| 200 |
-
info = self.data[idx]
|
| 201 |
-
|
| 202 |
-
# audio path
|
| 203 |
-
path = info["path"]
|
| 204 |
-
lang_type = info["lang_type"]
|
| 205 |
-
if info["lang_type"] == 'en':
|
| 206 |
-
lyrics = info['lyrics']
|
| 207 |
-
else:
|
| 208 |
-
lyrics = info['lyrics_phone']
|
| 209 |
-
|
| 210 |
-
# 随机选取一个lyric段落
|
| 211 |
-
ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
|
| 212 |
-
lyric = lyrics[ly_id].strip()
|
| 213 |
-
|
| 214 |
-
st, et, lyric = self.parse_lyric(lyric)
|
| 215 |
-
lyric = lyric.replace("\xa0", " ")
|
| 216 |
-
|
| 217 |
-
lyric = " ".join(lyric.split())
|
| 218 |
-
|
| 219 |
-
assert et - st < self.max_dur
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
if info["lang_type"] == 'en':
|
| 223 |
-
# print(len(lyric.split())/(et-st))
|
| 224 |
-
assert 6 > len(lyric.split())/(et-st) > 1
|
| 225 |
-
else:
|
| 226 |
-
# print(len(lyric.split())/(et-st))
|
| 227 |
-
lyric = lyric.replace("-", "")
|
| 228 |
-
assert 6 > len(lyric.split())/(et-st) > 1
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
# 读取音频文件
|
| 232 |
-
cur_sample_rate = torchaudio.info(path).sample_rate
|
| 233 |
-
offset = int(cur_sample_rate*st)
|
| 234 |
-
num_frames = int(cur_sample_rate * (et -st))
|
| 235 |
-
chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
|
| 236 |
-
# chunk = torch.zeros(1, 48000*15)
|
| 237 |
-
|
| 238 |
-
# 随机选取一个channel
|
| 239 |
-
if(chunk.shape[0]>1):
|
| 240 |
-
chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
|
| 241 |
-
else:
|
| 242 |
-
chunk = chunk[[0],:].float()
|
| 243 |
-
|
| 244 |
-
if(cur_sample_rate!=self.sr):
|
| 245 |
-
# print('a:',cur_sample_rate,chunk.shape)
|
| 246 |
-
chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
|
| 247 |
-
|
| 248 |
-
if self.pad_to_max:
|
| 249 |
-
chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0)
|
| 250 |
-
|
| 251 |
-
return chunk, lyric, et-st, path, lang_type
|
| 252 |
-
except:
|
| 253 |
-
# print("Error loadding ", info["path"])
|
| 254 |
-
try_cnt += 1
|
| 255 |
-
idx = np.random.randint(0, len(self.data))
|
| 256 |
-
if(try_cnt>20):
|
| 257 |
-
raise FileNotFoundError()
|
| 258 |
-
|
| 259 |
-
def parse_lyric(self, lyric):
|
| 260 |
-
pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
|
| 261 |
-
match = re.search(pattern, lyric)
|
| 262 |
-
|
| 263 |
-
start_time = float(match.group(1))
|
| 264 |
-
end_time = float(match.group(2))
|
| 265 |
-
content = match.group(3)
|
| 266 |
-
return start_time, end_time, content
|
| 267 |
-
|
| 268 |
-
def pad_2d_tensor(self, x, max_len, pad_id):
|
| 269 |
-
# 获取输入 tensor 的形状
|
| 270 |
-
batch_size, seq_len = x.size()
|
| 271 |
-
max_len = max(max_len, seq_len)
|
| 272 |
-
# 计算需要填充的长度
|
| 273 |
-
pad_len = max_len - seq_len
|
| 274 |
-
|
| 275 |
-
# 如果需要填充
|
| 276 |
-
if pad_len > 0:
|
| 277 |
-
# 创建填充 tensor
|
| 278 |
-
pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
|
| 279 |
-
|
| 280 |
-
# 沿第二个维度(列)连接输入 tensor 和填充 tensor
|
| 281 |
-
padded_tensor = torch.cat([x, pad_tensor], dim=1)
|
| 282 |
-
else:
|
| 283 |
-
# 如果不需要填充,直接返回输入 tensor
|
| 284 |
-
padded_tensor = x
|
| 285 |
-
|
| 286 |
-
return padded_tensor
|
| 287 |
-
|
| 288 |
-
def collect_data(data_list):
|
| 289 |
-
audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
|
| 290 |
-
lyrics = [data[1] for data in data_list]
|
| 291 |
-
st_et = [data[2] for data in data_list]
|
| 292 |
-
paths = [data[3] for data in data_list]
|
| 293 |
-
lang_types = [data[4] for data in data_list]
|
| 294 |
-
return audios, lyrics, st_et, lang_types
|
| 295 |
-
# return audios, lyrics, st_et
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def build_dataset():
|
| 299 |
-
train_dataset = WYYSongDataset(
|
| 300 |
-
metadata_path = "train.jsonl",
|
| 301 |
-
sr = 48000,
|
| 302 |
-
use_lang = ['zh-cn', 'en'],
|
| 303 |
-
num_examples = 10*10000
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
-
valid_dataset = WYYSongDataset(
|
| 307 |
-
metadata_path = "valid.jsonl",
|
| 308 |
-
sr = 48000,
|
| 309 |
-
use_lang = ['zh-cn', 'en'],
|
| 310 |
-
num_examples = 500
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
return train_dataset, valid_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py
DELETED
|
@@ -1,461 +0,0 @@
|
|
| 1 |
-
from torch.utils.data import Dataset
|
| 2 |
-
from beartype.typing import Sequence, Callable, Optional, Dict, List
|
| 3 |
-
from beartype.door import is_bearable
|
| 4 |
-
import random
|
| 5 |
-
import os
|
| 6 |
-
from torchaudio.functional import resample
|
| 7 |
-
import torch
|
| 8 |
-
import typing as tp
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
import torchaudio as ta
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
import soundfile
|
| 13 |
-
import numpy as np
|
| 14 |
-
import json
|
| 15 |
-
import yaml
|
| 16 |
-
import random
|
| 17 |
-
import librosa
|
| 18 |
-
from loguru import logger
|
| 19 |
-
import re
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def _av_read(filepath, seek_time=0, duration=None):
|
| 23 |
-
if duration is not None:
|
| 24 |
-
sr = librosa.get_samplerate(filepath)
|
| 25 |
-
offset = seek_time
|
| 26 |
-
num_samples = int(duration * sr)
|
| 27 |
-
wav, _ = librosa.load(filepath, sr=sr, offset=offset, duration=duration)
|
| 28 |
-
else:
|
| 29 |
-
wav, sr = librosa.load(filepath, sr=None, offset=seek_time)
|
| 30 |
-
|
| 31 |
-
return wav, sr
|
| 32 |
-
|
| 33 |
-
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
| 34 |
-
duration: float = -1., pad: bool = True) -> tp.Tuple[torch.Tensor, int]:
|
| 35 |
-
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
| 36 |
-
|
| 37 |
-
Args:
|
| 38 |
-
filepath (str or Path): Path to audio file to read.
|
| 39 |
-
seek_time (float): Time at which to start reading in the file.
|
| 40 |
-
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
| 41 |
-
pad (bool): Pad output audio if not reaching expected duration.
|
| 42 |
-
Returns:
|
| 43 |
-
tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
|
| 44 |
-
"""
|
| 45 |
-
fp = Path(filepath)
|
| 46 |
-
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
| 47 |
-
# There is some bug with ffmpeg and reading flac
|
| 48 |
-
info = soundfile.info(filepath)
|
| 49 |
-
frames = -1 if duration <= 0 else int(duration * info.samplerate)
|
| 50 |
-
frame_offset = int(seek_time * info.samplerate)
|
| 51 |
-
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
|
| 52 |
-
assert info.samplerate == sr, f"Mismatch of sample rates {info.samplerate} {sr}"
|
| 53 |
-
wav = torch.from_numpy(wav).t().contiguous()
|
| 54 |
-
if len(wav.shape) == 1:
|
| 55 |
-
wav = torch.unsqueeze(wav, 0)
|
| 56 |
-
elif (
|
| 57 |
-
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
|
| 58 |
-
and duration <= 0 and seek_time == 0
|
| 59 |
-
):
|
| 60 |
-
# Torchaudio is faster if we load an entire file at once.
|
| 61 |
-
wav, sr = librosa.load(fp, sr=None, mono=True)
|
| 62 |
-
else:
|
| 63 |
-
wav, sr = _av_read(filepath, seek_time, duration)
|
| 64 |
-
if pad and duration > 0:
|
| 65 |
-
expected_frames = int(duration * sr)
|
| 66 |
-
wav = F.pad(torch.tensor(wav), (0, expected_frames - wav.shape[-1]))
|
| 67 |
-
if not isinstance(wav, torch.Tensor):
|
| 68 |
-
wav = torch.tensor(wav)
|
| 69 |
-
return wav, sr
|
| 70 |
-
|
| 71 |
-
def random_seek_read(filepath, duration):
|
| 72 |
-
if duration > 0:
|
| 73 |
-
total_duration = librosa.get_duration(path=filepath)
|
| 74 |
-
acceptable_start = max(0, total_duration - duration)
|
| 75 |
-
wav, sr = audio_read(filepath, random.uniform(0, acceptable_start), duration, pad=True)
|
| 76 |
-
else:
|
| 77 |
-
wav, sr = audio_read(filepath, 0, -1, pad=False)
|
| 78 |
-
return wav, sr
|
| 79 |
-
|
| 80 |
-
def safe_random_seek_read(filepath, duration, sample_rate):
|
| 81 |
-
try:
|
| 82 |
-
wav, sr = random_seek_read(filepath, duration)
|
| 83 |
-
if sr != sample_rate:
|
| 84 |
-
wav = resample(wav, sr, sample_rate)
|
| 85 |
-
sr = sample_rate
|
| 86 |
-
except Exception as e:
|
| 87 |
-
logger.error(f"Error reading {filepath}: {e}")
|
| 88 |
-
sr = sample_rate
|
| 89 |
-
wav = torch.zeros(sr * max(duration, 0), dtype=torch.float32)
|
| 90 |
-
return wav, sr
|
| 91 |
-
|
| 92 |
-
def read_jsonlike(path: os.PathLike):
|
| 93 |
-
#json or jsonl
|
| 94 |
-
if str(path).endswith(".json"):
|
| 95 |
-
with open(path, 'r', encoding='utf8') as f:
|
| 96 |
-
data = json.load(f)
|
| 97 |
-
return data
|
| 98 |
-
elif str(path).endswith(".jsonl"):
|
| 99 |
-
with open(path, 'r', encoding='utf8') as f:
|
| 100 |
-
data = [json.loads(line) for line in f.readlines()]
|
| 101 |
-
return data
|
| 102 |
-
else:
|
| 103 |
-
raise ValueError("Unknown file format")
|
| 104 |
-
|
| 105 |
-
dist_prob_map = {
|
| 106 |
-
1: (1.0,),
|
| 107 |
-
2: (0.5, 0.5),
|
| 108 |
-
3: (0.3, 0.4, 0.3),
|
| 109 |
-
4: (0.2, 0.3, 0.3, 0.2),
|
| 110 |
-
5: (0.2, 0.2, 0.3, 0.2, 0.1),
|
| 111 |
-
6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
|
| 112 |
-
7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
|
| 113 |
-
8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
|
| 114 |
-
9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
|
| 115 |
-
10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
dist_prob_map_low = {
|
| 119 |
-
1: (1.0,),
|
| 120 |
-
2: (0.8, 0.2),
|
| 121 |
-
3: (0.8, 0.1, 0.1),
|
| 122 |
-
4: (0.7, 0.1, 0.1, 0.1),
|
| 123 |
-
5: (0.7, 0.1, 0.1, 0.05, 0.05),
|
| 124 |
-
6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
_bpm_range_rights = (
|
| 129 |
-
(40, '20-40'),
|
| 130 |
-
(60, '40-60'),
|
| 131 |
-
(66, '60-66'),
|
| 132 |
-
(76, '66-76'),
|
| 133 |
-
(108, '76-108'),
|
| 134 |
-
(120, '108-120'),
|
| 135 |
-
(168, '120-168'),
|
| 136 |
-
(176, '168-176'),
|
| 137 |
-
(200, '176-200')
|
| 138 |
-
)
|
| 139 |
-
_bpm_desc_map = {
|
| 140 |
-
'20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
|
| 141 |
-
'40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
|
| 142 |
-
'60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
|
| 143 |
-
'66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
|
| 144 |
-
'76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
|
| 145 |
-
'108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
|
| 146 |
-
'120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
|
| 147 |
-
'168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
|
| 148 |
-
'176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
|
| 149 |
-
'>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
|
| 150 |
-
}
|
| 151 |
-
_bpm_desc_map_zh = {
|
| 152 |
-
'20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
|
| 153 |
-
'40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
|
| 154 |
-
'60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
|
| 155 |
-
'66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
|
| 156 |
-
'76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
|
| 157 |
-
'108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
|
| 158 |
-
'120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
|
| 159 |
-
'168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
|
| 160 |
-
'176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
|
| 161 |
-
'>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
|
| 162 |
-
}
|
| 163 |
-
def get_bpm_range(bpm):
|
| 164 |
-
bpm = int(bpm)
|
| 165 |
-
for right, tag in _bpm_range_rights:
|
| 166 |
-
if bpm <= right:
|
| 167 |
-
return tag
|
| 168 |
-
return '>200'
|
| 169 |
-
|
| 170 |
-
def gen_bpm_descript(bpm, lang='en'):
|
| 171 |
-
bpm_range = get_bpm_range(bpm)
|
| 172 |
-
if lang == 'en':
|
| 173 |
-
return random.choice(_bpm_desc_map[bpm_range])
|
| 174 |
-
elif lang == 'zh':
|
| 175 |
-
return random.choice(_bpm_desc_map_zh[bpm_range])
|
| 176 |
-
else:
|
| 177 |
-
raise ValueError(f"Unknown language {lang}")
|
| 178 |
-
|
| 179 |
-
def read_translate(translate: Optional[Dict[str, os.PathLike]]):
|
| 180 |
-
if translate is None:
|
| 181 |
-
return None
|
| 182 |
-
return {k: read_jsonlike(path) for k, path in translate.items()}
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def tags_to_desc(tag_list, sep=',') -> str:
|
| 186 |
-
if not isinstance(tag_list, Sequence):
|
| 187 |
-
return str(tag_list)
|
| 188 |
-
if isinstance(tag_list, str):
|
| 189 |
-
return tag_list
|
| 190 |
-
if len(tag_list) <= 0:
|
| 191 |
-
return ''
|
| 192 |
-
elif len(tag_list) <= 5:
|
| 193 |
-
probs = dist_prob_map[len(tag_list)]
|
| 194 |
-
tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
|
| 195 |
-
random.shuffle(tag_list)
|
| 196 |
-
tag_list = tag_list[:tags_num]
|
| 197 |
-
return sep.join(tag_list)
|
| 198 |
-
else:
|
| 199 |
-
probs = dist_prob_map[5]
|
| 200 |
-
tags_num = random.choices(range(1, 6), probs)[0]
|
| 201 |
-
random.shuffle(tag_list)
|
| 202 |
-
tag_list = tag_list[:tags_num]
|
| 203 |
-
return sep.join(tag_list)
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
class PromptTemplate:
|
| 207 |
-
def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
|
| 208 |
-
self.template_text = template_text
|
| 209 |
-
self.tag_map = tag_map
|
| 210 |
-
self.lang = lang
|
| 211 |
-
|
| 212 |
-
@property
|
| 213 |
-
def tags(self):
|
| 214 |
-
return tuple(self.tag_map.keys())
|
| 215 |
-
|
| 216 |
-
def apply(self, **kwargs):
|
| 217 |
-
for tag in list(kwargs.keys()):
|
| 218 |
-
if kwargs[tag] == '':
|
| 219 |
-
kwargs.pop(tag)
|
| 220 |
-
for tag in self.tags:
|
| 221 |
-
if tag in kwargs:
|
| 222 |
-
kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
|
| 223 |
-
else:
|
| 224 |
-
kwargs[tag] = ''
|
| 225 |
-
prompt = self.template_text.format(**kwargs)
|
| 226 |
-
|
| 227 |
-
return self.beautify(prompt)
|
| 228 |
-
|
| 229 |
-
def beautify(self, text):
|
| 230 |
-
if self.lang == 'en':
|
| 231 |
-
return self._beautify_en(text)
|
| 232 |
-
elif self.lang == 'zh':
|
| 233 |
-
return self._beautify_zh(text)
|
| 234 |
-
else:
|
| 235 |
-
raise ValueError(f'Unknown language {self.lang}')
|
| 236 |
-
|
| 237 |
-
@staticmethod
|
| 238 |
-
def _beautify_en(text):
|
| 239 |
-
# no continuous commas without content between them
|
| 240 |
-
text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
|
| 241 |
-
# no continuous whitespace
|
| 242 |
-
text = re.sub(r'\s+', ' ', text)
|
| 243 |
-
# the comma is NOT followed by whitespace, and should be followed by ONE whitespace
|
| 244 |
-
text = re.sub(r'\s+,', r',', text)
|
| 245 |
-
text = re.sub(r',\s+', r', ', text)
|
| 246 |
-
# no whitespace before the full stop
|
| 247 |
-
text = re.sub(r'\s+\.', r'.', text)
|
| 248 |
-
# strip whitespace, comma, and replace ',.'
|
| 249 |
-
text = text.strip(' ,')
|
| 250 |
-
text = text.replace(',.', '.')
|
| 251 |
-
return text
|
| 252 |
-
|
| 253 |
-
@staticmethod
|
| 254 |
-
def _beautify_zh(text):
|
| 255 |
-
# no continuous commas without content between them
|
| 256 |
-
text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
|
| 257 |
-
text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
|
| 258 |
-
# assume there should be NO whitespace in Chinese
|
| 259 |
-
text = re.sub(r'\s+', r'', text)
|
| 260 |
-
# strip whitespace, comma, and replace ',。'
|
| 261 |
-
text = text.strip(', 、')
|
| 262 |
-
text = text.replace(',。', '。')
|
| 263 |
-
return text
|
| 264 |
-
|
| 265 |
-
def __repr__(self):
|
| 266 |
-
return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
|
| 267 |
-
|
| 268 |
-
__str__ = __repr__
|
| 269 |
-
|
| 270 |
-
def parse_prompt_template(prompt_template_text, lang='en'):
|
| 271 |
-
span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
|
| 272 |
-
tag_pattern = re.compile(r'{.+?}', re.DOTALL)
|
| 273 |
-
|
| 274 |
-
template_text = prompt_template_text.strip()
|
| 275 |
-
span_texts = span_pattern.findall(prompt_template_text)
|
| 276 |
-
tag_map = {}
|
| 277 |
-
for span_text in span_texts:
|
| 278 |
-
tag = tag_pattern.findall(span_text)[0].strip('{}')
|
| 279 |
-
tag_map[tag] = span_text
|
| 280 |
-
template_text = template_text.replace(span_text, '{'+tag+'}')
|
| 281 |
-
|
| 282 |
-
return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
|
| 283 |
-
|
| 284 |
-
def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
|
| 285 |
-
with open(path, 'r') as f:
|
| 286 |
-
lines = f.readlines()
|
| 287 |
-
cnt = 0
|
| 288 |
-
pts = []
|
| 289 |
-
for line in lines:
|
| 290 |
-
pt = parse_prompt_template(line, lang=lang)
|
| 291 |
-
cnt += 1
|
| 292 |
-
if len(pt.tags) < num:
|
| 293 |
-
logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
|
| 294 |
-
pts.append(pt)
|
| 295 |
-
|
| 296 |
-
return pts
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
class AudioStockDataset(Dataset):
|
| 300 |
-
def __init__(self,
|
| 301 |
-
num_examples:int,
|
| 302 |
-
metadata_path:str,
|
| 303 |
-
duration:float=60,
|
| 304 |
-
sr:int = 0,
|
| 305 |
-
return_path = False,
|
| 306 |
-
return_audio = True,
|
| 307 |
-
prompt_template_path: os.PathLike = None,
|
| 308 |
-
tag_types = [],
|
| 309 |
-
lang = 'en',
|
| 310 |
-
translate:Optional[Dict[str, os.PathLike]] = None
|
| 311 |
-
):
|
| 312 |
-
self.duration = duration
|
| 313 |
-
self.MAX_DURATION = 360
|
| 314 |
-
self._load_metadata(metadata_path)
|
| 315 |
-
if num_examples > 0:
|
| 316 |
-
self.random_choose = True
|
| 317 |
-
self.dataset_len = num_examples
|
| 318 |
-
else:
|
| 319 |
-
self.random_choose = False
|
| 320 |
-
self.dataset_len = len(self.data)
|
| 321 |
-
self.sr = sr
|
| 322 |
-
self.return_path = return_path
|
| 323 |
-
self.return_audio = return_audio
|
| 324 |
-
|
| 325 |
-
self.use_dynamic_prompt = prompt_template_path is not None
|
| 326 |
-
if self.use_dynamic_prompt:
|
| 327 |
-
self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
|
| 328 |
-
self.tag_types = tag_types
|
| 329 |
-
|
| 330 |
-
self.lang = lang
|
| 331 |
-
self.translate = read_translate(translate)
|
| 332 |
-
|
| 333 |
-
def _load_metadata(self, metadata_path):
|
| 334 |
-
total_len = 0; valid_len = 0
|
| 335 |
-
with open(metadata_path) as fp:
|
| 336 |
-
lines = fp.readlines()
|
| 337 |
-
self.data = []
|
| 338 |
-
for line in lines:
|
| 339 |
-
item = json.loads(line)
|
| 340 |
-
total_len += 1
|
| 341 |
-
if(item['duration']>self.duration and item['duration']<self.MAX_DURATION):
|
| 342 |
-
valid_len += 1
|
| 343 |
-
self.data.append(item)
|
| 344 |
-
print("Filter data from {} to {}".format(total_len, valid_len))
|
| 345 |
-
self.is_info_recorded = bool('Tags' in self.data[0])
|
| 346 |
-
|
| 347 |
-
def __len__(self):
|
| 348 |
-
return self.dataset_len
|
| 349 |
-
|
| 350 |
-
def __getitem__(self, idx):
|
| 351 |
-
first_try = True
|
| 352 |
-
try_cnt = 0
|
| 353 |
-
while True:
|
| 354 |
-
try:
|
| 355 |
-
if(self.random_choose or not first_try):
|
| 356 |
-
index2 = np.random.randint(0,len(self.data))
|
| 357 |
-
else:
|
| 358 |
-
index2 = idx
|
| 359 |
-
first_try = False
|
| 360 |
-
return self.getitem_main(index2)
|
| 361 |
-
except:
|
| 362 |
-
print("Error loadding ", self.data[idx]["path"])
|
| 363 |
-
try_cnt += 1
|
| 364 |
-
if(try_cnt>10):
|
| 365 |
-
raise ValueError()
|
| 366 |
-
|
| 367 |
-
def getitem_main(self, idx):
|
| 368 |
-
path:str = self.data[idx]["path"]
|
| 369 |
-
json_path = path[:path.rfind('.')] + ".json"
|
| 370 |
-
if self.is_info_recorded:
|
| 371 |
-
item = self.data[idx]
|
| 372 |
-
else:
|
| 373 |
-
with open(json_path) as fp:
|
| 374 |
-
item:dict = json.load(fp)
|
| 375 |
-
description = self.generate_description(item)
|
| 376 |
-
if self.return_audio:
|
| 377 |
-
audio, sr = safe_random_seek_read(path, duration=self.duration, sample_rate=self.sr)
|
| 378 |
-
else:
|
| 379 |
-
audio = None
|
| 380 |
-
if self.return_path:
|
| 381 |
-
return audio, description, path
|
| 382 |
-
return audio, description
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
def generate_description(self, item):
|
| 387 |
-
if self.use_dynamic_prompt:
|
| 388 |
-
# dynamically generate prompt from given prompt template
|
| 389 |
-
prompt_template = random.choice(self.prompt_templates)
|
| 390 |
-
description = self.generate_description_dynamic(item, prompt_template)
|
| 391 |
-
else:
|
| 392 |
-
# use ordinary static prompt instead
|
| 393 |
-
description = self.generate_description_ordinary(item)
|
| 394 |
-
return description
|
| 395 |
-
|
| 396 |
-
def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
|
| 397 |
-
exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
|
| 398 |
-
|
| 399 |
-
if len(exists_tag) > 0:
|
| 400 |
-
probs = dist_prob_map[len(exists_tag)]
|
| 401 |
-
tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
|
| 402 |
-
random.shuffle(exists_tag)
|
| 403 |
-
tags = exists_tag[:tags_num]
|
| 404 |
-
tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
|
| 405 |
-
tags_args = self.handle_BPM_tag(tags_args)
|
| 406 |
-
prompt = prompt_template.apply(**tags_args)
|
| 407 |
-
else:
|
| 408 |
-
# no strong tags, use all weak tags instead
|
| 409 |
-
prompt = prompt_template.apply()
|
| 410 |
-
|
| 411 |
-
return prompt
|
| 412 |
-
|
| 413 |
-
def tags_to_desc(self, tag_list, tag_type) -> str:
|
| 414 |
-
if self.lang == 'en':
|
| 415 |
-
return tags_to_desc(tag_list)
|
| 416 |
-
elif self.lang == 'zh':
|
| 417 |
-
if tag_type == 'BPM':
|
| 418 |
-
return tags_to_desc(tag_list, sep='、')
|
| 419 |
-
translator = self.translate[tag_type]
|
| 420 |
-
translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
|
| 421 |
-
return tags_to_desc(translated_tag_list, sep='、')
|
| 422 |
-
|
| 423 |
-
def handle_BPM_tag(self, tags_args):
|
| 424 |
-
if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
|
| 425 |
-
bpm = tags_args["BPM"]
|
| 426 |
-
del tags_args["BPM"]
|
| 427 |
-
tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
|
| 428 |
-
for tag_type in tag_types_used:
|
| 429 |
-
tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
|
| 430 |
-
return tags_args
|
| 431 |
-
|
| 432 |
-
def generate_description_ordinary(self, data, thresh = 0.3):
|
| 433 |
-
if self.lang != 'en':
|
| 434 |
-
raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
|
| 435 |
-
description = f'a piece of music by {data["Artist"]}'
|
| 436 |
-
|
| 437 |
-
# Add genre if available
|
| 438 |
-
if data["Genre"] and random.random() > thresh:
|
| 439 |
-
genres = ', '.join(data["Genre"])
|
| 440 |
-
description += f', belonging to the {genres} genres'
|
| 441 |
-
|
| 442 |
-
# Add moods if available
|
| 443 |
-
if data["Tags"] and random.random() > thresh:
|
| 444 |
-
tags = ', '.join(data["Tags"])
|
| 445 |
-
description += f'. This track contains the tags:{tags}'
|
| 446 |
-
|
| 447 |
-
# Add moods if available
|
| 448 |
-
if data["Mood"] and random.random() > thresh:
|
| 449 |
-
moods = ', '.join(data["Mood"])
|
| 450 |
-
description += f'. This track conveys a {moods} mood.'
|
| 451 |
-
|
| 452 |
-
# Add instruments if available
|
| 453 |
-
if data["Instrument"] and random.random() > thresh:
|
| 454 |
-
instruments = ', '.join(data["Instrument"])
|
| 455 |
-
description += f'. and primarily features the following instruments: {instruments}'
|
| 456 |
-
|
| 457 |
-
# Add a period to end the description
|
| 458 |
-
description += '.'
|
| 459 |
-
|
| 460 |
-
return description
|
| 461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/model_1rvq.py
CHANGED
|
@@ -270,8 +270,6 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 270 |
hubert_layer=None,
|
| 271 |
ssl_layer=None,
|
| 272 |
uncondition=True,
|
| 273 |
-
out_paint=False,
|
| 274 |
-
ssl_path='ckpt/encode-s12k.pt'
|
| 275 |
):
|
| 276 |
super().__init__()
|
| 277 |
|
|
|
|
| 270 |
hubert_layer=None,
|
| 271 |
ssl_layer=None,
|
| 272 |
uncondition=True,
|
|
|
|
|
|
|
| 273 |
):
|
| 274 |
super().__init__()
|
| 275 |
|
codeclm/tokenizer/Flow1dVAE/model_2rvq.py
DELETED
|
@@ -1,774 +0,0 @@
|
|
| 1 |
-
import yaml
|
| 2 |
-
import random
|
| 3 |
-
import inspect
|
| 4 |
-
import numpy as np
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
import typing as tp
|
| 7 |
-
from abc import ABC
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
import torchaudio
|
| 13 |
-
|
| 14 |
-
from einops import repeat
|
| 15 |
-
from tools.torch_tools import wav_to_fbank
|
| 16 |
-
|
| 17 |
-
import diffusers
|
| 18 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 19 |
-
from diffusers import DDPMScheduler
|
| 20 |
-
from models.transformer_2d_flow import Transformer2DModel
|
| 21 |
-
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
|
| 22 |
-
# from tools.get_mulan import get_mulan
|
| 23 |
-
from third_party.wespeaker.extract_embd import XVECModel
|
| 24 |
-
# from libs.rvq2 import RVQEmbedding
|
| 25 |
-
from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
|
| 26 |
-
|
| 27 |
-
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
| 28 |
-
from models_gpt.models.gpt2_config import GPT2Config
|
| 29 |
-
|
| 30 |
-
from torch.cuda.amp import autocast
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
from our_MERT_BESTRQ.test import load_model
|
| 34 |
-
|
| 35 |
-
class HubertModelWithFinalProj(HubertModel):
|
| 36 |
-
def __init__(self, config):
|
| 37 |
-
super().__init__(config)
|
| 38 |
-
|
| 39 |
-
# The final projection layer is only used for backward compatibility.
|
| 40 |
-
# Following https://github.com/auspicious3000/contentvec/issues/6
|
| 41 |
-
# Remove this layer is necessary to achieve the desired outcome.
|
| 42 |
-
print("hidden_size:",config.hidden_size)
|
| 43 |
-
print("classifier_proj_size:",config.classifier_proj_size)
|
| 44 |
-
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
class SampleProcessor(torch.nn.Module):
|
| 48 |
-
def project_sample(self, x: torch.Tensor):
|
| 49 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
| 50 |
-
"""Project back from diffusion space to the actual sample space."""
|
| 51 |
-
return z
|
| 52 |
-
|
| 53 |
-
class Feature1DProcessor(SampleProcessor):
|
| 54 |
-
def __init__(self, dim: int = 100, power_std = 1., \
|
| 55 |
-
num_samples: int = 100_000, cal_num_frames: int = 600):
|
| 56 |
-
super().__init__()
|
| 57 |
-
|
| 58 |
-
self.num_samples = num_samples
|
| 59 |
-
self.dim = dim
|
| 60 |
-
self.power_std = power_std
|
| 61 |
-
self.cal_num_frames = cal_num_frames
|
| 62 |
-
self.register_buffer('counts', torch.zeros(1))
|
| 63 |
-
self.register_buffer('sum_x', torch.zeros(dim))
|
| 64 |
-
self.register_buffer('sum_x2', torch.zeros(dim))
|
| 65 |
-
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
| 66 |
-
self.counts: torch.Tensor
|
| 67 |
-
self.sum_x: torch.Tensor
|
| 68 |
-
self.sum_x2: torch.Tensor
|
| 69 |
-
|
| 70 |
-
@property
|
| 71 |
-
def mean(self):
|
| 72 |
-
mean = self.sum_x / self.counts
|
| 73 |
-
if(self.counts < 10):
|
| 74 |
-
mean = torch.zeros_like(mean)
|
| 75 |
-
return mean
|
| 76 |
-
|
| 77 |
-
@property
|
| 78 |
-
def std(self):
|
| 79 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
| 80 |
-
if(self.counts < 10):
|
| 81 |
-
std = torch.ones_like(std)
|
| 82 |
-
return std
|
| 83 |
-
|
| 84 |
-
@property
|
| 85 |
-
def target_std(self):
|
| 86 |
-
return 1
|
| 87 |
-
|
| 88 |
-
def project_sample(self, x: torch.Tensor):
|
| 89 |
-
assert x.dim() == 3
|
| 90 |
-
if self.counts.item() < self.num_samples:
|
| 91 |
-
self.counts += len(x)
|
| 92 |
-
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
| 93 |
-
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
| 94 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
| 95 |
-
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
| 96 |
-
return x
|
| 97 |
-
|
| 98 |
-
def return_sample(self, x: torch.Tensor):
|
| 99 |
-
assert x.dim() == 3
|
| 100 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
| 101 |
-
# print(rescale, self.mean)
|
| 102 |
-
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
| 103 |
-
return x
|
| 104 |
-
|
| 105 |
-
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
| 106 |
-
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
| 107 |
-
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
| 108 |
-
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
| 109 |
-
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
| 110 |
-
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
| 111 |
-
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
| 112 |
-
else:
|
| 113 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
| 114 |
-
prior_text_mask = prior_text_mask[:,0:len_size]
|
| 115 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
| 116 |
-
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
| 117 |
-
|
| 118 |
-
class BASECFM(torch.nn.Module, ABC):
|
| 119 |
-
def __init__(
|
| 120 |
-
self,
|
| 121 |
-
estimator,
|
| 122 |
-
mlp,
|
| 123 |
-
ssl_layer
|
| 124 |
-
):
|
| 125 |
-
super().__init__()
|
| 126 |
-
self.sigma_min = 1e-4
|
| 127 |
-
|
| 128 |
-
self.estimator = estimator
|
| 129 |
-
self.mlp = mlp
|
| 130 |
-
self.ssl_layer = ssl_layer
|
| 131 |
-
|
| 132 |
-
@torch.inference_mode()
|
| 133 |
-
def forward(self, mu, n_timesteps, temperature=1.0):
|
| 134 |
-
"""Forward diffusion
|
| 135 |
-
|
| 136 |
-
Args:
|
| 137 |
-
mu (torch.Tensor): output of encoder
|
| 138 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 139 |
-
n_timesteps (int): number of diffusion steps
|
| 140 |
-
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 141 |
-
|
| 142 |
-
Returns:
|
| 143 |
-
sample: generated mel-spectrogram
|
| 144 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 145 |
-
"""
|
| 146 |
-
z = torch.randn_like(mu) * temperature
|
| 147 |
-
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 148 |
-
return self.solve_euler(z, t_span=t_span)
|
| 149 |
-
|
| 150 |
-
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
| 151 |
-
"""
|
| 152 |
-
Fixed euler solver for ODEs.
|
| 153 |
-
Args:
|
| 154 |
-
x (torch.Tensor): random noise
|
| 155 |
-
t_span (torch.Tensor): n_timesteps interpolated
|
| 156 |
-
shape: (n_timesteps + 1,)
|
| 157 |
-
mu (torch.Tensor): output of encoder
|
| 158 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 159 |
-
"""
|
| 160 |
-
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 161 |
-
noise = x.clone()
|
| 162 |
-
|
| 163 |
-
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 164 |
-
# Or in future might add like a return_all_steps flag
|
| 165 |
-
sol = []
|
| 166 |
-
|
| 167 |
-
for step in tqdm(range(1, len(t_span))):
|
| 168 |
-
# print("incontext_x.shape:",incontext_x.shape)
|
| 169 |
-
# print("noise.shape:",noise.shape)
|
| 170 |
-
# print("t.shape:",t.shape)
|
| 171 |
-
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
| 172 |
-
if(guidance_scale > 1.0):
|
| 173 |
-
|
| 174 |
-
model_input = torch.cat([ \
|
| 175 |
-
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
| 176 |
-
torch.cat([incontext_x, incontext_x], 0), \
|
| 177 |
-
torch.cat([torch.zeros_like(mu), mu], 0), \
|
| 178 |
-
torch.cat([x, x], 0), \
|
| 179 |
-
], 2)
|
| 180 |
-
timestep=t.unsqueeze(-1).repeat(2)
|
| 181 |
-
|
| 182 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
| 183 |
-
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
| 184 |
-
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
| 185 |
-
else:
|
| 186 |
-
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
| 187 |
-
timestep=t.unsqueeze(-1)
|
| 188 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
| 189 |
-
|
| 190 |
-
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
| 191 |
-
# print("dphi_dt.shape:",dphi_dt.shape)
|
| 192 |
-
# print("x.shape:",x.shape)
|
| 193 |
-
|
| 194 |
-
x = x + dt * dphi_dt
|
| 195 |
-
t = t + dt
|
| 196 |
-
sol.append(x)
|
| 197 |
-
if step < len(t_span) - 1:
|
| 198 |
-
dt = t_span[step + 1] - t
|
| 199 |
-
|
| 200 |
-
return sol[-1]
|
| 201 |
-
|
| 202 |
-
def projection_loss(self,hidden_proj, bestrq_emb):
|
| 203 |
-
bsz = hidden_proj.shape[0]
|
| 204 |
-
|
| 205 |
-
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
| 206 |
-
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
| 207 |
-
|
| 208 |
-
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
| 209 |
-
proj_loss = 1+proj_loss.mean()
|
| 210 |
-
|
| 211 |
-
return proj_loss
|
| 212 |
-
|
| 213 |
-
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
| 214 |
-
"""Computes diffusion loss
|
| 215 |
-
|
| 216 |
-
Args:
|
| 217 |
-
x1 (torch.Tensor): Target
|
| 218 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 219 |
-
mu (torch.Tensor): output of encoder
|
| 220 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 221 |
-
|
| 222 |
-
Returns:
|
| 223 |
-
loss: conditional flow matching loss
|
| 224 |
-
y: conditional flow
|
| 225 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 226 |
-
"""
|
| 227 |
-
b = mu[0].shape[0]
|
| 228 |
-
len_x = x1.shape[2]
|
| 229 |
-
# random timestep
|
| 230 |
-
if(validation_mode):
|
| 231 |
-
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
| 232 |
-
else:
|
| 233 |
-
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
| 234 |
-
# sample noise p(x_0)
|
| 235 |
-
z = torch.randn_like(x1)
|
| 236 |
-
|
| 237 |
-
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 238 |
-
u = x1 - (1 - self.sigma_min) * z
|
| 239 |
-
# print("y.shape:",y.shape)
|
| 240 |
-
#self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
|
| 241 |
-
model_input = torch.cat([*mu,y], 2)
|
| 242 |
-
t=t.squeeze(-1).squeeze(-1)
|
| 243 |
-
# print("model_input.shape:",model_input.shape)
|
| 244 |
-
# print("attention_mask.shape:",attention_mask.shape)
|
| 245 |
-
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
| 246 |
-
hidden_layer = out.hidden_states[self.ssl_layer]
|
| 247 |
-
hidden_proj = self.mlp(hidden_layer)
|
| 248 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
| 249 |
-
# print("mert_emb.shape:",mert_emb.shape)
|
| 250 |
-
# exit()
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
out = out.last_hidden_state
|
| 254 |
-
|
| 255 |
-
out=out[:,:,-len_x:]
|
| 256 |
-
# out=self.proj_out(out)
|
| 257 |
-
|
| 258 |
-
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
| 259 |
-
# print("out.shape",out.shape)
|
| 260 |
-
# print("u.shape",u.shape)
|
| 261 |
-
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
| 262 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
| 263 |
-
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
| 264 |
-
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
| 265 |
-
loss = loss_re + loss_cos * 0.5
|
| 266 |
-
# print("loss_cos:",loss_cos,loss_cos.device)
|
| 267 |
-
print("loss:",loss,loss.device)
|
| 268 |
-
# exit()
|
| 269 |
-
return loss, loss_re, loss_cos
|
| 270 |
-
|
| 271 |
-
class PromptCondAudioDiffusion(nn.Module):
|
| 272 |
-
def __init__(
|
| 273 |
-
self,
|
| 274 |
-
num_channels,
|
| 275 |
-
unet_model_name=None,
|
| 276 |
-
unet_model_config_path=None,
|
| 277 |
-
snr_gamma=None,
|
| 278 |
-
hubert_layer=None,
|
| 279 |
-
ssl_layer=None,
|
| 280 |
-
uncondition=True,
|
| 281 |
-
out_paint=False,
|
| 282 |
-
):
|
| 283 |
-
super().__init__()
|
| 284 |
-
|
| 285 |
-
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
| 286 |
-
|
| 287 |
-
self.unet_model_name = unet_model_name
|
| 288 |
-
self.unet_model_config_path = unet_model_config_path
|
| 289 |
-
self.snr_gamma = snr_gamma
|
| 290 |
-
self.uncondition = uncondition
|
| 291 |
-
self.num_channels = num_channels
|
| 292 |
-
self.hubert_layer = hubert_layer
|
| 293 |
-
self.ssl_layer = ssl_layer
|
| 294 |
-
|
| 295 |
-
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
| 296 |
-
self.normfeat = Feature1DProcessor(dim=64)
|
| 297 |
-
|
| 298 |
-
self.sample_rate = 48000
|
| 299 |
-
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
| 300 |
-
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
| 301 |
-
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
| 302 |
-
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 303 |
-
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 304 |
-
self.bestrq = load_model(
|
| 305 |
-
model_dir='path/to/our-MERT/mert_fairseq',
|
| 306 |
-
checkpoint_dir='checkpoint-120000.pt',
|
| 307 |
-
)
|
| 308 |
-
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
| 309 |
-
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
| 310 |
-
for v in self.bestrq.parameters():v.requires_grad = False
|
| 311 |
-
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 312 |
-
# for v in self.rvq_bestrq_emb.parameters():
|
| 313 |
-
# print(v)
|
| 314 |
-
freeze_parameters='quantizers.0'
|
| 315 |
-
for name, param in self.rvq_bestrq_emb.named_parameters():
|
| 316 |
-
if freeze_parameters in name:
|
| 317 |
-
param.requires_grad = False
|
| 318 |
-
print("Freezing RVQ parameters:", name)
|
| 319 |
-
self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
| 320 |
-
for v in self.hubert.parameters():v.requires_grad = False
|
| 321 |
-
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
| 322 |
-
# self.xvecmodel = XVECModel()
|
| 323 |
-
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
| 324 |
-
unet = GPT2Model(config)
|
| 325 |
-
mlp = nn.Sequential(
|
| 326 |
-
nn.Linear(1200, 1024),
|
| 327 |
-
nn.SiLU(),
|
| 328 |
-
nn.Linear(1024, 1024),
|
| 329 |
-
nn.SiLU(),
|
| 330 |
-
nn.Linear(1024, 768)
|
| 331 |
-
)
|
| 332 |
-
self.set_from = "random"
|
| 333 |
-
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
| 334 |
-
self.mask_emb = torch.nn.Embedding(3, 48)
|
| 335 |
-
print("Transformer initialized from pretrain.")
|
| 336 |
-
torch.cuda.empty_cache()
|
| 337 |
-
# self.unet.set_attn_processor(AttnProcessor2_0())
|
| 338 |
-
# self.unet.set_use_memory_efficient_attention_xformers(True)
|
| 339 |
-
|
| 340 |
-
# self.start_embedding = nn.Parameter(torch.randn(1,1024))
|
| 341 |
-
# self.end_embedding = nn.Parameter(torch.randn(1,1024))
|
| 342 |
-
|
| 343 |
-
def compute_snr(self, timesteps):
|
| 344 |
-
"""
|
| 345 |
-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
| 346 |
-
"""
|
| 347 |
-
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
| 348 |
-
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 349 |
-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 350 |
-
|
| 351 |
-
# Expand the tensors.
|
| 352 |
-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
| 353 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 354 |
-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 355 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 356 |
-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 357 |
-
|
| 358 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 359 |
-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
| 360 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
| 361 |
-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
| 362 |
-
|
| 363 |
-
# Compute SNR.
|
| 364 |
-
snr = (alpha / sigma) ** 2
|
| 365 |
-
return snr
|
| 366 |
-
|
| 367 |
-
def preprocess_audio(self, input_audios, threshold=0.9):
|
| 368 |
-
assert len(input_audios.shape) == 2, input_audios.shape
|
| 369 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
| 370 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
| 371 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| 372 |
-
return input_audios/norm_value.unsqueeze(-1)
|
| 373 |
-
|
| 374 |
-
def extract_wav2vec_embeds(self, input_audios,output_len):
|
| 375 |
-
wav2vec_stride = 2
|
| 376 |
-
|
| 377 |
-
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
| 378 |
-
# print(wav2vec_embeds)
|
| 379 |
-
# print("audio.shape:",input_audios.shape)
|
| 380 |
-
wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
|
| 381 |
-
# print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
|
| 382 |
-
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
| 383 |
-
return wav2vec_embeds_last
|
| 384 |
-
|
| 385 |
-
def extract_mert_embeds(self, input_audios):
|
| 386 |
-
prompt_stride = 3
|
| 387 |
-
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
| 388 |
-
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
| 389 |
-
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
| 390 |
-
mert_emb= prompt_embeds[-1]
|
| 391 |
-
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
|
| 392 |
-
|
| 393 |
-
return mert_emb
|
| 394 |
-
|
| 395 |
-
def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
|
| 396 |
-
self.bestrq.eval()
|
| 397 |
-
# print("audio shape:",input_audio_0.shape)
|
| 398 |
-
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
| 399 |
-
# print("input_wav_mean.shape:",input_wav_mean.shape)
|
| 400 |
-
# input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
|
| 401 |
-
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
| 402 |
-
layer_results = input_wav_mean['layer_results']
|
| 403 |
-
# print("layer_results.shape:",layer_results[layer].shape)
|
| 404 |
-
bestrq_emb = layer_results[layer]
|
| 405 |
-
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
| 406 |
-
#[b,t,1024] t=t/960
|
| 407 |
-
#35.84s->batch,896,1024
|
| 408 |
-
return bestrq_emb
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
def extract_spk_embeds(self, input_audios):
|
| 412 |
-
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
| 413 |
-
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
| 414 |
-
return spk_embeds
|
| 415 |
-
|
| 416 |
-
def extract_lyric_feats(self, lyric):
|
| 417 |
-
with torch.no_grad():
|
| 418 |
-
try:
|
| 419 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
| 420 |
-
except:
|
| 421 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
| 422 |
-
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
| 423 |
-
text_mask = text_mask.to(self.device)
|
| 424 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
| 425 |
-
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
| 426 |
-
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
| 427 |
-
return text_encoder_hidden_states, text_mask
|
| 428 |
-
|
| 429 |
-
def extract_energy_bar(self, input_audios):
|
| 430 |
-
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
| 431 |
-
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
| 432 |
-
else:
|
| 433 |
-
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
| 434 |
-
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
| 435 |
-
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
| 436 |
-
energy_embedding = self.energy_embedding(energy_bar)
|
| 437 |
-
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
| 438 |
-
return energy_embedding
|
| 439 |
-
|
| 440 |
-
def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
|
| 441 |
-
additional_feats = ['spk', 'lyric'], \
|
| 442 |
-
train_rvq=True, train_ssl=False,layer=5):
|
| 443 |
-
if not hasattr(self,"device"):
|
| 444 |
-
self.device = input_audios.device
|
| 445 |
-
if not hasattr(self,"dtype"):
|
| 446 |
-
self.dtype = input_audios.dtype
|
| 447 |
-
device = self.device
|
| 448 |
-
input_audio_0 = input_audios[:,0,:]
|
| 449 |
-
input_audio_1 = input_audios[:,1,:]
|
| 450 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 451 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 452 |
-
input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
|
| 453 |
-
# energy_embedding = self.extract_energy_bar(input_audios)
|
| 454 |
-
# print("energy_embedding.shape:",energy_embedding.shape)
|
| 455 |
-
# with autocast(enabled=False):
|
| 456 |
-
if(train_ssl):
|
| 457 |
-
self.wav2vec.train()
|
| 458 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
| 459 |
-
self.clap_embd_extractor.train()
|
| 460 |
-
prompt_embeds = self.extract_mert_embeds(input_audios)
|
| 461 |
-
if('spk' in additional_feats):
|
| 462 |
-
self.xvecmodel.train()
|
| 463 |
-
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
| 464 |
-
else:
|
| 465 |
-
with torch.no_grad():
|
| 466 |
-
with autocast(enabled=False):
|
| 467 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
| 468 |
-
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
| 469 |
-
|
| 470 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
|
| 471 |
-
|
| 472 |
-
bestrq_emb = bestrq_emb.detach()
|
| 473 |
-
if('lyric' in additional_feats):
|
| 474 |
-
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
| 475 |
-
else:
|
| 476 |
-
text_encoder_hidden_states, text_mask = None, None
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
if(train_rvq):
|
| 480 |
-
random_num=random.random()
|
| 481 |
-
if(random_num<0.6):
|
| 482 |
-
rvq_layer = 1
|
| 483 |
-
elif(random_num<0.8):
|
| 484 |
-
rvq_layer = 2
|
| 485 |
-
else:
|
| 486 |
-
rvq_layer = 4
|
| 487 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
|
| 488 |
-
else:
|
| 489 |
-
bestrq_emb = bestrq_emb.float()
|
| 490 |
-
self.rvq_bestrq_emb.eval()
|
| 491 |
-
# with autocast(enabled=False):
|
| 492 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
| 493 |
-
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
| 494 |
-
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
| 495 |
-
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
| 496 |
-
|
| 497 |
-
commitment_loss = commitment_loss_bestrq_emb
|
| 498 |
-
codebook_loss = codebook_loss_bestrq_emb
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
alpha=1
|
| 502 |
-
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
| 503 |
-
|
| 504 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 505 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
| 506 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
scenario = np.random.choice(['start_seg', 'other_seg'])
|
| 511 |
-
if(scenario == 'other_seg'):
|
| 512 |
-
for binx in range(input_audios.shape[0]):
|
| 513 |
-
# latent_masks[binx,0:64] = 1
|
| 514 |
-
latent_masks[binx,0:random.randint(64,128)] = 1
|
| 515 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
| 516 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 517 |
-
# print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
|
| 518 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
| 519 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
| 520 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
if self.uncondition:
|
| 526 |
-
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
| 527 |
-
if len(mask_indices) > 0:
|
| 528 |
-
quantized_bestrq_emb[mask_indices] = 0
|
| 529 |
-
# print("latents.shape:",latents.shape)
|
| 530 |
-
latents = latents.permute(0,2,1).contiguous()
|
| 531 |
-
latents = self.normfeat.project_sample(latents)
|
| 532 |
-
latents = latents.permute(0,2,1).contiguous()
|
| 533 |
-
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
| 534 |
-
attention_mask=(latent_masks > 0.5)
|
| 535 |
-
B, L = attention_mask.size()
|
| 536 |
-
attention_mask = attention_mask.view(B, 1, L)
|
| 537 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
| 538 |
-
attention_mask = attention_mask.unsqueeze(1)
|
| 539 |
-
# print("incontext_latents.shape:",incontext_latents.shape)
|
| 540 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 541 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
| 542 |
-
#64+48+64+1024
|
| 543 |
-
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
| 544 |
-
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
| 545 |
-
|
| 546 |
-
def init_device_dtype(self, device, dtype):
|
| 547 |
-
self.device = device
|
| 548 |
-
self.dtype = dtype
|
| 549 |
-
|
| 550 |
-
@torch.no_grad()
|
| 551 |
-
def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
|
| 552 |
-
input_audio_0 = input_audios[[0],:]
|
| 553 |
-
input_audio_1 = input_audios[[1],:]
|
| 554 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 555 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 556 |
-
|
| 557 |
-
self.bestrq.eval()
|
| 558 |
-
|
| 559 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 560 |
-
# bestrq_middle = bestrq_middle.detach()
|
| 561 |
-
# bestrq_last = bestrq_last.detach()
|
| 562 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
| 563 |
-
bestrq_emb = bestrq_emb.detach()
|
| 564 |
-
|
| 565 |
-
# self.rvq_bestrq_middle.eval()
|
| 566 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
| 567 |
-
# self.rvq_bestrq_last.eval()
|
| 568 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
| 569 |
-
|
| 570 |
-
self.rvq_bestrq_emb.eval()
|
| 571 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
| 572 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
| 573 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 574 |
-
# exit()
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
if('spk' in additional_feats):
|
| 578 |
-
self.xvecmodel.eval()
|
| 579 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
| 580 |
-
else:
|
| 581 |
-
spk_embeds = None
|
| 582 |
-
|
| 583 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
| 584 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
| 585 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
| 586 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
| 587 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
| 588 |
-
|
| 589 |
-
@torch.no_grad()
|
| 590 |
-
def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
|
| 591 |
-
input_audio_0 = input_audios[:,0,:]
|
| 592 |
-
input_audio_1 = input_audios[:,1,:]
|
| 593 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 594 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 595 |
-
|
| 596 |
-
self.bestrq.eval()
|
| 597 |
-
|
| 598 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 599 |
-
# bestrq_middle = bestrq_middle.detach()
|
| 600 |
-
# bestrq_last = bestrq_last.detach()
|
| 601 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
| 602 |
-
bestrq_emb = bestrq_emb.detach()
|
| 603 |
-
|
| 604 |
-
# self.rvq_bestrq_middle.eval()
|
| 605 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
| 606 |
-
# self.rvq_bestrq_last.eval()
|
| 607 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
| 608 |
-
|
| 609 |
-
self.rvq_bestrq_emb.eval()
|
| 610 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
| 611 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 612 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
| 613 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 614 |
-
# exit()
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
if('spk' in additional_feats):
|
| 618 |
-
self.xvecmodel.eval()
|
| 619 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
| 620 |
-
else:
|
| 621 |
-
spk_embeds = None
|
| 622 |
-
|
| 623 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
| 624 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
| 625 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
| 626 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
| 627 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
| 628 |
-
|
| 629 |
-
@torch.no_grad()
|
| 630 |
-
def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
|
| 631 |
-
input_audio_0 = input_audios[:,0,:]
|
| 632 |
-
input_audio_1 = input_audios[:,1,:]
|
| 633 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 634 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 635 |
-
|
| 636 |
-
self.bestrq.eval()
|
| 637 |
-
|
| 638 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 639 |
-
# bestrq_middle = bestrq_middle.detach()
|
| 640 |
-
# bestrq_last = bestrq_last.detach()
|
| 641 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
| 642 |
-
bestrq_emb = bestrq_emb.detach()
|
| 643 |
-
|
| 644 |
-
# self.rvq_bestrq_middle.eval()
|
| 645 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
| 646 |
-
# self.rvq_bestrq_last.eval()
|
| 647 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
| 648 |
-
|
| 649 |
-
self.rvq_bestrq_emb.eval()
|
| 650 |
-
bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
|
| 651 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
| 652 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 653 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
| 654 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 655 |
-
# exit()
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
if('spk' in additional_feats):
|
| 659 |
-
self.xvecmodel.eval()
|
| 660 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
| 661 |
-
else:
|
| 662 |
-
spk_embeds = None
|
| 663 |
-
|
| 664 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
| 665 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
| 666 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
| 667 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
| 668 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
| 669 |
-
|
| 670 |
-
@torch.no_grad()
|
| 671 |
-
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
|
| 672 |
-
guidance_scale=2, num_steps=20,
|
| 673 |
-
disable_progress=True, scenario='start_seg'):
|
| 674 |
-
classifier_free_guidance = guidance_scale > 1.0
|
| 675 |
-
device = self.device
|
| 676 |
-
dtype = self.dtype
|
| 677 |
-
# codes_bestrq_middle, codes_bestrq_last = codes
|
| 678 |
-
codes_bestrq_emb = codes[0]
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
batch_size = codes_bestrq_emb.shape[0]
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
| 685 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
| 686 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
| 687 |
-
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 688 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
if('spk' in additional_feats):
|
| 694 |
-
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
| 695 |
-
|
| 696 |
-
num_frames = quantized_bestrq_emb.shape[1]
|
| 697 |
-
|
| 698 |
-
num_channels_latents = self.num_channels
|
| 699 |
-
shape = (batch_size, num_frames, 64)
|
| 700 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
| 705 |
-
latent_masks[:,0:latent_length] = 2
|
| 706 |
-
if(scenario=='other_seg'):
|
| 707 |
-
latent_masks[:,0:incontext_length] = 1
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
| 712 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
| 713 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
| 714 |
-
true_latents = self.normfeat.project_sample(true_latents)
|
| 715 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
| 716 |
-
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
| 717 |
-
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
attention_mask=(latent_masks > 0.5)
|
| 721 |
-
B, L = attention_mask.size()
|
| 722 |
-
attention_mask = attention_mask.view(B, 1, L)
|
| 723 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
| 724 |
-
attention_mask = attention_mask.unsqueeze(1)
|
| 725 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
| 726 |
-
|
| 727 |
-
if('spk' in additional_feats):
|
| 728 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
| 729 |
-
additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
|
| 730 |
-
else:
|
| 731 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
| 732 |
-
additional_model_input = torch.cat([quantized_bestrq_emb],1)
|
| 733 |
-
|
| 734 |
-
temperature = 1.0
|
| 735 |
-
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
| 736 |
-
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
| 737 |
-
|
| 738 |
-
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
| 739 |
-
latents = latents.permute(0,2,1).contiguous()
|
| 740 |
-
latents = self.normfeat.return_sample(latents)
|
| 741 |
-
# latents = latents.permute(0,2,1).contiguous()
|
| 742 |
-
return latents
|
| 743 |
-
|
| 744 |
-
@torch.no_grad()
|
| 745 |
-
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
| 746 |
-
disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
|
| 747 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
|
| 748 |
-
|
| 749 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
| 750 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
| 751 |
-
disable_progress=disable_progress,scenario=scenario)
|
| 752 |
-
return latents
|
| 753 |
-
|
| 754 |
-
@torch.no_grad()
|
| 755 |
-
def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
| 756 |
-
disable_progress=True,layer=5,scenario='start_seg'):
|
| 757 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
| 758 |
-
import time
|
| 759 |
-
start = time.time()
|
| 760 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
| 761 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
| 762 |
-
disable_progress=disable_progress,scenario=scenario)
|
| 763 |
-
return latents,time.time()-start
|
| 764 |
-
|
| 765 |
-
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
| 766 |
-
divisor = 4
|
| 767 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
| 768 |
-
if(num_frames%divisor>0):
|
| 769 |
-
num_frames = round(num_frames/float(divisor))*divisor
|
| 770 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
| 771 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
| 772 |
-
return latents
|
| 773 |
-
|
| 774 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/model_4rvq.py
DELETED
|
@@ -1,774 +0,0 @@
|
|
| 1 |
-
import yaml
|
| 2 |
-
import random
|
| 3 |
-
import inspect
|
| 4 |
-
import numpy as np
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
import typing as tp
|
| 7 |
-
from abc import ABC
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
import torchaudio
|
| 13 |
-
|
| 14 |
-
from einops import repeat
|
| 15 |
-
from tools.torch_tools import wav_to_fbank
|
| 16 |
-
|
| 17 |
-
import diffusers
|
| 18 |
-
from diffusers.utils.torch_utils import randn_tensor
|
| 19 |
-
from diffusers import DDPMScheduler
|
| 20 |
-
from models.transformer_2d_flow import Transformer2DModel
|
| 21 |
-
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
|
| 22 |
-
# from tools.get_mulan import get_mulan
|
| 23 |
-
from third_party.wespeaker.extract_embd import XVECModel
|
| 24 |
-
# from libs.rvq2 import RVQEmbedding
|
| 25 |
-
from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
|
| 26 |
-
|
| 27 |
-
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
| 28 |
-
from models_gpt.models.gpt2_config import GPT2Config
|
| 29 |
-
|
| 30 |
-
from torch.cuda.amp import autocast
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
from our_MERT_BESTRQ.test import load_model
|
| 34 |
-
|
| 35 |
-
class HubertModelWithFinalProj(HubertModel):
|
| 36 |
-
def __init__(self, config):
|
| 37 |
-
super().__init__(config)
|
| 38 |
-
|
| 39 |
-
# The final projection layer is only used for backward compatibility.
|
| 40 |
-
# Following https://github.com/auspicious3000/contentvec/issues/6
|
| 41 |
-
# Remove this layer is necessary to achieve the desired outcome.
|
| 42 |
-
print("hidden_size:",config.hidden_size)
|
| 43 |
-
print("classifier_proj_size:",config.classifier_proj_size)
|
| 44 |
-
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
class SampleProcessor(torch.nn.Module):
|
| 48 |
-
def project_sample(self, x: torch.Tensor):
|
| 49 |
-
"""Project the original sample to the 'space' where the diffusion will happen."""
|
| 50 |
-
"""Project back from diffusion space to the actual sample space."""
|
| 51 |
-
return z
|
| 52 |
-
|
| 53 |
-
class Feature1DProcessor(SampleProcessor):
|
| 54 |
-
def __init__(self, dim: int = 100, power_std = 1., \
|
| 55 |
-
num_samples: int = 100_000, cal_num_frames: int = 600):
|
| 56 |
-
super().__init__()
|
| 57 |
-
|
| 58 |
-
self.num_samples = num_samples
|
| 59 |
-
self.dim = dim
|
| 60 |
-
self.power_std = power_std
|
| 61 |
-
self.cal_num_frames = cal_num_frames
|
| 62 |
-
self.register_buffer('counts', torch.zeros(1))
|
| 63 |
-
self.register_buffer('sum_x', torch.zeros(dim))
|
| 64 |
-
self.register_buffer('sum_x2', torch.zeros(dim))
|
| 65 |
-
self.register_buffer('sum_target_x2', torch.zeros(dim))
|
| 66 |
-
self.counts: torch.Tensor
|
| 67 |
-
self.sum_x: torch.Tensor
|
| 68 |
-
self.sum_x2: torch.Tensor
|
| 69 |
-
|
| 70 |
-
@property
|
| 71 |
-
def mean(self):
|
| 72 |
-
mean = self.sum_x / self.counts
|
| 73 |
-
if(self.counts < 10):
|
| 74 |
-
mean = torch.zeros_like(mean)
|
| 75 |
-
return mean
|
| 76 |
-
|
| 77 |
-
@property
|
| 78 |
-
def std(self):
|
| 79 |
-
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
| 80 |
-
if(self.counts < 10):
|
| 81 |
-
std = torch.ones_like(std)
|
| 82 |
-
return std
|
| 83 |
-
|
| 84 |
-
@property
|
| 85 |
-
def target_std(self):
|
| 86 |
-
return 1
|
| 87 |
-
|
| 88 |
-
def project_sample(self, x: torch.Tensor):
|
| 89 |
-
assert x.dim() == 3
|
| 90 |
-
if self.counts.item() < self.num_samples:
|
| 91 |
-
self.counts += len(x)
|
| 92 |
-
self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
|
| 93 |
-
self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
|
| 94 |
-
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
| 95 |
-
x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
|
| 96 |
-
return x
|
| 97 |
-
|
| 98 |
-
def return_sample(self, x: torch.Tensor):
|
| 99 |
-
assert x.dim() == 3
|
| 100 |
-
rescale = (self.std / self.target_std) ** self.power_std
|
| 101 |
-
# print(rescale, self.mean)
|
| 102 |
-
x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
|
| 103 |
-
return x
|
| 104 |
-
|
| 105 |
-
def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
|
| 106 |
-
if(prior_text_encoder_hidden_states.shape[1]<len_size):
|
| 107 |
-
prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
|
| 108 |
-
torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
|
| 109 |
-
prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
|
| 110 |
-
dtype=prior_text_encoder_hidden_states.dtype)],1)
|
| 111 |
-
prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
|
| 112 |
-
else:
|
| 113 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
|
| 114 |
-
prior_text_mask = prior_text_mask[:,0:len_size]
|
| 115 |
-
prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
|
| 116 |
-
return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
|
| 117 |
-
|
| 118 |
-
class BASECFM(torch.nn.Module, ABC):
|
| 119 |
-
def __init__(
|
| 120 |
-
self,
|
| 121 |
-
estimator,
|
| 122 |
-
mlp,
|
| 123 |
-
ssl_layer
|
| 124 |
-
):
|
| 125 |
-
super().__init__()
|
| 126 |
-
self.sigma_min = 1e-4
|
| 127 |
-
|
| 128 |
-
self.estimator = estimator
|
| 129 |
-
self.mlp = mlp
|
| 130 |
-
self.ssl_layer = ssl_layer
|
| 131 |
-
|
| 132 |
-
@torch.inference_mode()
|
| 133 |
-
def forward(self, mu, n_timesteps, temperature=1.0):
|
| 134 |
-
"""Forward diffusion
|
| 135 |
-
|
| 136 |
-
Args:
|
| 137 |
-
mu (torch.Tensor): output of encoder
|
| 138 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 139 |
-
n_timesteps (int): number of diffusion steps
|
| 140 |
-
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 141 |
-
|
| 142 |
-
Returns:
|
| 143 |
-
sample: generated mel-spectrogram
|
| 144 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 145 |
-
"""
|
| 146 |
-
z = torch.randn_like(mu) * temperature
|
| 147 |
-
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 148 |
-
return self.solve_euler(z, t_span=t_span)
|
| 149 |
-
|
| 150 |
-
def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
|
| 151 |
-
"""
|
| 152 |
-
Fixed euler solver for ODEs.
|
| 153 |
-
Args:
|
| 154 |
-
x (torch.Tensor): random noise
|
| 155 |
-
t_span (torch.Tensor): n_timesteps interpolated
|
| 156 |
-
shape: (n_timesteps + 1,)
|
| 157 |
-
mu (torch.Tensor): output of encoder
|
| 158 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 159 |
-
"""
|
| 160 |
-
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 161 |
-
noise = x.clone()
|
| 162 |
-
|
| 163 |
-
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 164 |
-
# Or in future might add like a return_all_steps flag
|
| 165 |
-
sol = []
|
| 166 |
-
|
| 167 |
-
for step in tqdm(range(1, len(t_span))):
|
| 168 |
-
print("incontext_x.shape:",incontext_x.shape)
|
| 169 |
-
print("noise.shape:",noise.shape)
|
| 170 |
-
print("t.shape:",t.shape)
|
| 171 |
-
x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
|
| 172 |
-
if(guidance_scale > 1.0):
|
| 173 |
-
|
| 174 |
-
model_input = torch.cat([ \
|
| 175 |
-
torch.cat([latent_mask_input, latent_mask_input], 0), \
|
| 176 |
-
torch.cat([incontext_x, incontext_x], 0), \
|
| 177 |
-
torch.cat([torch.zeros_like(mu), mu], 0), \
|
| 178 |
-
torch.cat([x, x], 0), \
|
| 179 |
-
], 2)
|
| 180 |
-
timestep=t.unsqueeze(-1).repeat(2)
|
| 181 |
-
|
| 182 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
| 183 |
-
dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
|
| 184 |
-
dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
|
| 185 |
-
else:
|
| 186 |
-
model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
|
| 187 |
-
timestep=t.unsqueeze(-1)
|
| 188 |
-
dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
|
| 189 |
-
|
| 190 |
-
dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
|
| 191 |
-
print("dphi_dt.shape:",dphi_dt.shape)
|
| 192 |
-
print("x.shape:",x.shape)
|
| 193 |
-
|
| 194 |
-
x = x + dt * dphi_dt
|
| 195 |
-
t = t + dt
|
| 196 |
-
sol.append(x)
|
| 197 |
-
if step < len(t_span) - 1:
|
| 198 |
-
dt = t_span[step + 1] - t
|
| 199 |
-
|
| 200 |
-
return sol[-1]
|
| 201 |
-
|
| 202 |
-
def projection_loss(self,hidden_proj, bestrq_emb):
|
| 203 |
-
bsz = hidden_proj.shape[0]
|
| 204 |
-
|
| 205 |
-
hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
|
| 206 |
-
bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
|
| 207 |
-
|
| 208 |
-
proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
|
| 209 |
-
proj_loss = 1+proj_loss.mean()
|
| 210 |
-
|
| 211 |
-
return proj_loss
|
| 212 |
-
|
| 213 |
-
def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
|
| 214 |
-
"""Computes diffusion loss
|
| 215 |
-
|
| 216 |
-
Args:
|
| 217 |
-
x1 (torch.Tensor): Target
|
| 218 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 219 |
-
mu (torch.Tensor): output of encoder
|
| 220 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 221 |
-
|
| 222 |
-
Returns:
|
| 223 |
-
loss: conditional flow matching loss
|
| 224 |
-
y: conditional flow
|
| 225 |
-
shape: (batch_size, n_channels, mel_timesteps, n_feats)
|
| 226 |
-
"""
|
| 227 |
-
b = mu[0].shape[0]
|
| 228 |
-
len_x = x1.shape[2]
|
| 229 |
-
# random timestep
|
| 230 |
-
if(validation_mode):
|
| 231 |
-
t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
|
| 232 |
-
else:
|
| 233 |
-
t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
|
| 234 |
-
# sample noise p(x_0)
|
| 235 |
-
z = torch.randn_like(x1)
|
| 236 |
-
|
| 237 |
-
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 238 |
-
u = x1 - (1 - self.sigma_min) * z
|
| 239 |
-
# print("y.shape:",y.shape)
|
| 240 |
-
#self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
|
| 241 |
-
model_input = torch.cat([*mu,y], 2)
|
| 242 |
-
t=t.squeeze(-1).squeeze(-1)
|
| 243 |
-
# print("model_input.shape:",model_input.shape)
|
| 244 |
-
# print("attention_mask.shape:",attention_mask.shape)
|
| 245 |
-
out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
|
| 246 |
-
hidden_layer = out.hidden_states[self.ssl_layer]
|
| 247 |
-
hidden_proj = self.mlp(hidden_layer)
|
| 248 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
| 249 |
-
# print("mert_emb.shape:",mert_emb.shape)
|
| 250 |
-
# exit()
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
out = out.last_hidden_state
|
| 254 |
-
|
| 255 |
-
out=out[:,:,-len_x:]
|
| 256 |
-
# out=self.proj_out(out)
|
| 257 |
-
|
| 258 |
-
weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
|
| 259 |
-
# print("out.shape",out.shape)
|
| 260 |
-
# print("u.shape",u.shape)
|
| 261 |
-
loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
|
| 262 |
-
# print("hidden_proj.shape:",hidden_proj.shape)
|
| 263 |
-
# print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
|
| 264 |
-
loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
|
| 265 |
-
loss = loss_re + loss_cos * 0.5
|
| 266 |
-
# print("loss_cos:",loss_cos,loss_cos.device)
|
| 267 |
-
print("loss:",loss,loss.device)
|
| 268 |
-
# exit()
|
| 269 |
-
return loss, loss_re, loss_cos
|
| 270 |
-
|
| 271 |
-
class PromptCondAudioDiffusion(nn.Module):
|
| 272 |
-
def __init__(
|
| 273 |
-
self,
|
| 274 |
-
num_channels,
|
| 275 |
-
unet_model_name=None,
|
| 276 |
-
unet_model_config_path=None,
|
| 277 |
-
snr_gamma=None,
|
| 278 |
-
hubert_layer=None,
|
| 279 |
-
ssl_layer=None,
|
| 280 |
-
uncondition=True,
|
| 281 |
-
out_paint=False,
|
| 282 |
-
):
|
| 283 |
-
super().__init__()
|
| 284 |
-
|
| 285 |
-
assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
|
| 286 |
-
|
| 287 |
-
self.unet_model_name = unet_model_name
|
| 288 |
-
self.unet_model_config_path = unet_model_config_path
|
| 289 |
-
self.snr_gamma = snr_gamma
|
| 290 |
-
self.uncondition = uncondition
|
| 291 |
-
self.num_channels = num_channels
|
| 292 |
-
self.hubert_layer = hubert_layer
|
| 293 |
-
self.ssl_layer = ssl_layer
|
| 294 |
-
|
| 295 |
-
# https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
|
| 296 |
-
self.normfeat = Feature1DProcessor(dim=64)
|
| 297 |
-
|
| 298 |
-
self.sample_rate = 48000
|
| 299 |
-
self.num_samples_perseg = self.sample_rate * 20 // 1000
|
| 300 |
-
self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
|
| 301 |
-
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
| 302 |
-
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 303 |
-
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 304 |
-
self.bestrq = load_model(
|
| 305 |
-
model_dir='path/to/our-MERT/mert_fairseq',
|
| 306 |
-
checkpoint_dir='checkpoint-120000.pt',
|
| 307 |
-
)
|
| 308 |
-
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
| 309 |
-
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
| 310 |
-
for v in self.bestrq.parameters():v.requires_grad = False
|
| 311 |
-
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 312 |
-
# for v in self.rvq_bestrq_emb.parameters():
|
| 313 |
-
# print(v)
|
| 314 |
-
freeze_parameters='quantizers.0'
|
| 315 |
-
for name, param in self.rvq_bestrq_emb.named_parameters():
|
| 316 |
-
if freeze_parameters in name:
|
| 317 |
-
param.requires_grad = False
|
| 318 |
-
print("Freezing RVQ parameters:", name)
|
| 319 |
-
self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
| 320 |
-
for v in self.hubert.parameters():v.requires_grad = False
|
| 321 |
-
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
| 322 |
-
# self.xvecmodel = XVECModel()
|
| 323 |
-
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
| 324 |
-
unet = GPT2Model(config)
|
| 325 |
-
mlp = nn.Sequential(
|
| 326 |
-
nn.Linear(1200, 1024),
|
| 327 |
-
nn.SiLU(),
|
| 328 |
-
nn.Linear(1024, 1024),
|
| 329 |
-
nn.SiLU(),
|
| 330 |
-
nn.Linear(1024, 768)
|
| 331 |
-
)
|
| 332 |
-
self.set_from = "random"
|
| 333 |
-
self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
| 334 |
-
self.mask_emb = torch.nn.Embedding(3, 48)
|
| 335 |
-
print("Transformer initialized from pretrain.")
|
| 336 |
-
torch.cuda.empty_cache()
|
| 337 |
-
# self.unet.set_attn_processor(AttnProcessor2_0())
|
| 338 |
-
# self.unet.set_use_memory_efficient_attention_xformers(True)
|
| 339 |
-
|
| 340 |
-
# self.start_embedding = nn.Parameter(torch.randn(1,1024))
|
| 341 |
-
# self.end_embedding = nn.Parameter(torch.randn(1,1024))
|
| 342 |
-
|
| 343 |
-
def compute_snr(self, timesteps):
|
| 344 |
-
"""
|
| 345 |
-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
| 346 |
-
"""
|
| 347 |
-
alphas_cumprod = self.noise_scheduler.alphas_cumprod
|
| 348 |
-
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 349 |
-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 350 |
-
|
| 351 |
-
# Expand the tensors.
|
| 352 |
-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
| 353 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 354 |
-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 355 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 356 |
-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 357 |
-
|
| 358 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 359 |
-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
| 360 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
| 361 |
-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
| 362 |
-
|
| 363 |
-
# Compute SNR.
|
| 364 |
-
snr = (alpha / sigma) ** 2
|
| 365 |
-
return snr
|
| 366 |
-
|
| 367 |
-
def preprocess_audio(self, input_audios, threshold=0.9):
|
| 368 |
-
assert len(input_audios.shape) == 2, input_audios.shape
|
| 369 |
-
norm_value = torch.ones_like(input_audios[:,0])
|
| 370 |
-
max_volume = input_audios.abs().max(dim=-1)[0]
|
| 371 |
-
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| 372 |
-
return input_audios/norm_value.unsqueeze(-1)
|
| 373 |
-
|
| 374 |
-
def extract_wav2vec_embeds(self, input_audios,output_len):
|
| 375 |
-
wav2vec_stride = 2
|
| 376 |
-
|
| 377 |
-
wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
|
| 378 |
-
# print(wav2vec_embeds)
|
| 379 |
-
# print("audio.shape:",input_audios.shape)
|
| 380 |
-
wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
|
| 381 |
-
# print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
|
| 382 |
-
wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
|
| 383 |
-
return wav2vec_embeds_last
|
| 384 |
-
|
| 385 |
-
def extract_mert_embeds(self, input_audios):
|
| 386 |
-
prompt_stride = 3
|
| 387 |
-
inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
|
| 388 |
-
input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
|
| 389 |
-
prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
|
| 390 |
-
mert_emb= prompt_embeds[-1]
|
| 391 |
-
mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
|
| 392 |
-
|
| 393 |
-
return mert_emb
|
| 394 |
-
|
| 395 |
-
def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
|
| 396 |
-
self.bestrq.eval()
|
| 397 |
-
# print("audio shape:",input_audio_0.shape)
|
| 398 |
-
input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
|
| 399 |
-
# print("input_wav_mean.shape:",input_wav_mean.shape)
|
| 400 |
-
# input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
|
| 401 |
-
input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
|
| 402 |
-
layer_results = input_wav_mean['layer_results']
|
| 403 |
-
# print("layer_results.shape:",layer_results[layer].shape)
|
| 404 |
-
bestrq_emb = layer_results[layer]
|
| 405 |
-
bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
|
| 406 |
-
#[b,t,1024] t=t/960
|
| 407 |
-
#35.84s->batch,896,1024
|
| 408 |
-
return bestrq_emb
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
def extract_spk_embeds(self, input_audios):
|
| 412 |
-
spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
|
| 413 |
-
spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
|
| 414 |
-
return spk_embeds
|
| 415 |
-
|
| 416 |
-
def extract_lyric_feats(self, lyric):
|
| 417 |
-
with torch.no_grad():
|
| 418 |
-
try:
|
| 419 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
|
| 420 |
-
except:
|
| 421 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
|
| 422 |
-
text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
|
| 423 |
-
text_mask = text_mask.to(self.device)
|
| 424 |
-
text_encoder_hidden_states, text_mask, text_prompt_embeds = \
|
| 425 |
-
pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
|
| 426 |
-
text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
|
| 427 |
-
return text_encoder_hidden_states, text_mask
|
| 428 |
-
|
| 429 |
-
def extract_energy_bar(self, input_audios):
|
| 430 |
-
if(input_audios.shape[-1] % self.num_samples_perseg > 0):
|
| 431 |
-
energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
| 432 |
-
else:
|
| 433 |
-
energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
|
| 434 |
-
energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
|
| 435 |
-
energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
|
| 436 |
-
energy_embedding = self.energy_embedding(energy_bar)
|
| 437 |
-
energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
|
| 438 |
-
return energy_embedding
|
| 439 |
-
|
| 440 |
-
def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
|
| 441 |
-
additional_feats = ['spk', 'lyric'], \
|
| 442 |
-
train_rvq=True, train_ssl=False,layer=5):
|
| 443 |
-
if not hasattr(self,"device"):
|
| 444 |
-
self.device = input_audios.device
|
| 445 |
-
if not hasattr(self,"dtype"):
|
| 446 |
-
self.dtype = input_audios.dtype
|
| 447 |
-
device = self.device
|
| 448 |
-
input_audio_0 = input_audios[:,0,:]
|
| 449 |
-
input_audio_1 = input_audios[:,1,:]
|
| 450 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 451 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 452 |
-
input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
|
| 453 |
-
# energy_embedding = self.extract_energy_bar(input_audios)
|
| 454 |
-
# print("energy_embedding.shape:",energy_embedding.shape)
|
| 455 |
-
# with autocast(enabled=False):
|
| 456 |
-
if(train_ssl):
|
| 457 |
-
self.wav2vec.train()
|
| 458 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
|
| 459 |
-
self.clap_embd_extractor.train()
|
| 460 |
-
prompt_embeds = self.extract_mert_embeds(input_audios)
|
| 461 |
-
if('spk' in additional_feats):
|
| 462 |
-
self.xvecmodel.train()
|
| 463 |
-
spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
|
| 464 |
-
else:
|
| 465 |
-
with torch.no_grad():
|
| 466 |
-
with autocast(enabled=False):
|
| 467 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
| 468 |
-
# mert_emb = self.extract_mert_embeds(input_audios_mert)
|
| 469 |
-
|
| 470 |
-
wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
|
| 471 |
-
|
| 472 |
-
bestrq_emb = bestrq_emb.detach()
|
| 473 |
-
if('lyric' in additional_feats):
|
| 474 |
-
text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
|
| 475 |
-
else:
|
| 476 |
-
text_encoder_hidden_states, text_mask = None, None
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
if(train_rvq):
|
| 480 |
-
random_num=random.random()
|
| 481 |
-
if(random_num<0.6):
|
| 482 |
-
rvq_layer = 1
|
| 483 |
-
elif(random_num<0.8):
|
| 484 |
-
rvq_layer = 2
|
| 485 |
-
else:
|
| 486 |
-
rvq_layer = 4
|
| 487 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
|
| 488 |
-
else:
|
| 489 |
-
bestrq_emb = bestrq_emb.float()
|
| 490 |
-
self.rvq_bestrq_emb.eval()
|
| 491 |
-
# with autocast(enabled=False):
|
| 492 |
-
quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
|
| 493 |
-
commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
|
| 494 |
-
codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
|
| 495 |
-
quantized_bestrq_emb = quantized_bestrq_emb.detach()
|
| 496 |
-
|
| 497 |
-
commitment_loss = commitment_loss_bestrq_emb
|
| 498 |
-
codebook_loss = codebook_loss_bestrq_emb
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
alpha=1
|
| 502 |
-
quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
|
| 503 |
-
|
| 504 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 505 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
| 506 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
scenario = np.random.choice(['start_seg', 'other_seg'])
|
| 511 |
-
if(scenario == 'other_seg'):
|
| 512 |
-
for binx in range(input_audios.shape[0]):
|
| 513 |
-
# latent_masks[binx,0:64] = 1
|
| 514 |
-
latent_masks[binx,0:random.randint(64,128)] = 1
|
| 515 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
| 516 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 517 |
-
# print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
|
| 518 |
-
# print("latent_masks.shape:",latent_masks.shape)
|
| 519 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
| 520 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
if self.uncondition:
|
| 526 |
-
mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
|
| 527 |
-
if len(mask_indices) > 0:
|
| 528 |
-
quantized_bestrq_emb[mask_indices] = 0
|
| 529 |
-
# print("latents.shape:",latents.shape)
|
| 530 |
-
latents = latents.permute(0,2,1).contiguous()
|
| 531 |
-
latents = self.normfeat.project_sample(latents)
|
| 532 |
-
latents = latents.permute(0,2,1).contiguous()
|
| 533 |
-
incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
| 534 |
-
attention_mask=(latent_masks > 0.5)
|
| 535 |
-
B, L = attention_mask.size()
|
| 536 |
-
attention_mask = attention_mask.view(B, 1, L)
|
| 537 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
| 538 |
-
attention_mask = attention_mask.unsqueeze(1)
|
| 539 |
-
# print("incontext_latents.shape:",incontext_latents.shape)
|
| 540 |
-
# print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 541 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
| 542 |
-
#64+48+64+1024
|
| 543 |
-
loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
|
| 544 |
-
return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
|
| 545 |
-
|
| 546 |
-
def init_device_dtype(self, device, dtype):
|
| 547 |
-
self.device = device
|
| 548 |
-
self.dtype = dtype
|
| 549 |
-
|
| 550 |
-
@torch.no_grad()
|
| 551 |
-
def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
|
| 552 |
-
input_audio_0 = input_audios[[0],:]
|
| 553 |
-
input_audio_1 = input_audios[[1],:]
|
| 554 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 555 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 556 |
-
|
| 557 |
-
self.bestrq.eval()
|
| 558 |
-
|
| 559 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 560 |
-
# bestrq_middle = bestrq_middle.detach()
|
| 561 |
-
# bestrq_last = bestrq_last.detach()
|
| 562 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
| 563 |
-
bestrq_emb = bestrq_emb.detach()
|
| 564 |
-
|
| 565 |
-
# self.rvq_bestrq_middle.eval()
|
| 566 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
| 567 |
-
# self.rvq_bestrq_last.eval()
|
| 568 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
| 569 |
-
|
| 570 |
-
self.rvq_bestrq_emb.eval()
|
| 571 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
| 572 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
| 573 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 574 |
-
# exit()
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
if('spk' in additional_feats):
|
| 578 |
-
self.xvecmodel.eval()
|
| 579 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
| 580 |
-
else:
|
| 581 |
-
spk_embeds = None
|
| 582 |
-
|
| 583 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
| 584 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
| 585 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
| 586 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
| 587 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
| 588 |
-
|
| 589 |
-
@torch.no_grad()
|
| 590 |
-
def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
|
| 591 |
-
input_audio_0 = input_audios[:,0,:]
|
| 592 |
-
input_audio_1 = input_audios[:,1,:]
|
| 593 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 594 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 595 |
-
|
| 596 |
-
self.bestrq.eval()
|
| 597 |
-
|
| 598 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 599 |
-
# bestrq_middle = bestrq_middle.detach()
|
| 600 |
-
# bestrq_last = bestrq_last.detach()
|
| 601 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
| 602 |
-
bestrq_emb = bestrq_emb.detach()
|
| 603 |
-
|
| 604 |
-
# self.rvq_bestrq_middle.eval()
|
| 605 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
| 606 |
-
# self.rvq_bestrq_last.eval()
|
| 607 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
| 608 |
-
|
| 609 |
-
self.rvq_bestrq_emb.eval()
|
| 610 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
| 611 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 612 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
| 613 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 614 |
-
# exit()
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
if('spk' in additional_feats):
|
| 618 |
-
self.xvecmodel.eval()
|
| 619 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
| 620 |
-
else:
|
| 621 |
-
spk_embeds = None
|
| 622 |
-
|
| 623 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
| 624 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
| 625 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
| 626 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
| 627 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
| 628 |
-
|
| 629 |
-
@torch.no_grad()
|
| 630 |
-
def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
|
| 631 |
-
input_audio_0 = input_audios[:,0,:]
|
| 632 |
-
input_audio_1 = input_audios[:,1,:]
|
| 633 |
-
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 634 |
-
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 635 |
-
|
| 636 |
-
self.bestrq.eval()
|
| 637 |
-
|
| 638 |
-
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 639 |
-
# bestrq_middle = bestrq_middle.detach()
|
| 640 |
-
# bestrq_last = bestrq_last.detach()
|
| 641 |
-
bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
|
| 642 |
-
bestrq_emb = bestrq_emb.detach()
|
| 643 |
-
|
| 644 |
-
# self.rvq_bestrq_middle.eval()
|
| 645 |
-
# quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
|
| 646 |
-
# self.rvq_bestrq_last.eval()
|
| 647 |
-
# quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
|
| 648 |
-
|
| 649 |
-
self.rvq_bestrq_emb.eval()
|
| 650 |
-
bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
|
| 651 |
-
quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
|
| 652 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 653 |
-
codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
|
| 654 |
-
# print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
|
| 655 |
-
# exit()
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
if('spk' in additional_feats):
|
| 659 |
-
self.xvecmodel.eval()
|
| 660 |
-
spk_embeds = self.extract_spk_embeds(input_audios)
|
| 661 |
-
else:
|
| 662 |
-
spk_embeds = None
|
| 663 |
-
|
| 664 |
-
# return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
|
| 665 |
-
# return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
|
| 666 |
-
# return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
|
| 667 |
-
return [codes_bestrq_emb], [bestrq_emb], spk_embeds
|
| 668 |
-
# return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
|
| 669 |
-
|
| 670 |
-
@torch.no_grad()
|
| 671 |
-
def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
|
| 672 |
-
guidance_scale=2, num_steps=20,
|
| 673 |
-
disable_progress=True, scenario='start_seg'):
|
| 674 |
-
classifier_free_guidance = guidance_scale > 1.0
|
| 675 |
-
device = self.device
|
| 676 |
-
dtype = self.dtype
|
| 677 |
-
# codes_bestrq_middle, codes_bestrq_last = codes
|
| 678 |
-
codes_bestrq_emb = codes[0]
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
batch_size = codes_bestrq_emb.shape[0]
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
|
| 685 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
| 686 |
-
quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
|
| 687 |
-
print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
|
| 688 |
-
# quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
if('spk' in additional_feats):
|
| 694 |
-
spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
|
| 695 |
-
|
| 696 |
-
num_frames = quantized_bestrq_emb.shape[1]
|
| 697 |
-
|
| 698 |
-
num_channels_latents = self.num_channels
|
| 699 |
-
shape = (batch_size, num_frames, 64)
|
| 700 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
|
| 705 |
-
latent_masks[:,0:latent_length] = 2
|
| 706 |
-
if(scenario=='other_seg'):
|
| 707 |
-
latent_masks[:,0:incontext_length] = 1
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
|
| 712 |
-
+ (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
|
| 713 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
| 714 |
-
true_latents = self.normfeat.project_sample(true_latents)
|
| 715 |
-
true_latents = true_latents.permute(0,2,1).contiguous()
|
| 716 |
-
incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
|
| 717 |
-
incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
attention_mask=(latent_masks > 0.5)
|
| 721 |
-
B, L = attention_mask.size()
|
| 722 |
-
attention_mask = attention_mask.view(B, 1, L)
|
| 723 |
-
attention_mask = attention_mask * attention_mask.transpose(-1, -2)
|
| 724 |
-
attention_mask = attention_mask.unsqueeze(1)
|
| 725 |
-
latent_mask_input = self.mask_emb(latent_masks)
|
| 726 |
-
|
| 727 |
-
if('spk' in additional_feats):
|
| 728 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
|
| 729 |
-
additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
|
| 730 |
-
else:
|
| 731 |
-
# additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
|
| 732 |
-
additional_model_input = torch.cat([quantized_bestrq_emb],1)
|
| 733 |
-
|
| 734 |
-
temperature = 1.0
|
| 735 |
-
t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
|
| 736 |
-
latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
|
| 737 |
-
|
| 738 |
-
latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
|
| 739 |
-
latents = latents.permute(0,2,1).contiguous()
|
| 740 |
-
latents = self.normfeat.return_sample(latents)
|
| 741 |
-
# latents = latents.permute(0,2,1).contiguous()
|
| 742 |
-
return latents
|
| 743 |
-
|
| 744 |
-
@torch.no_grad()
|
| 745 |
-
def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
| 746 |
-
disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
|
| 747 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
|
| 748 |
-
|
| 749 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
| 750 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
| 751 |
-
disable_progress=disable_progress,scenario=scenario)
|
| 752 |
-
return latents
|
| 753 |
-
|
| 754 |
-
@torch.no_grad()
|
| 755 |
-
def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
|
| 756 |
-
disable_progress=True,layer=5,scenario='start_seg'):
|
| 757 |
-
codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
|
| 758 |
-
import time
|
| 759 |
-
start = time.time()
|
| 760 |
-
latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
|
| 761 |
-
guidance_scale=guidance_scale, num_steps=num_steps, \
|
| 762 |
-
disable_progress=disable_progress,scenario=scenario)
|
| 763 |
-
return latents,time.time()-start
|
| 764 |
-
|
| 765 |
-
def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
|
| 766 |
-
divisor = 4
|
| 767 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
| 768 |
-
if(num_frames%divisor>0):
|
| 769 |
-
num_frames = round(num_frames/float(divisor))*divisor
|
| 770 |
-
shape = (batch_size, num_channels_latents, num_frames, 32)
|
| 771 |
-
latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
|
| 772 |
-
return latents
|
| 773 |
-
|
| 774 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/model_septoken.py
CHANGED
|
@@ -252,8 +252,6 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 252 |
unet_model_config_path=None,
|
| 253 |
snr_gamma=None,
|
| 254 |
uncondition=True,
|
| 255 |
-
out_paint=False,
|
| 256 |
-
ssl_path='ckpt/encode-s12k.pt'
|
| 257 |
):
|
| 258 |
super().__init__()
|
| 259 |
|
|
|
|
| 252 |
unet_model_config_path=None,
|
| 253 |
snr_gamma=None,
|
| 254 |
uncondition=True,
|
|
|
|
|
|
|
| 255 |
):
|
| 256 |
super().__init__()
|
| 257 |
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
|
| 3 |
-
common:
|
| 4 |
-
fp16: true
|
| 5 |
-
log_format: json
|
| 6 |
-
log_interval: 200
|
| 7 |
-
tensorboard_logdir: tb
|
| 8 |
-
min_loss_scale: 1e-6
|
| 9 |
-
fp16_no_flatten_grads: true
|
| 10 |
-
user_dir: ${env:PWD}
|
| 11 |
-
seed: 1
|
| 12 |
-
|
| 13 |
-
checkpoint:
|
| 14 |
-
save_interval: 1
|
| 15 |
-
save_interval_updates: 10000
|
| 16 |
-
keep_interval_updates: 1
|
| 17 |
-
no_epoch_checkpoints: true
|
| 18 |
-
|
| 19 |
-
task:
|
| 20 |
-
_name: mae_image_pretraining
|
| 21 |
-
data: unbalanced_train
|
| 22 |
-
rebuild_batches: true
|
| 23 |
-
key: source
|
| 24 |
-
precompute_mask_config: {}
|
| 25 |
-
downsr_16hz: true
|
| 26 |
-
audio_mae: true
|
| 27 |
-
h5_format: false
|
| 28 |
-
target_length: 1024
|
| 29 |
-
flexible_mask: false
|
| 30 |
-
|
| 31 |
-
dataset:
|
| 32 |
-
num_workers: 10
|
| 33 |
-
batch_size: 12
|
| 34 |
-
skip_invalid_size_inputs_valid_test: true
|
| 35 |
-
required_batch_size_multiple: 1
|
| 36 |
-
disable_validation: true
|
| 37 |
-
|
| 38 |
-
distributed_training:
|
| 39 |
-
distributed_world_size: 4
|
| 40 |
-
ddp_backend: c10d
|
| 41 |
-
|
| 42 |
-
criterion:
|
| 43 |
-
_name: model
|
| 44 |
-
log_keys:
|
| 45 |
-
- ema_decay
|
| 46 |
-
- target_var
|
| 47 |
-
- pred_var
|
| 48 |
-
- model_norm
|
| 49 |
-
- ema_norm
|
| 50 |
-
- masked_pct
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [ 0.0005 ]
|
| 55 |
-
debug_param_names: true
|
| 56 |
-
clip_norm: 4
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: composite
|
| 60 |
-
dynamic_groups: true
|
| 61 |
-
groups:
|
| 62 |
-
default:
|
| 63 |
-
lr_float: 0.0005
|
| 64 |
-
optimizer:
|
| 65 |
-
_name: adam
|
| 66 |
-
adam_betas: [0.9,0.95]
|
| 67 |
-
weight_decay: 0.05
|
| 68 |
-
lr_scheduler:
|
| 69 |
-
_name: cosine
|
| 70 |
-
warmup_updates: 53333
|
| 71 |
-
|
| 72 |
-
lr_scheduler: pass_through
|
| 73 |
-
|
| 74 |
-
model:
|
| 75 |
-
_name: data2vec_multi
|
| 76 |
-
|
| 77 |
-
ema_decay: 0.9998
|
| 78 |
-
ema_end_decay: 0.99999
|
| 79 |
-
ema_anneal_end_step: 100000
|
| 80 |
-
instance_norm_target_layer: true
|
| 81 |
-
layer_norm_target_layer: false
|
| 82 |
-
layer_norm_targets: true
|
| 83 |
-
end_of_block_targets: false
|
| 84 |
-
|
| 85 |
-
depth: 12
|
| 86 |
-
average_top_k_layers: 12
|
| 87 |
-
clone_batch: 16
|
| 88 |
-
|
| 89 |
-
norm_eps: 1e-6
|
| 90 |
-
|
| 91 |
-
min_target_var: 0
|
| 92 |
-
min_pred_var: 0
|
| 93 |
-
|
| 94 |
-
encoder_dropout: 0
|
| 95 |
-
post_mlp_drop: 0
|
| 96 |
-
attention_dropout: 0
|
| 97 |
-
activation_dropout: 0
|
| 98 |
-
|
| 99 |
-
supported_modality: IMAGE
|
| 100 |
-
cls_loss: 1
|
| 101 |
-
|
| 102 |
-
ema_encoder_only: false
|
| 103 |
-
|
| 104 |
-
modalities:
|
| 105 |
-
image:
|
| 106 |
-
in_chans: 1
|
| 107 |
-
inverse_mask: true
|
| 108 |
-
mask_prob: 0.8
|
| 109 |
-
mask_prob_adjust: 0.07
|
| 110 |
-
mask_length: 5
|
| 111 |
-
mask_noise_std: 0.01
|
| 112 |
-
prenet_depth: 0
|
| 113 |
-
ema_local_encoder: true
|
| 114 |
-
num_extra_tokens: 1
|
| 115 |
-
init_extra_token_zero: false
|
| 116 |
-
use_alibi_encoder: false
|
| 117 |
-
decoder:
|
| 118 |
-
decoder_dim: 768
|
| 119 |
-
decoder_groups: 16
|
| 120 |
-
decoder_kernel: 3
|
| 121 |
-
decoder_layers: 6
|
| 122 |
-
input_dropout: 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
|
| 3 |
-
common:
|
| 4 |
-
fp16: true
|
| 5 |
-
log_format: json
|
| 6 |
-
log_interval: 200
|
| 7 |
-
tensorboard_logdir: tb
|
| 8 |
-
min_loss_scale: 1e-6
|
| 9 |
-
fp16_no_flatten_grads: true
|
| 10 |
-
user_dir: ${env:PWD}
|
| 11 |
-
seed: 1
|
| 12 |
-
|
| 13 |
-
checkpoint:
|
| 14 |
-
save_interval: 1
|
| 15 |
-
save_interval_updates: 10000
|
| 16 |
-
keep_interval_updates: 1000
|
| 17 |
-
no_epoch_checkpoints: true
|
| 18 |
-
|
| 19 |
-
task:
|
| 20 |
-
_name: mae_image_pretraining
|
| 21 |
-
data: music4all_sh/
|
| 22 |
-
rebuild_batches: true
|
| 23 |
-
key: source
|
| 24 |
-
precompute_mask_config: {}
|
| 25 |
-
downsr_16hz: false
|
| 26 |
-
audio_mae: true
|
| 27 |
-
h5_format: false
|
| 28 |
-
target_length: 752
|
| 29 |
-
flexible_mask: false
|
| 30 |
-
sample_rate: 24000
|
| 31 |
-
fixed_duration: 30
|
| 32 |
-
|
| 33 |
-
dataset:
|
| 34 |
-
num_workers: 10
|
| 35 |
-
batch_size: 12
|
| 36 |
-
skip_invalid_size_inputs_valid_test: true
|
| 37 |
-
required_batch_size_multiple: 1
|
| 38 |
-
disable_validation: true
|
| 39 |
-
|
| 40 |
-
distributed_training:
|
| 41 |
-
distributed_world_size: 4
|
| 42 |
-
ddp_backend: c10d
|
| 43 |
-
|
| 44 |
-
criterion:
|
| 45 |
-
_name: model
|
| 46 |
-
log_keys:
|
| 47 |
-
- ema_decay
|
| 48 |
-
- target_var
|
| 49 |
-
- pred_var
|
| 50 |
-
- model_norm
|
| 51 |
-
- ema_norm
|
| 52 |
-
- masked_pct
|
| 53 |
-
|
| 54 |
-
optimization:
|
| 55 |
-
max_update: 400000
|
| 56 |
-
lr: [ 0.0001 ]
|
| 57 |
-
# debug_param_names: true
|
| 58 |
-
clip_norm: 4
|
| 59 |
-
|
| 60 |
-
optimizer:
|
| 61 |
-
_name: composite
|
| 62 |
-
# dynamic_groups: true
|
| 63 |
-
groups:
|
| 64 |
-
default:
|
| 65 |
-
lr_float: 0.0005
|
| 66 |
-
optimizer:
|
| 67 |
-
_name: adam
|
| 68 |
-
adam_betas: [0.9,0.95]
|
| 69 |
-
weight_decay: 0.05
|
| 70 |
-
lr_scheduler:
|
| 71 |
-
_name: cosine
|
| 72 |
-
warmup_updates: 10000 # 53333
|
| 73 |
-
|
| 74 |
-
lr_scheduler: pass_through
|
| 75 |
-
|
| 76 |
-
model:
|
| 77 |
-
_name: data2vec_multi
|
| 78 |
-
|
| 79 |
-
ema_decay: 0.9998
|
| 80 |
-
ema_end_decay: 0.99999
|
| 81 |
-
ema_anneal_end_step: 100000
|
| 82 |
-
instance_norm_target_layer: true
|
| 83 |
-
layer_norm_target_layer: false
|
| 84 |
-
layer_norm_targets: true
|
| 85 |
-
end_of_block_targets: false
|
| 86 |
-
|
| 87 |
-
depth: 12
|
| 88 |
-
average_top_k_layers: 12
|
| 89 |
-
clone_batch: 16
|
| 90 |
-
|
| 91 |
-
norm_eps: 1e-6
|
| 92 |
-
|
| 93 |
-
min_target_var: 0
|
| 94 |
-
min_pred_var: 0
|
| 95 |
-
|
| 96 |
-
encoder_dropout: 0
|
| 97 |
-
post_mlp_drop: 0
|
| 98 |
-
attention_dropout: 0
|
| 99 |
-
activation_dropout: 0
|
| 100 |
-
|
| 101 |
-
supported_modality: IMAGE
|
| 102 |
-
cls_loss: 1
|
| 103 |
-
|
| 104 |
-
ema_encoder_only: false
|
| 105 |
-
|
| 106 |
-
modalities:
|
| 107 |
-
image:
|
| 108 |
-
in_chans: 1
|
| 109 |
-
inverse_mask: true
|
| 110 |
-
mask_prob: 0.8
|
| 111 |
-
mask_prob_adjust: 0.07
|
| 112 |
-
mask_length: 5
|
| 113 |
-
mask_noise_std: 0.01
|
| 114 |
-
prenet_depth: 0
|
| 115 |
-
ema_local_encoder: true
|
| 116 |
-
num_extra_tokens: 1
|
| 117 |
-
init_extra_token_zero: false
|
| 118 |
-
use_alibi_encoder: false
|
| 119 |
-
decoder:
|
| 120 |
-
decoder_dim: 768
|
| 121 |
-
decoder_groups: 16
|
| 122 |
-
decoder_kernel: 3
|
| 123 |
-
decoder_layers: 6
|
| 124 |
-
input_dropout: 0
|
| 125 |
-
target_length: 752
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 100
|
| 6 |
-
seed: 1337
|
| 7 |
-
|
| 8 |
-
# tensorboard_logdir: tblog_proj_name
|
| 9 |
-
# wandb_project: wandb_proj_name
|
| 10 |
-
|
| 11 |
-
checkpoint:
|
| 12 |
-
save_interval_updates: 5000
|
| 13 |
-
keep_interval_updates: -1
|
| 14 |
-
no_epoch_checkpoints: true
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
distributed_training:
|
| 18 |
-
ddp_backend: no_c10d
|
| 19 |
-
distributed_backend: 'nccl'
|
| 20 |
-
distributed_world_size: 64
|
| 21 |
-
nprocs_per_node: 8
|
| 22 |
-
find_unused_parameters: true
|
| 23 |
-
# reset-dataloader: true
|
| 24 |
-
|
| 25 |
-
task:
|
| 26 |
-
_name: mert_pretraining
|
| 27 |
-
data: ???
|
| 28 |
-
label_dir: ???
|
| 29 |
-
labels: ???
|
| 30 |
-
label_rate: ${model.label_rate}
|
| 31 |
-
sharding_data: -1 #数据分块
|
| 32 |
-
load_random_data_shard: false
|
| 33 |
-
sample_rate: 24000
|
| 34 |
-
# crop to 5s
|
| 35 |
-
# max_sample_size: 120000
|
| 36 |
-
# crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
|
| 37 |
-
max_sample_size: 122880
|
| 38 |
-
min_sample_size: 72000
|
| 39 |
-
|
| 40 |
-
pad_audio: false
|
| 41 |
-
random_crop: true
|
| 42 |
-
# normalize: true # must be consistent with extractor_mode: layer_norm
|
| 43 |
-
normalize: false # must be consistent with extractor_mode: default (groupnorm)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
dataset:
|
| 47 |
-
num_workers: 6
|
| 48 |
-
max_tokens: 900000
|
| 49 |
-
skip_invalid_size_inputs_valid_test: true
|
| 50 |
-
validate_interval: 1
|
| 51 |
-
validate_interval_updates: 10000
|
| 52 |
-
|
| 53 |
-
criterion:
|
| 54 |
-
_name: hubert
|
| 55 |
-
pred_masked_weight: 1.0
|
| 56 |
-
pred_nomask_weight: 0.0
|
| 57 |
-
loss_weights: [10, 1]
|
| 58 |
-
|
| 59 |
-
optimization:
|
| 60 |
-
max_update: 1000000
|
| 61 |
-
lr: [0.0015]
|
| 62 |
-
clip_norm: 1.0
|
| 63 |
-
update_freq: [8]
|
| 64 |
-
|
| 65 |
-
optimizer:
|
| 66 |
-
_name: adam
|
| 67 |
-
adam_betas: (0.9,0.98)
|
| 68 |
-
adam_eps: 1e-06
|
| 69 |
-
weight_decay: 0.01
|
| 70 |
-
|
| 71 |
-
lr_scheduler:
|
| 72 |
-
_name: polynomial_decay
|
| 73 |
-
warmup_updates: 32000
|
| 74 |
-
|
| 75 |
-
model:
|
| 76 |
-
_name: mert
|
| 77 |
-
label_rate: ???
|
| 78 |
-
skip_masked: false
|
| 79 |
-
skip_nomask: true
|
| 80 |
-
mask_prob: 0.8
|
| 81 |
-
mask_length: 5
|
| 82 |
-
|
| 83 |
-
logit_temp: 0.1
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ----- mixture ------
|
| 87 |
-
mixture_prob: 0.5
|
| 88 |
-
inbatch_noise_augment_len_range: "[12000, 36000]"
|
| 89 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 90 |
-
inbatch_noise_augment_volume: 1.0
|
| 91 |
-
# ------------------------
|
| 92 |
-
|
| 93 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 94 |
-
audio_cqt_loss_m: true
|
| 95 |
-
audio_cqt_bins: 336
|
| 96 |
-
|
| 97 |
-
final_dim: 128
|
| 98 |
-
encoder_layers: 24
|
| 99 |
-
encoder_embed_dim: 1024
|
| 100 |
-
encoder_ffn_embed_dim: 4096
|
| 101 |
-
encoder_attention_heads: 16
|
| 102 |
-
# default refers to group norm
|
| 103 |
-
extractor_mode: default
|
| 104 |
-
# extractor_mode: layer_norm
|
| 105 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 106 |
-
encoder_layerdrop: 0.0
|
| 107 |
-
dropout_input: 0.0
|
| 108 |
-
dropout_features: 0.0
|
| 109 |
-
dropout: 0.0
|
| 110 |
-
attention_dropout: 0.0
|
| 111 |
-
|
| 112 |
-
layer_norm_first: true
|
| 113 |
-
feature_grad_mult: 1.0
|
| 114 |
-
|
| 115 |
-
untie_final_proj: true
|
| 116 |
-
activation_dropout: 0.0
|
| 117 |
-
|
| 118 |
-
deepnorm: false
|
| 119 |
-
attention_relax: 32.0
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
hydra:
|
| 124 |
-
job:
|
| 125 |
-
config:
|
| 126 |
-
override_dirname:
|
| 127 |
-
kv_sep: '-'
|
| 128 |
-
item_sep: '__'
|
| 129 |
-
exclude_keys:
|
| 130 |
-
- run
|
| 131 |
-
- task.data
|
| 132 |
-
- task.label_dir
|
| 133 |
-
run:
|
| 134 |
-
dir: ???
|
| 135 |
-
sweep:
|
| 136 |
-
dir: ???
|
| 137 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml
DELETED
|
@@ -1,139 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 100
|
| 6 |
-
seed: 1337
|
| 7 |
-
# model_parallel_size: 8
|
| 8 |
-
# amp: true
|
| 9 |
-
|
| 10 |
-
# tensorboard_logdir: tblog_proj_name
|
| 11 |
-
# wandb_project: wandb_proj_name
|
| 12 |
-
|
| 13 |
-
checkpoint:
|
| 14 |
-
save_interval_updates: 5000
|
| 15 |
-
keep_interval_updates: -1
|
| 16 |
-
no_epoch_checkpoints: true
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
distributed_training:
|
| 20 |
-
ddp_backend: c10d
|
| 21 |
-
distributed_backend: 'nccl'
|
| 22 |
-
distributed_world_size: 64
|
| 23 |
-
nprocs_per_node: 8
|
| 24 |
-
find_unused_parameters: true
|
| 25 |
-
# reset-dataloader: true
|
| 26 |
-
|
| 27 |
-
task:
|
| 28 |
-
_name: mert_pretraining
|
| 29 |
-
data: ???
|
| 30 |
-
label_dir: ???
|
| 31 |
-
labels: ???
|
| 32 |
-
label_rate: ${model.label_rate}
|
| 33 |
-
sharding_data: -1 #数据分块
|
| 34 |
-
load_random_data_shard: false
|
| 35 |
-
sample_rate: 24000
|
| 36 |
-
# crop to 5s
|
| 37 |
-
# max_sample_size: 120000
|
| 38 |
-
# crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
|
| 39 |
-
max_sample_size: 122880
|
| 40 |
-
min_sample_size: 72000
|
| 41 |
-
|
| 42 |
-
pad_audio: false
|
| 43 |
-
random_crop: true
|
| 44 |
-
# normalize: true # must be consistent with extractor_mode: layer_norm
|
| 45 |
-
normalize: false # must be consistent with extractor_mode: default (groupnorm)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
dataset:
|
| 49 |
-
num_workers: 6
|
| 50 |
-
max_tokens: 900000
|
| 51 |
-
skip_invalid_size_inputs_valid_test: true
|
| 52 |
-
validate_interval: 1
|
| 53 |
-
validate_interval_updates: 10000
|
| 54 |
-
|
| 55 |
-
criterion:
|
| 56 |
-
_name: hubert
|
| 57 |
-
pred_masked_weight: 1.0
|
| 58 |
-
pred_nomask_weight: 0.0
|
| 59 |
-
loss_weights: [10, 1]
|
| 60 |
-
|
| 61 |
-
optimization:
|
| 62 |
-
max_update: 1000000
|
| 63 |
-
lr: [0.0015]
|
| 64 |
-
clip_norm: 1.0
|
| 65 |
-
update_freq: [8]
|
| 66 |
-
|
| 67 |
-
optimizer:
|
| 68 |
-
_name: adam
|
| 69 |
-
adam_betas: (0.9,0.98)
|
| 70 |
-
adam_eps: 1e-06
|
| 71 |
-
weight_decay: 0.01
|
| 72 |
-
|
| 73 |
-
lr_scheduler:
|
| 74 |
-
_name: polynomial_decay
|
| 75 |
-
warmup_updates: 32000
|
| 76 |
-
|
| 77 |
-
model:
|
| 78 |
-
_name: mert
|
| 79 |
-
label_rate: ???
|
| 80 |
-
skip_masked: false
|
| 81 |
-
skip_nomask: true
|
| 82 |
-
mask_prob: 0.8
|
| 83 |
-
mask_length: 5
|
| 84 |
-
|
| 85 |
-
logit_temp: 0.1
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# ----- mixture ------
|
| 89 |
-
mixture_prob: 0.5
|
| 90 |
-
inbatch_noise_augment_len_range: "[12000, 36000]"
|
| 91 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 92 |
-
inbatch_noise_augment_volume: 1.0
|
| 93 |
-
# ------------------------
|
| 94 |
-
|
| 95 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 96 |
-
audio_cqt_loss_m: true
|
| 97 |
-
audio_cqt_bins: 336
|
| 98 |
-
|
| 99 |
-
final_dim: 128
|
| 100 |
-
encoder_layers: 24
|
| 101 |
-
encoder_embed_dim: 1024
|
| 102 |
-
encoder_ffn_embed_dim: 4096
|
| 103 |
-
encoder_attention_heads: 16
|
| 104 |
-
# default refers to group norm
|
| 105 |
-
extractor_mode: default
|
| 106 |
-
# extractor_mode: layer_norm
|
| 107 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 108 |
-
encoder_layerdrop: 0.0
|
| 109 |
-
dropout_input: 0.0
|
| 110 |
-
dropout_features: 0.0
|
| 111 |
-
dropout: 0.0
|
| 112 |
-
attention_dropout: 0.0
|
| 113 |
-
|
| 114 |
-
layer_norm_first: true
|
| 115 |
-
feature_grad_mult: 1.0
|
| 116 |
-
|
| 117 |
-
untie_final_proj: true
|
| 118 |
-
activation_dropout: 0.0
|
| 119 |
-
|
| 120 |
-
deepnorm: false
|
| 121 |
-
attention_relax: 32.0
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
hydra:
|
| 126 |
-
job:
|
| 127 |
-
config:
|
| 128 |
-
override_dirname:
|
| 129 |
-
kv_sep: '-'
|
| 130 |
-
item_sep: '__'
|
| 131 |
-
exclude_keys:
|
| 132 |
-
- run
|
| 133 |
-
- task.data
|
| 134 |
-
- task.label_dir
|
| 135 |
-
run:
|
| 136 |
-
dir: run
|
| 137 |
-
sweep:
|
| 138 |
-
dir: sweep
|
| 139 |
-
subdir: subdir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml
DELETED
|
@@ -1,138 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 100
|
| 6 |
-
seed: 1337
|
| 7 |
-
# amp: true
|
| 8 |
-
|
| 9 |
-
# tensorboard_logdir: tblog_proj_name
|
| 10 |
-
# wandb_project: wandb_proj_name
|
| 11 |
-
|
| 12 |
-
checkpoint:
|
| 13 |
-
save_interval_updates: 5000
|
| 14 |
-
keep_interval_updates: -1
|
| 15 |
-
no_epoch_checkpoints: true
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
distributed_training:
|
| 19 |
-
ddp_backend: c10d
|
| 20 |
-
distributed_backend: 'nccl'
|
| 21 |
-
distributed_world_size: 64
|
| 22 |
-
nprocs_per_node: 8
|
| 23 |
-
find_unused_parameters: true
|
| 24 |
-
# reset-dataloader: true
|
| 25 |
-
|
| 26 |
-
task:
|
| 27 |
-
_name: mert_pretraining
|
| 28 |
-
data: ???
|
| 29 |
-
label_dir: ???
|
| 30 |
-
labels: ???
|
| 31 |
-
label_rate: ${model.label_rate}
|
| 32 |
-
sharding_data: -1 #数据分块
|
| 33 |
-
load_random_data_shard: false
|
| 34 |
-
sample_rate: 24000
|
| 35 |
-
# crop to 5s
|
| 36 |
-
# max_sample_size: 120000
|
| 37 |
-
# crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
|
| 38 |
-
max_sample_size: 122880
|
| 39 |
-
min_sample_size: 72000
|
| 40 |
-
|
| 41 |
-
pad_audio: false
|
| 42 |
-
random_crop: true
|
| 43 |
-
# normalize: true # must be consistent with extractor_mode: layer_norm
|
| 44 |
-
normalize: false # must be consistent with extractor_mode: default (groupnorm)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
dataset:
|
| 48 |
-
num_workers: 6
|
| 49 |
-
max_tokens: 900000
|
| 50 |
-
skip_invalid_size_inputs_valid_test: true
|
| 51 |
-
validate_interval: 1
|
| 52 |
-
validate_interval_updates: 10000
|
| 53 |
-
|
| 54 |
-
criterion:
|
| 55 |
-
_name: hubert
|
| 56 |
-
pred_masked_weight: 1.0
|
| 57 |
-
pred_nomask_weight: 0.0
|
| 58 |
-
loss_weights: [10, 1]
|
| 59 |
-
|
| 60 |
-
optimization:
|
| 61 |
-
max_update: 1000000
|
| 62 |
-
lr: [0.0015]
|
| 63 |
-
clip_norm: 1.0
|
| 64 |
-
update_freq: [8]
|
| 65 |
-
|
| 66 |
-
optimizer:
|
| 67 |
-
_name: adam
|
| 68 |
-
adam_betas: (0.9,0.98)
|
| 69 |
-
adam_eps: 1e-06
|
| 70 |
-
weight_decay: 0.01
|
| 71 |
-
|
| 72 |
-
lr_scheduler:
|
| 73 |
-
_name: polynomial_decay
|
| 74 |
-
warmup_updates: 32000
|
| 75 |
-
|
| 76 |
-
model:
|
| 77 |
-
_name: mert
|
| 78 |
-
label_rate: ???
|
| 79 |
-
skip_masked: false
|
| 80 |
-
skip_nomask: true
|
| 81 |
-
mask_prob: 0.8
|
| 82 |
-
mask_length: 5
|
| 83 |
-
|
| 84 |
-
logit_temp: 0.1
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# ----- mixture ------
|
| 88 |
-
mixture_prob: 0.5
|
| 89 |
-
inbatch_noise_augment_len_range: "[12000, 36000]"
|
| 90 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 91 |
-
inbatch_noise_augment_volume: 1.0
|
| 92 |
-
# ------------------------
|
| 93 |
-
|
| 94 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 95 |
-
audio_cqt_loss_m: true
|
| 96 |
-
audio_cqt_bins: 336
|
| 97 |
-
|
| 98 |
-
final_dim: 128
|
| 99 |
-
encoder_layers: 24
|
| 100 |
-
encoder_embed_dim: 1024
|
| 101 |
-
encoder_ffn_embed_dim: 4096
|
| 102 |
-
encoder_attention_heads: 16
|
| 103 |
-
# default refers to group norm
|
| 104 |
-
extractor_mode: default
|
| 105 |
-
# extractor_mode: layer_norm
|
| 106 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 107 |
-
encoder_layerdrop: 0.0
|
| 108 |
-
dropout_input: 0.0
|
| 109 |
-
dropout_features: 0.0
|
| 110 |
-
dropout: 0.0
|
| 111 |
-
attention_dropout: 0.0
|
| 112 |
-
|
| 113 |
-
layer_norm_first: true
|
| 114 |
-
feature_grad_mult: 1.0
|
| 115 |
-
|
| 116 |
-
untie_final_proj: true
|
| 117 |
-
activation_dropout: 0.0
|
| 118 |
-
|
| 119 |
-
deepnorm: false
|
| 120 |
-
attention_relax: 32.0
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
hydra:
|
| 125 |
-
job:
|
| 126 |
-
config:
|
| 127 |
-
override_dirname:
|
| 128 |
-
kv_sep: '-'
|
| 129 |
-
item_sep: '__'
|
| 130 |
-
exclude_keys:
|
| 131 |
-
- run
|
| 132 |
-
- task.data
|
| 133 |
-
- task.label_dir
|
| 134 |
-
run:
|
| 135 |
-
dir: run
|
| 136 |
-
sweep:
|
| 137 |
-
dir: sweep
|
| 138 |
-
subdir: subdir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml
DELETED
|
@@ -1,139 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 100
|
| 6 |
-
seed: 1337
|
| 7 |
-
model_parallel_size: 8
|
| 8 |
-
# amp: true
|
| 9 |
-
|
| 10 |
-
# tensorboard_logdir: tblog_proj_name
|
| 11 |
-
# wandb_project: wandb_proj_name
|
| 12 |
-
|
| 13 |
-
checkpoint:
|
| 14 |
-
save_interval_updates: 5000
|
| 15 |
-
keep_interval_updates: -1
|
| 16 |
-
no_epoch_checkpoints: true
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
distributed_training:
|
| 20 |
-
ddp_backend: c10d
|
| 21 |
-
distributed_backend: 'nccl'
|
| 22 |
-
distributed_world_size: 64
|
| 23 |
-
nprocs_per_node: 8
|
| 24 |
-
find_unused_parameters: true
|
| 25 |
-
# reset-dataloader: true
|
| 26 |
-
|
| 27 |
-
task:
|
| 28 |
-
_name: mert_pretraining
|
| 29 |
-
data: ???
|
| 30 |
-
label_dir: ???
|
| 31 |
-
labels: ???
|
| 32 |
-
label_rate: ${model.label_rate}
|
| 33 |
-
sharding_data: -1 #数据分块
|
| 34 |
-
load_random_data_shard: false
|
| 35 |
-
sample_rate: 24000
|
| 36 |
-
# crop to 5s
|
| 37 |
-
# max_sample_size: 120000
|
| 38 |
-
# crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
|
| 39 |
-
max_sample_size: 122880
|
| 40 |
-
min_sample_size: 72000
|
| 41 |
-
|
| 42 |
-
pad_audio: false
|
| 43 |
-
random_crop: true
|
| 44 |
-
# normalize: true # must be consistent with extractor_mode: layer_norm
|
| 45 |
-
normalize: false # must be consistent with extractor_mode: default (groupnorm)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
dataset:
|
| 49 |
-
num_workers: 6
|
| 50 |
-
max_tokens: null
|
| 51 |
-
skip_invalid_size_inputs_valid_test: true
|
| 52 |
-
validate_interval: 1
|
| 53 |
-
validate_interval_updates: 10000
|
| 54 |
-
|
| 55 |
-
criterion:
|
| 56 |
-
_name: hubert
|
| 57 |
-
pred_masked_weight: 1.0
|
| 58 |
-
pred_nomask_weight: 0.0
|
| 59 |
-
loss_weights: [10, 1]
|
| 60 |
-
|
| 61 |
-
optimization:
|
| 62 |
-
max_update: 1000000
|
| 63 |
-
lr: [0.0015]
|
| 64 |
-
clip_norm: 1.0
|
| 65 |
-
update_freq: [8]
|
| 66 |
-
|
| 67 |
-
optimizer:
|
| 68 |
-
_name: adam
|
| 69 |
-
adam_betas: (0.9,0.98)
|
| 70 |
-
adam_eps: 1e-06
|
| 71 |
-
weight_decay: 0.01
|
| 72 |
-
|
| 73 |
-
lr_scheduler:
|
| 74 |
-
_name: polynomial_decay
|
| 75 |
-
warmup_updates: 32000
|
| 76 |
-
|
| 77 |
-
model:
|
| 78 |
-
_name: mert
|
| 79 |
-
label_rate: ???
|
| 80 |
-
skip_masked: false
|
| 81 |
-
skip_nomask: true
|
| 82 |
-
mask_prob: 0.8
|
| 83 |
-
mask_length: 5
|
| 84 |
-
|
| 85 |
-
logit_temp: 0.1
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# ----- mixture ------
|
| 89 |
-
mixture_prob: 0.5
|
| 90 |
-
inbatch_noise_augment_len_range: "[12000, 36000]"
|
| 91 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 92 |
-
inbatch_noise_augment_volume: 1.0
|
| 93 |
-
# ------------------------
|
| 94 |
-
|
| 95 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 96 |
-
audio_cqt_loss_m: true
|
| 97 |
-
audio_cqt_bins: 336
|
| 98 |
-
|
| 99 |
-
final_dim: 128
|
| 100 |
-
encoder_layers: 24
|
| 101 |
-
encoder_embed_dim: 1024
|
| 102 |
-
encoder_ffn_embed_dim: 4096
|
| 103 |
-
encoder_attention_heads: 16
|
| 104 |
-
# default refers to group norm
|
| 105 |
-
extractor_mode: default
|
| 106 |
-
# extractor_mode: layer_norm
|
| 107 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 108 |
-
encoder_layerdrop: 0.0
|
| 109 |
-
dropout_input: 0.0
|
| 110 |
-
dropout_features: 0.0
|
| 111 |
-
dropout: 0.0
|
| 112 |
-
attention_dropout: 0.0
|
| 113 |
-
|
| 114 |
-
layer_norm_first: true
|
| 115 |
-
feature_grad_mult: 1.0
|
| 116 |
-
|
| 117 |
-
untie_final_proj: true
|
| 118 |
-
activation_dropout: 0.0
|
| 119 |
-
|
| 120 |
-
deepnorm: false
|
| 121 |
-
attention_relax: 32.0
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
hydra:
|
| 126 |
-
job:
|
| 127 |
-
config:
|
| 128 |
-
override_dirname:
|
| 129 |
-
kv_sep: '-'
|
| 130 |
-
item_sep: '__'
|
| 131 |
-
exclude_keys:
|
| 132 |
-
- run
|
| 133 |
-
- task.data
|
| 134 |
-
- task.label_dir
|
| 135 |
-
run:
|
| 136 |
-
dir: run
|
| 137 |
-
sweep:
|
| 138 |
-
dir: sweep
|
| 139 |
-
subdir: subdir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml
DELETED
|
@@ -1,135 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: true
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 100
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 5000
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sharding_data: 6
|
| 30 |
-
load_random_data_shard: false
|
| 31 |
-
sample_rate: 24000
|
| 32 |
-
# crop to 5s
|
| 33 |
-
# max_sample_size: 120000
|
| 34 |
-
# crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
|
| 35 |
-
max_sample_size: 122880
|
| 36 |
-
min_sample_size: 72000
|
| 37 |
-
|
| 38 |
-
pad_audio: false
|
| 39 |
-
random_crop: true
|
| 40 |
-
# normalize: true # must be consistent with extractor_mode: layer_norm
|
| 41 |
-
normalize: false # must be consistent with extractor_mode: default (groupnorm)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
dataset:
|
| 45 |
-
num_workers: 6
|
| 46 |
-
max_tokens: 900000
|
| 47 |
-
skip_invalid_size_inputs_valid_test: true
|
| 48 |
-
validate_interval: 1
|
| 49 |
-
validate_interval_updates: 10000
|
| 50 |
-
|
| 51 |
-
criterion:
|
| 52 |
-
_name: hubert
|
| 53 |
-
pred_masked_weight: 1.0
|
| 54 |
-
pred_nomask_weight: 0.0
|
| 55 |
-
loss_weights: [10, 1]
|
| 56 |
-
|
| 57 |
-
optimization:
|
| 58 |
-
max_update: 400000
|
| 59 |
-
lr: [0.0015]
|
| 60 |
-
clip_norm: 1.0
|
| 61 |
-
update_freq: [8]
|
| 62 |
-
|
| 63 |
-
optimizer:
|
| 64 |
-
_name: adam
|
| 65 |
-
adam_betas: (0.9,0.98)
|
| 66 |
-
adam_eps: 1e-06
|
| 67 |
-
weight_decay: 0.01
|
| 68 |
-
|
| 69 |
-
lr_scheduler:
|
| 70 |
-
_name: polynomial_decay
|
| 71 |
-
warmup_updates: 32000
|
| 72 |
-
|
| 73 |
-
model:
|
| 74 |
-
_name: mert
|
| 75 |
-
label_rate: ???
|
| 76 |
-
skip_masked: false
|
| 77 |
-
skip_nomask: true
|
| 78 |
-
mask_prob: 0.8
|
| 79 |
-
mask_length: 5
|
| 80 |
-
|
| 81 |
-
logit_temp: 0.1
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
# ----- mixture ------
|
| 85 |
-
mixture_prob: 0.5
|
| 86 |
-
inbatch_noise_augment_len_range: "[12000, 36000]"
|
| 87 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 88 |
-
inbatch_noise_augment_volume: 1.0
|
| 89 |
-
# ------------------------
|
| 90 |
-
|
| 91 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 92 |
-
audio_cqt_loss_m: true
|
| 93 |
-
audio_cqt_bins: 336
|
| 94 |
-
|
| 95 |
-
final_dim: 128
|
| 96 |
-
encoder_layers: 24
|
| 97 |
-
encoder_embed_dim: 1024
|
| 98 |
-
encoder_ffn_embed_dim: 4096
|
| 99 |
-
encoder_attention_heads: 16
|
| 100 |
-
# default refers to group norm
|
| 101 |
-
extractor_mode: default
|
| 102 |
-
# extractor_mode: layer_norm
|
| 103 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 104 |
-
encoder_layerdrop: 0.0
|
| 105 |
-
dropout_input: 0.0
|
| 106 |
-
dropout_features: 0.0
|
| 107 |
-
dropout: 0.0
|
| 108 |
-
attention_dropout: 0.0
|
| 109 |
-
|
| 110 |
-
layer_norm_first: true
|
| 111 |
-
feature_grad_mult: 1.0
|
| 112 |
-
|
| 113 |
-
untie_final_proj: true
|
| 114 |
-
activation_dropout: 0.0
|
| 115 |
-
|
| 116 |
-
deepnorm: false
|
| 117 |
-
attention_relax: 32.0
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
hydra:
|
| 122 |
-
job:
|
| 123 |
-
config:
|
| 124 |
-
override_dirname:
|
| 125 |
-
kv_sep: '-'
|
| 126 |
-
item_sep: '__'
|
| 127 |
-
exclude_keys:
|
| 128 |
-
- run
|
| 129 |
-
- task.data
|
| 130 |
-
- task.label_dir
|
| 131 |
-
run:
|
| 132 |
-
dir: ???
|
| 133 |
-
sweep:
|
| 134 |
-
dir: ???
|
| 135 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: true
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 100
|
| 6 |
-
seed: 1337
|
| 7 |
-
|
| 8 |
-
# tensorboard_logdir: tblog_proj_name
|
| 9 |
-
# wandb_project: wandb_proj_name
|
| 10 |
-
|
| 11 |
-
checkpoint:
|
| 12 |
-
save_interval_updates: 5000
|
| 13 |
-
keep_interval_updates: -1
|
| 14 |
-
no_epoch_checkpoints: true
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
distributed_training:
|
| 18 |
-
ddp_backend: no_c10d
|
| 19 |
-
distributed_backend: 'nccl'
|
| 20 |
-
distributed_world_size: 64
|
| 21 |
-
nprocs_per_node: 8
|
| 22 |
-
find_unused_parameters: true
|
| 23 |
-
# reset-dataloader: true
|
| 24 |
-
|
| 25 |
-
task:
|
| 26 |
-
_name: mert_pretraining
|
| 27 |
-
data: ???
|
| 28 |
-
label_dir: ???
|
| 29 |
-
labels: ???
|
| 30 |
-
label_rate: ${model.label_rate}
|
| 31 |
-
sharding_data: -1 #数据分块
|
| 32 |
-
load_random_data_shard: false
|
| 33 |
-
sample_rate: 24000
|
| 34 |
-
# crop to 5s
|
| 35 |
-
# max_sample_size: 120000
|
| 36 |
-
# crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
|
| 37 |
-
max_sample_size: 122880
|
| 38 |
-
min_sample_size: 72000
|
| 39 |
-
|
| 40 |
-
pad_audio: false
|
| 41 |
-
random_crop: true
|
| 42 |
-
# normalize: true # must be consistent with extractor_mode: layer_norm
|
| 43 |
-
normalize: false # must be consistent with extractor_mode: default (groupnorm)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
dataset:
|
| 47 |
-
num_workers: 6
|
| 48 |
-
max_tokens: 900000
|
| 49 |
-
skip_invalid_size_inputs_valid_test: true
|
| 50 |
-
validate_interval: 1
|
| 51 |
-
validate_interval_updates: 10000
|
| 52 |
-
|
| 53 |
-
criterion:
|
| 54 |
-
_name: hubert
|
| 55 |
-
pred_masked_weight: 1.0
|
| 56 |
-
pred_nomask_weight: 0.0
|
| 57 |
-
loss_weights: [10, 1]
|
| 58 |
-
|
| 59 |
-
optimization:
|
| 60 |
-
max_update: 400000
|
| 61 |
-
lr: [0.0015]
|
| 62 |
-
clip_norm: 1.0
|
| 63 |
-
update_freq: [8]
|
| 64 |
-
|
| 65 |
-
optimizer:
|
| 66 |
-
_name: adam
|
| 67 |
-
adam_betas: (0.9,0.98)
|
| 68 |
-
adam_eps: 1e-06
|
| 69 |
-
weight_decay: 0.01
|
| 70 |
-
|
| 71 |
-
lr_scheduler:
|
| 72 |
-
_name: polynomial_decay
|
| 73 |
-
warmup_updates: 32000
|
| 74 |
-
|
| 75 |
-
model:
|
| 76 |
-
_name: mert
|
| 77 |
-
label_rate: ???
|
| 78 |
-
skip_masked: false
|
| 79 |
-
skip_nomask: true
|
| 80 |
-
mask_prob: 0.8
|
| 81 |
-
mask_length: 5
|
| 82 |
-
# freeze_parameters:true
|
| 83 |
-
logit_temp: 0.1
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ----- mixture ------
|
| 87 |
-
mixture_prob: 0.5
|
| 88 |
-
inbatch_noise_augment_len_range: "[12000, 36000]"
|
| 89 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 90 |
-
inbatch_noise_augment_volume: 1.0
|
| 91 |
-
# ------------------------
|
| 92 |
-
|
| 93 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 94 |
-
audio_cqt_loss_m: true
|
| 95 |
-
audio_cqt_bins: 336
|
| 96 |
-
|
| 97 |
-
final_dim: 128
|
| 98 |
-
encoder_layers: 24
|
| 99 |
-
encoder_embed_dim: 1024
|
| 100 |
-
encoder_ffn_embed_dim: 4096
|
| 101 |
-
encoder_attention_heads: 16
|
| 102 |
-
# default refers to group norm
|
| 103 |
-
extractor_mode: default
|
| 104 |
-
# extractor_mode: layer_norm
|
| 105 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 106 |
-
encoder_layerdrop: 0.0
|
| 107 |
-
dropout_input: 0.0
|
| 108 |
-
dropout_features: 0.0
|
| 109 |
-
dropout: 0.0
|
| 110 |
-
attention_dropout: 0.0
|
| 111 |
-
|
| 112 |
-
layer_norm_first: true
|
| 113 |
-
feature_grad_mult: 1.0
|
| 114 |
-
|
| 115 |
-
untie_final_proj: true
|
| 116 |
-
activation_dropout: 0.0
|
| 117 |
-
|
| 118 |
-
deepnorm: false
|
| 119 |
-
attention_relax: 32.0
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
hydra:
|
| 124 |
-
job:
|
| 125 |
-
config:
|
| 126 |
-
override_dirname:
|
| 127 |
-
kv_sep: '-'
|
| 128 |
-
item_sep: '__'
|
| 129 |
-
exclude_keys:
|
| 130 |
-
- run
|
| 131 |
-
- task.data
|
| 132 |
-
- task.label_dir
|
| 133 |
-
run:
|
| 134 |
-
dir: ???
|
| 135 |
-
sweep:
|
| 136 |
-
dir: ???
|
| 137 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml
DELETED
|
@@ -1,116 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 25000
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
|
| 57 |
-
optimizer:
|
| 58 |
-
_name: adam
|
| 59 |
-
adam_betas: (0.9,0.98)
|
| 60 |
-
adam_eps: 1e-06
|
| 61 |
-
weight_decay: 0.01
|
| 62 |
-
|
| 63 |
-
lr_scheduler:
|
| 64 |
-
_name: polynomial_decay
|
| 65 |
-
warmup_updates: 32000
|
| 66 |
-
|
| 67 |
-
model:
|
| 68 |
-
_name: mert
|
| 69 |
-
label_rate: ???
|
| 70 |
-
skip_masked: false
|
| 71 |
-
skip_nomask: true
|
| 72 |
-
mask_prob: 0.8
|
| 73 |
-
mask_length: 5
|
| 74 |
-
|
| 75 |
-
logit_temp: 0.1
|
| 76 |
-
|
| 77 |
-
# ----- mixture ------
|
| 78 |
-
mixture_prob: 0.5
|
| 79 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 80 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 81 |
-
inbatch_noise_augment_volume: 1.0
|
| 82 |
-
# ------------------------
|
| 83 |
-
extractor_mode: default
|
| 84 |
-
audio_extract_type: w2v_conv
|
| 85 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 86 |
-
|
| 87 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 88 |
-
audio_cqt_loss_m: true
|
| 89 |
-
audio_cqt_bins: 336
|
| 90 |
-
# -----------
|
| 91 |
-
final_dim: 64
|
| 92 |
-
encoder_layerdrop: 0.05
|
| 93 |
-
dropout_input: 0.1
|
| 94 |
-
dropout_features: 0.1
|
| 95 |
-
dropout: 0.1
|
| 96 |
-
attention_dropout: 0.1
|
| 97 |
-
feature_grad_mult: 0.1
|
| 98 |
-
untie_final_proj: true
|
| 99 |
-
activation_dropout: 0.0
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
hydra:
|
| 103 |
-
job:
|
| 104 |
-
config:
|
| 105 |
-
override_dirname:
|
| 106 |
-
kv_sep: '-'
|
| 107 |
-
item_sep: '__'
|
| 108 |
-
exclude_keys:
|
| 109 |
-
- run
|
| 110 |
-
- task.data
|
| 111 |
-
- task.label_dir
|
| 112 |
-
run:
|
| 113 |
-
dir: ???
|
| 114 |
-
sweep:
|
| 115 |
-
dir: ???
|
| 116 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 25000
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 8 # 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
|
| 57 |
-
optimizer:
|
| 58 |
-
_name: adam
|
| 59 |
-
adam_betas: (0.9,0.98)
|
| 60 |
-
adam_eps: 1e-06
|
| 61 |
-
weight_decay: 0.01
|
| 62 |
-
|
| 63 |
-
lr_scheduler:
|
| 64 |
-
_name: polynomial_decay
|
| 65 |
-
warmup_updates: 32000
|
| 66 |
-
|
| 67 |
-
model:
|
| 68 |
-
_name: mert
|
| 69 |
-
label_rate: ???
|
| 70 |
-
skip_masked: false
|
| 71 |
-
skip_nomask: true
|
| 72 |
-
mask_prob: 0.8
|
| 73 |
-
mask_length: 5
|
| 74 |
-
|
| 75 |
-
logit_temp: 0.1
|
| 76 |
-
|
| 77 |
-
# ----- mixture ------
|
| 78 |
-
mixture_prob: 0.5
|
| 79 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 80 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 81 |
-
inbatch_noise_augment_volume: 1.0
|
| 82 |
-
# ------------------------
|
| 83 |
-
extractor_mode: default
|
| 84 |
-
audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
|
| 85 |
-
melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
|
| 86 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 87 |
-
|
| 88 |
-
# best-rq loss
|
| 89 |
-
audio_rq_loss_m: true
|
| 90 |
-
audio_rq_loss_embed_dim: 16
|
| 91 |
-
audio_rq_loss_num_codebooks: 1
|
| 92 |
-
audio_rq_loss_num_embeds: 8192
|
| 93 |
-
audio_rq_loss_seed: 42
|
| 94 |
-
audio_rq_loss_use_norm: true
|
| 95 |
-
|
| 96 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 97 |
-
audio_cqt_loss_m: true
|
| 98 |
-
audio_cqt_bins: 336
|
| 99 |
-
# -----------
|
| 100 |
-
final_dim: 64
|
| 101 |
-
encoder_layerdrop: 0.05
|
| 102 |
-
dropout_input: 0.1
|
| 103 |
-
dropout_features: 0.1
|
| 104 |
-
dropout: 0.1
|
| 105 |
-
attention_dropout: 0.1
|
| 106 |
-
feature_grad_mult: 0.1
|
| 107 |
-
untie_final_proj: true
|
| 108 |
-
activation_dropout: 0.0
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
hydra:
|
| 112 |
-
job:
|
| 113 |
-
config:
|
| 114 |
-
override_dirname:
|
| 115 |
-
kv_sep: '-'
|
| 116 |
-
item_sep: '__'
|
| 117 |
-
exclude_keys:
|
| 118 |
-
- run
|
| 119 |
-
- task.data
|
| 120 |
-
- task.label_dir
|
| 121 |
-
run:
|
| 122 |
-
dir: ???
|
| 123 |
-
sweep:
|
| 124 |
-
dir: ???
|
| 125 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml
DELETED
|
@@ -1,128 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
update_freq: [4]
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: adam
|
| 60 |
-
adam_betas: (0.9,0.98)
|
| 61 |
-
adam_eps: 1e-06
|
| 62 |
-
weight_decay: 0.01
|
| 63 |
-
|
| 64 |
-
lr_scheduler:
|
| 65 |
-
_name: polynomial_decay
|
| 66 |
-
warmup_updates: 32000
|
| 67 |
-
|
| 68 |
-
model:
|
| 69 |
-
_name: mert
|
| 70 |
-
label_rate: ???
|
| 71 |
-
skip_masked: false
|
| 72 |
-
skip_nomask: true
|
| 73 |
-
mask_prob: 0.8
|
| 74 |
-
mask_length: 5
|
| 75 |
-
|
| 76 |
-
logit_temp: 0.1
|
| 77 |
-
|
| 78 |
-
# ----- mixture ------
|
| 79 |
-
mixture_prob: 0.5
|
| 80 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 81 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 82 |
-
inbatch_noise_augment_volume: 1.0
|
| 83 |
-
# ------------------------
|
| 84 |
-
extractor_mode: default
|
| 85 |
-
audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
|
| 86 |
-
melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
|
| 87 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 88 |
-
|
| 89 |
-
# best-rq loss
|
| 90 |
-
audio_rq_loss_m: true
|
| 91 |
-
audio_rq_loss_embed_dim: 16
|
| 92 |
-
audio_rq_loss_num_codebooks: 1
|
| 93 |
-
audio_rq_loss_num_embeds: 8192
|
| 94 |
-
audio_rq_loss_seed: 42
|
| 95 |
-
audio_rq_loss_use_norm: true
|
| 96 |
-
audio_rq_loss_use_chroma: true
|
| 97 |
-
audio_rq_loss_seed_chroma: 123
|
| 98 |
-
|
| 99 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 100 |
-
audio_cqt_loss_m: true
|
| 101 |
-
audio_cqt_bins: 336
|
| 102 |
-
# -----------
|
| 103 |
-
final_dim: 32
|
| 104 |
-
encoder_layerdrop: 0.05
|
| 105 |
-
dropout_input: 0.1
|
| 106 |
-
dropout_features: 0.1
|
| 107 |
-
dropout: 0.1
|
| 108 |
-
attention_dropout: 0.1
|
| 109 |
-
feature_grad_mult: 0.1
|
| 110 |
-
untie_final_proj: true
|
| 111 |
-
activation_dropout: 0.0
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
hydra:
|
| 115 |
-
job:
|
| 116 |
-
config:
|
| 117 |
-
override_dirname:
|
| 118 |
-
kv_sep: '-'
|
| 119 |
-
item_sep: '__'
|
| 120 |
-
exclude_keys:
|
| 121 |
-
- run
|
| 122 |
-
- task.data
|
| 123 |
-
- task.label_dir
|
| 124 |
-
run:
|
| 125 |
-
dir: ???
|
| 126 |
-
sweep:
|
| 127 |
-
dir: ???
|
| 128 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml
DELETED
|
@@ -1,126 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
update_freq: [4]
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: adam
|
| 60 |
-
adam_betas: (0.9,0.98)
|
| 61 |
-
adam_eps: 1e-06
|
| 62 |
-
weight_decay: 0.01
|
| 63 |
-
|
| 64 |
-
lr_scheduler:
|
| 65 |
-
_name: polynomial_decay
|
| 66 |
-
warmup_updates: 32000
|
| 67 |
-
|
| 68 |
-
model:
|
| 69 |
-
_name: mert
|
| 70 |
-
label_rate: ???
|
| 71 |
-
skip_masked: false
|
| 72 |
-
skip_nomask: true
|
| 73 |
-
mask_prob: 0.8
|
| 74 |
-
mask_length: 5
|
| 75 |
-
|
| 76 |
-
logit_temp: 0.1
|
| 77 |
-
|
| 78 |
-
# ----- mixture ------
|
| 79 |
-
mixture_prob: 0.5
|
| 80 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 81 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 82 |
-
inbatch_noise_augment_volume: 1.0
|
| 83 |
-
# ------------------------
|
| 84 |
-
extractor_mode: default
|
| 85 |
-
audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
|
| 86 |
-
melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
|
| 87 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 88 |
-
|
| 89 |
-
# best-rq loss
|
| 90 |
-
audio_rq_loss_m: true
|
| 91 |
-
audio_rq_loss_embed_dim: 16
|
| 92 |
-
audio_rq_loss_num_codebooks: 1
|
| 93 |
-
audio_rq_loss_num_embeds: 8192
|
| 94 |
-
audio_rq_loss_seed: 42
|
| 95 |
-
audio_rq_loss_use_norm: true
|
| 96 |
-
|
| 97 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 98 |
-
audio_cqt_loss_m: true
|
| 99 |
-
audio_cqt_bins: 336
|
| 100 |
-
# -----------
|
| 101 |
-
final_dim: 64
|
| 102 |
-
encoder_layerdrop: 0.05
|
| 103 |
-
dropout_input: 0.1
|
| 104 |
-
dropout_features: 0.1
|
| 105 |
-
dropout: 0.1
|
| 106 |
-
attention_dropout: 0.1
|
| 107 |
-
feature_grad_mult: 0.1
|
| 108 |
-
untie_final_proj: true
|
| 109 |
-
activation_dropout: 0.0
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
hydra:
|
| 113 |
-
job:
|
| 114 |
-
config:
|
| 115 |
-
override_dirname:
|
| 116 |
-
kv_sep: '-'
|
| 117 |
-
item_sep: '__'
|
| 118 |
-
exclude_keys:
|
| 119 |
-
- run
|
| 120 |
-
- task.data
|
| 121 |
-
- task.label_dir
|
| 122 |
-
run:
|
| 123 |
-
dir: ???
|
| 124 |
-
sweep:
|
| 125 |
-
dir: ???
|
| 126 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml
DELETED
|
@@ -1,128 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
update_freq: [4]
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: adam
|
| 60 |
-
adam_betas: (0.9,0.98)
|
| 61 |
-
adam_eps: 1e-06
|
| 62 |
-
weight_decay: 0.01
|
| 63 |
-
|
| 64 |
-
lr_scheduler:
|
| 65 |
-
_name: polynomial_decay
|
| 66 |
-
warmup_updates: 32000
|
| 67 |
-
|
| 68 |
-
model:
|
| 69 |
-
_name: mert
|
| 70 |
-
label_rate: ???
|
| 71 |
-
skip_masked: false
|
| 72 |
-
skip_nomask: true
|
| 73 |
-
mask_prob: 0.8
|
| 74 |
-
mask_length: 5
|
| 75 |
-
|
| 76 |
-
logit_temp: 0.1
|
| 77 |
-
|
| 78 |
-
# ----- mixture ------
|
| 79 |
-
mixture_prob: 0.5
|
| 80 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 81 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 82 |
-
inbatch_noise_augment_volume: 1.0
|
| 83 |
-
# ------------------------
|
| 84 |
-
extractor_mode: default
|
| 85 |
-
audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
|
| 86 |
-
melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
|
| 87 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 88 |
-
|
| 89 |
-
# best-rq loss
|
| 90 |
-
audio_rq_loss_m: true
|
| 91 |
-
audio_rq_loss_embed_dim: 16
|
| 92 |
-
audio_rq_loss_num_codebooks: 1
|
| 93 |
-
audio_rq_loss_num_embeds: 8192
|
| 94 |
-
audio_rq_loss_seed: 42
|
| 95 |
-
audio_rq_loss_use_norm: true
|
| 96 |
-
audio_rq_loss_use_chroma: false
|
| 97 |
-
audio_rq_loss_seed_chroma: 123
|
| 98 |
-
|
| 99 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 100 |
-
audio_cqt_loss_m: true
|
| 101 |
-
audio_cqt_bins: 336
|
| 102 |
-
# -----------
|
| 103 |
-
final_dim: 64
|
| 104 |
-
encoder_layerdrop: 0.05
|
| 105 |
-
dropout_input: 0.1
|
| 106 |
-
dropout_features: 0.1
|
| 107 |
-
dropout: 0.1
|
| 108 |
-
attention_dropout: 0.1
|
| 109 |
-
feature_grad_mult: 0.1
|
| 110 |
-
untie_final_proj: true
|
| 111 |
-
activation_dropout: 0.0
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
hydra:
|
| 115 |
-
job:
|
| 116 |
-
config:
|
| 117 |
-
override_dirname:
|
| 118 |
-
kv_sep: '-'
|
| 119 |
-
item_sep: '__'
|
| 120 |
-
exclude_keys:
|
| 121 |
-
- run
|
| 122 |
-
- task.data
|
| 123 |
-
- task.label_dir
|
| 124 |
-
run:
|
| 125 |
-
dir: ???
|
| 126 |
-
sweep:
|
| 127 |
-
dir: ???
|
| 128 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml
DELETED
|
@@ -1,128 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
update_freq: [4]
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: adam
|
| 60 |
-
adam_betas: (0.9,0.98)
|
| 61 |
-
adam_eps: 1e-06
|
| 62 |
-
weight_decay: 0.01
|
| 63 |
-
|
| 64 |
-
lr_scheduler:
|
| 65 |
-
_name: polynomial_decay
|
| 66 |
-
warmup_updates: 32000
|
| 67 |
-
|
| 68 |
-
model:
|
| 69 |
-
_name: mert
|
| 70 |
-
label_rate: ???
|
| 71 |
-
skip_masked: false
|
| 72 |
-
skip_nomask: true
|
| 73 |
-
mask_prob: 0.8
|
| 74 |
-
mask_length: 5
|
| 75 |
-
|
| 76 |
-
logit_temp: 0.1
|
| 77 |
-
|
| 78 |
-
# ----- mixture ------
|
| 79 |
-
mixture_prob: 0 # 0.5
|
| 80 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 81 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 82 |
-
inbatch_noise_augment_volume: 1.0
|
| 83 |
-
# ------------------------
|
| 84 |
-
extractor_mode: default
|
| 85 |
-
audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
|
| 86 |
-
melspec_n_bins: 80 # 120 # for melspec we use 120, means 12 bins per octave
|
| 87 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 88 |
-
|
| 89 |
-
# best-rq loss
|
| 90 |
-
audio_rq_loss_m: true
|
| 91 |
-
audio_rq_loss_embed_dim: 16
|
| 92 |
-
audio_rq_loss_num_codebooks: 1
|
| 93 |
-
audio_rq_loss_num_embeds: 8192
|
| 94 |
-
audio_rq_loss_seed: 42
|
| 95 |
-
audio_rq_loss_use_norm: true
|
| 96 |
-
audio_rq_loss_use_chroma: false
|
| 97 |
-
audio_rq_loss_seed_chroma: 123
|
| 98 |
-
|
| 99 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 100 |
-
audio_cqt_loss_m: false
|
| 101 |
-
audio_cqt_bins: 336
|
| 102 |
-
# -----------
|
| 103 |
-
final_dim: 64
|
| 104 |
-
encoder_layerdrop: 0.05
|
| 105 |
-
dropout_input: 0.1
|
| 106 |
-
dropout_features: 0.1
|
| 107 |
-
dropout: 0.1
|
| 108 |
-
attention_dropout: 0.1
|
| 109 |
-
feature_grad_mult: 0.1
|
| 110 |
-
untie_final_proj: true
|
| 111 |
-
activation_dropout: 0.0
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
hydra:
|
| 115 |
-
job:
|
| 116 |
-
config:
|
| 117 |
-
override_dirname:
|
| 118 |
-
kv_sep: '-'
|
| 119 |
-
item_sep: '__'
|
| 120 |
-
exclude_keys:
|
| 121 |
-
- run
|
| 122 |
-
- task.data
|
| 123 |
-
- task.label_dir
|
| 124 |
-
run:
|
| 125 |
-
dir: ???
|
| 126 |
-
sweep:
|
| 127 |
-
dir: ???
|
| 128 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml
DELETED
|
@@ -1,121 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
update_freq: [4]
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: adam
|
| 60 |
-
adam_betas: (0.9,0.98)
|
| 61 |
-
adam_eps: 1e-06
|
| 62 |
-
weight_decay: 0.01
|
| 63 |
-
|
| 64 |
-
lr_scheduler:
|
| 65 |
-
_name: polynomial_decay
|
| 66 |
-
warmup_updates: 32000
|
| 67 |
-
|
| 68 |
-
model:
|
| 69 |
-
_name: mert
|
| 70 |
-
label_rate: ???
|
| 71 |
-
skip_masked: false
|
| 72 |
-
skip_nomask: true
|
| 73 |
-
mask_prob: 0.8
|
| 74 |
-
mask_length: 5
|
| 75 |
-
|
| 76 |
-
logit_temp: 0.1
|
| 77 |
-
|
| 78 |
-
# ----- mixture ------
|
| 79 |
-
mixture_prob: 0.5
|
| 80 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 81 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 82 |
-
inbatch_noise_augment_volume: 1.0
|
| 83 |
-
# ------------------------
|
| 84 |
-
extractor_mode: default
|
| 85 |
-
audio_extract_type: w2v_conv
|
| 86 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 87 |
-
|
| 88 |
-
# ---- codec target
|
| 89 |
-
audio_codec_type: rvq
|
| 90 |
-
audio_codec_ckpt_path: RVQ_3000.pth
|
| 91 |
-
|
| 92 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 93 |
-
audio_cqt_loss_m: true
|
| 94 |
-
audio_cqt_bins: 336
|
| 95 |
-
# -----------
|
| 96 |
-
final_dim: 64
|
| 97 |
-
encoder_layerdrop: 0.05
|
| 98 |
-
dropout_input: 0.1
|
| 99 |
-
dropout_features: 0.1
|
| 100 |
-
dropout: 0.1
|
| 101 |
-
attention_dropout: 0.1
|
| 102 |
-
feature_grad_mult: 0.1
|
| 103 |
-
untie_final_proj: true
|
| 104 |
-
activation_dropout: 0.0
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
hydra:
|
| 108 |
-
job:
|
| 109 |
-
config:
|
| 110 |
-
override_dirname:
|
| 111 |
-
kv_sep: '-'
|
| 112 |
-
item_sep: '__'
|
| 113 |
-
exclude_keys:
|
| 114 |
-
- run
|
| 115 |
-
- task.data
|
| 116 |
-
- task.label_dir
|
| 117 |
-
run:
|
| 118 |
-
dir: ???
|
| 119 |
-
sweep:
|
| 120 |
-
dir: ???
|
| 121 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac.yaml
DELETED
|
File without changes
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml
DELETED
|
@@ -1,121 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
update_freq: [4]
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: adam
|
| 60 |
-
adam_betas: (0.9,0.98)
|
| 61 |
-
adam_eps: 1e-06
|
| 62 |
-
weight_decay: 0.01
|
| 63 |
-
|
| 64 |
-
lr_scheduler:
|
| 65 |
-
_name: polynomial_decay
|
| 66 |
-
warmup_updates: 32000
|
| 67 |
-
|
| 68 |
-
model:
|
| 69 |
-
_name: mert
|
| 70 |
-
label_rate: ???
|
| 71 |
-
skip_masked: false
|
| 72 |
-
skip_nomask: true
|
| 73 |
-
mask_prob: 0.8
|
| 74 |
-
mask_length: 5
|
| 75 |
-
|
| 76 |
-
logit_temp: 0.1
|
| 77 |
-
|
| 78 |
-
# ----- mixture ------
|
| 79 |
-
mixture_prob: 0.5
|
| 80 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 81 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 82 |
-
inbatch_noise_augment_volume: 1.0
|
| 83 |
-
# ------------------------
|
| 84 |
-
extractor_mode: default
|
| 85 |
-
audio_extract_type: w2v_conv
|
| 86 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 87 |
-
|
| 88 |
-
# ---- codec target
|
| 89 |
-
audio_codec_type: dac
|
| 90 |
-
audio_codec_dac_model_path: weights_24khz_8kbps_0.0.4.pth #nj
|
| 91 |
-
|
| 92 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 93 |
-
audio_cqt_loss_m: true
|
| 94 |
-
audio_cqt_bins: 336
|
| 95 |
-
# -----------
|
| 96 |
-
final_dim: 64
|
| 97 |
-
encoder_layerdrop: 0.05
|
| 98 |
-
dropout_input: 0.1
|
| 99 |
-
dropout_features: 0.1
|
| 100 |
-
dropout: 0.1
|
| 101 |
-
attention_dropout: 0.1
|
| 102 |
-
feature_grad_mult: 0.1
|
| 103 |
-
untie_final_proj: true
|
| 104 |
-
activation_dropout: 0.0
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
hydra:
|
| 108 |
-
job:
|
| 109 |
-
config:
|
| 110 |
-
override_dirname:
|
| 111 |
-
kv_sep: '-'
|
| 112 |
-
item_sep: '__'
|
| 113 |
-
exclude_keys:
|
| 114 |
-
- run
|
| 115 |
-
- task.data
|
| 116 |
-
- task.label_dir
|
| 117 |
-
run:
|
| 118 |
-
dir: ???
|
| 119 |
-
sweep:
|
| 120 |
-
dir: ???
|
| 121 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
update_freq: [4]
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: adam
|
| 60 |
-
adam_betas: (0.9,0.98)
|
| 61 |
-
adam_eps: 1e-06
|
| 62 |
-
weight_decay: 0.01
|
| 63 |
-
|
| 64 |
-
lr_scheduler:
|
| 65 |
-
_name: polynomial_decay
|
| 66 |
-
warmup_updates: 32000
|
| 67 |
-
|
| 68 |
-
model:
|
| 69 |
-
_name: mert
|
| 70 |
-
label_rate: ???
|
| 71 |
-
skip_masked: false
|
| 72 |
-
skip_nomask: true
|
| 73 |
-
mask_prob: 0.8
|
| 74 |
-
mask_length: 5
|
| 75 |
-
|
| 76 |
-
logit_temp: 0.1
|
| 77 |
-
|
| 78 |
-
# ----- mixture ------
|
| 79 |
-
mixture_prob: 0.5
|
| 80 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 81 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 82 |
-
inbatch_noise_augment_volume: 1.0
|
| 83 |
-
# ------------------------
|
| 84 |
-
extractor_mode: default
|
| 85 |
-
audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
|
| 86 |
-
melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
|
| 87 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 88 |
-
|
| 89 |
-
# best-rq loss
|
| 90 |
-
audio_rq_loss_m: true
|
| 91 |
-
audio_rq_loss_embed_dim: 16
|
| 92 |
-
audio_rq_loss_num_codebooks: 64 # 32
|
| 93 |
-
audio_rq_loss_num_embeds: 1024
|
| 94 |
-
audio_rq_loss_seed: 42
|
| 95 |
-
|
| 96 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 97 |
-
audio_cqt_loss_m: true
|
| 98 |
-
audio_cqt_bins: 336
|
| 99 |
-
# -----------
|
| 100 |
-
final_dim: 16 # 64
|
| 101 |
-
encoder_layerdrop: 0.05
|
| 102 |
-
dropout_input: 0.1
|
| 103 |
-
dropout_features: 0.1
|
| 104 |
-
dropout: 0.1
|
| 105 |
-
attention_dropout: 0.1
|
| 106 |
-
feature_grad_mult: 0.1
|
| 107 |
-
untie_final_proj: true
|
| 108 |
-
activation_dropout: 0.0
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
hydra:
|
| 112 |
-
job:
|
| 113 |
-
config:
|
| 114 |
-
override_dirname:
|
| 115 |
-
kv_sep: '-'
|
| 116 |
-
item_sep: '__'
|
| 117 |
-
exclude_keys:
|
| 118 |
-
- run
|
| 119 |
-
- task.data
|
| 120 |
-
- task.label_dir
|
| 121 |
-
run:
|
| 122 |
-
dir: ???
|
| 123 |
-
sweep:
|
| 124 |
-
dir: ???
|
| 125 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# crop to 5s
|
| 31 |
-
max_sample_size: 120000
|
| 32 |
-
min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
pad_audio: false
|
| 35 |
-
random_crop: true
|
| 36 |
-
normalize: false # must be consistent with extractor
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
dataset:
|
| 40 |
-
num_workers: 6
|
| 41 |
-
max_tokens: 2000000
|
| 42 |
-
skip_invalid_size_inputs_valid_test: true
|
| 43 |
-
validate_interval: 1
|
| 44 |
-
validate_interval_updates: 10000
|
| 45 |
-
|
| 46 |
-
criterion:
|
| 47 |
-
_name: hubert
|
| 48 |
-
pred_masked_weight: 1.0
|
| 49 |
-
pred_nomask_weight: 0.0
|
| 50 |
-
loss_weights: [10, 1]
|
| 51 |
-
|
| 52 |
-
optimization:
|
| 53 |
-
max_update: 400000
|
| 54 |
-
lr: [0.0005]
|
| 55 |
-
clip_norm: 10.0
|
| 56 |
-
update_freq: [4]
|
| 57 |
-
|
| 58 |
-
optimizer:
|
| 59 |
-
_name: adam
|
| 60 |
-
adam_betas: (0.9,0.98)
|
| 61 |
-
adam_eps: 1e-06
|
| 62 |
-
weight_decay: 0.01
|
| 63 |
-
|
| 64 |
-
lr_scheduler:
|
| 65 |
-
_name: polynomial_decay
|
| 66 |
-
warmup_updates: 32000
|
| 67 |
-
|
| 68 |
-
model:
|
| 69 |
-
_name: mert
|
| 70 |
-
label_rate: ???
|
| 71 |
-
skip_masked: false
|
| 72 |
-
skip_nomask: true
|
| 73 |
-
mask_prob: 0.8
|
| 74 |
-
mask_length: 5
|
| 75 |
-
|
| 76 |
-
logit_temp: 0.1
|
| 77 |
-
|
| 78 |
-
# ----- mixture ------
|
| 79 |
-
mixture_prob: 0.5
|
| 80 |
-
inbatch_noise_augment_len_range: "[12000, 24000]"
|
| 81 |
-
inbatch_noise_augment_number_range: "[1, 3]"
|
| 82 |
-
inbatch_noise_augment_volume: 1.0
|
| 83 |
-
# ------------------------
|
| 84 |
-
extractor_mode: default
|
| 85 |
-
audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
|
| 86 |
-
melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
|
| 87 |
-
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
|
| 88 |
-
|
| 89 |
-
# best-rq loss
|
| 90 |
-
audio_rq_loss_m: false
|
| 91 |
-
audio_rq_loss_embed_dim: 16
|
| 92 |
-
audio_rq_loss_num_codebooks: 1
|
| 93 |
-
audio_rq_loss_num_embeds: 8192
|
| 94 |
-
|
| 95 |
-
# ---- cqt reconstruction, need to add loss weight ---
|
| 96 |
-
audio_cqt_loss_m: true
|
| 97 |
-
audio_cqt_bins: 336
|
| 98 |
-
# -----------
|
| 99 |
-
final_dim: 64
|
| 100 |
-
encoder_layerdrop: 0.05
|
| 101 |
-
dropout_input: 0.1
|
| 102 |
-
dropout_features: 0.1
|
| 103 |
-
dropout: 0.1
|
| 104 |
-
attention_dropout: 0.1
|
| 105 |
-
feature_grad_mult: 0.1
|
| 106 |
-
untie_final_proj: true
|
| 107 |
-
activation_dropout: 0.0
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
hydra:
|
| 111 |
-
job:
|
| 112 |
-
config:
|
| 113 |
-
override_dirname:
|
| 114 |
-
kv_sep: '-'
|
| 115 |
-
item_sep: '__'
|
| 116 |
-
exclude_keys:
|
| 117 |
-
- run
|
| 118 |
-
- task.data
|
| 119 |
-
- task.label_dir
|
| 120 |
-
run:
|
| 121 |
-
dir: ???
|
| 122 |
-
sweep:
|
| 123 |
-
dir: ???
|
| 124 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml
DELETED
|
@@ -1,108 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# # crop to 5s
|
| 31 |
-
# max_sample_size: 120000
|
| 32 |
-
# min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
# crop to 30s
|
| 35 |
-
max_sample_size: 720000
|
| 36 |
-
min_sample_size: 432000
|
| 37 |
-
clip_secs: 30
|
| 38 |
-
|
| 39 |
-
pad_audio: false
|
| 40 |
-
random_crop: true
|
| 41 |
-
normalize: false # must be consistent with extractor
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
dataset:
|
| 45 |
-
num_workers: 6
|
| 46 |
-
max_tokens: 2000000
|
| 47 |
-
skip_invalid_size_inputs_valid_test: true
|
| 48 |
-
validate_interval: 1
|
| 49 |
-
validate_interval_updates: 10000
|
| 50 |
-
|
| 51 |
-
criterion:
|
| 52 |
-
_name: model
|
| 53 |
-
# log_keys:
|
| 54 |
-
# - accuracies
|
| 55 |
-
|
| 56 |
-
optimization:
|
| 57 |
-
max_update: 400000
|
| 58 |
-
lr: [0.0005]
|
| 59 |
-
clip_norm: 10.0
|
| 60 |
-
update_freq: [1]
|
| 61 |
-
|
| 62 |
-
optimizer:
|
| 63 |
-
_name: adam
|
| 64 |
-
adam_betas: (0.9,0.98)
|
| 65 |
-
adam_eps: 1e-06
|
| 66 |
-
weight_decay: 0.01
|
| 67 |
-
|
| 68 |
-
lr_scheduler:
|
| 69 |
-
_name: polynomial_decay
|
| 70 |
-
warmup_updates: 32000
|
| 71 |
-
|
| 72 |
-
model:
|
| 73 |
-
_name: musicfm
|
| 74 |
-
label_rate: 25
|
| 75 |
-
num_codebooks: 1
|
| 76 |
-
codebook_dim: 16
|
| 77 |
-
codebook_size: 8192 # 4096
|
| 78 |
-
features: ["melspec_2048"]
|
| 79 |
-
hop_length: 240
|
| 80 |
-
n_mels: 128
|
| 81 |
-
conv_dim: 512
|
| 82 |
-
encoder_dim: 1024
|
| 83 |
-
encoder_depth: 12
|
| 84 |
-
mask_hop: 0.4
|
| 85 |
-
mask_prob: 0.6
|
| 86 |
-
is_flash: false
|
| 87 |
-
|
| 88 |
-
stat_path: msd_stats.json
|
| 89 |
-
model_path: null
|
| 90 |
-
w2v2_config_path: our-MERT/data/models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa
|
| 91 |
-
use_rvq_target: true
|
| 92 |
-
rvq_ckpt_path: RVQ_4000.pth
|
| 93 |
-
|
| 94 |
-
hydra:
|
| 95 |
-
job:
|
| 96 |
-
config:
|
| 97 |
-
override_dirname:
|
| 98 |
-
kv_sep: '-'
|
| 99 |
-
item_sep: '__'
|
| 100 |
-
exclude_keys:
|
| 101 |
-
- run
|
| 102 |
-
- task.data
|
| 103 |
-
- task.label_dir
|
| 104 |
-
run:
|
| 105 |
-
dir: ???
|
| 106 |
-
sweep:
|
| 107 |
-
dir: ???
|
| 108 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml
DELETED
|
@@ -1,105 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 12500
|
| 12 |
-
keep_interval_updates: -1
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# # crop to 5s
|
| 31 |
-
# max_sample_size: 120000
|
| 32 |
-
# min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
# crop to 30s
|
| 35 |
-
max_sample_size: 720000
|
| 36 |
-
min_sample_size: 432000
|
| 37 |
-
clip_secs: 30
|
| 38 |
-
|
| 39 |
-
pad_audio: false
|
| 40 |
-
random_crop: true
|
| 41 |
-
normalize: false # must be consistent with extractor
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
dataset:
|
| 45 |
-
num_workers: 6
|
| 46 |
-
max_tokens: 2000000
|
| 47 |
-
skip_invalid_size_inputs_valid_test: true
|
| 48 |
-
validate_interval: 1
|
| 49 |
-
validate_interval_updates: 10000
|
| 50 |
-
|
| 51 |
-
criterion:
|
| 52 |
-
_name: model
|
| 53 |
-
# log_keys:
|
| 54 |
-
# - accuracies
|
| 55 |
-
|
| 56 |
-
optimization:
|
| 57 |
-
max_update: 400000
|
| 58 |
-
lr: [0.0005]
|
| 59 |
-
clip_norm: 10.0
|
| 60 |
-
update_freq: [1]
|
| 61 |
-
|
| 62 |
-
optimizer:
|
| 63 |
-
_name: adam
|
| 64 |
-
adam_betas: (0.9,0.98)
|
| 65 |
-
adam_eps: 1e-06
|
| 66 |
-
weight_decay: 0.01
|
| 67 |
-
|
| 68 |
-
lr_scheduler:
|
| 69 |
-
_name: polynomial_decay
|
| 70 |
-
warmup_updates: 32000
|
| 71 |
-
|
| 72 |
-
model:
|
| 73 |
-
_name: musicfm
|
| 74 |
-
label_rate: 25
|
| 75 |
-
num_codebooks: 1
|
| 76 |
-
codebook_dim: 16
|
| 77 |
-
codebook_size: 4096
|
| 78 |
-
features: ["melspec_2048"]
|
| 79 |
-
hop_length: 240
|
| 80 |
-
n_mels: 128
|
| 81 |
-
conv_dim: 512
|
| 82 |
-
encoder_dim: 1024
|
| 83 |
-
encoder_depth: 12
|
| 84 |
-
mask_hop: 0.4
|
| 85 |
-
mask_prob: 0.6
|
| 86 |
-
is_flash: false
|
| 87 |
-
stat_path: msd_stats.json
|
| 88 |
-
model_path: pretrained_msd.pt
|
| 89 |
-
w2v2_config_path: models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa
|
| 90 |
-
|
| 91 |
-
hydra:
|
| 92 |
-
job:
|
| 93 |
-
config:
|
| 94 |
-
override_dirname:
|
| 95 |
-
kv_sep: '-'
|
| 96 |
-
item_sep: '__'
|
| 97 |
-
exclude_keys:
|
| 98 |
-
- run
|
| 99 |
-
- task.data
|
| 100 |
-
- task.label_dir
|
| 101 |
-
run:
|
| 102 |
-
dir: ???
|
| 103 |
-
sweep:
|
| 104 |
-
dir: ???
|
| 105 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml
DELETED
|
@@ -1,106 +0,0 @@
|
|
| 1 |
-
# @package _group_
|
| 2 |
-
common:
|
| 3 |
-
fp16: false
|
| 4 |
-
log_format: json
|
| 5 |
-
log_interval: 200
|
| 6 |
-
seed: 1337
|
| 7 |
-
# tensorboard_logdir: tblog_proj_name
|
| 8 |
-
# wandb_project: wandb_proj_name
|
| 9 |
-
|
| 10 |
-
checkpoint:
|
| 11 |
-
save_interval_updates: 2500
|
| 12 |
-
keep_interval_updates: 10000
|
| 13 |
-
no_epoch_checkpoints: true
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
distributed_training:
|
| 17 |
-
ddp_backend: no_c10d
|
| 18 |
-
distributed_backend: 'nccl'
|
| 19 |
-
distributed_world_size: 64
|
| 20 |
-
nprocs_per_node: 8
|
| 21 |
-
find_unused_parameters: true
|
| 22 |
-
|
| 23 |
-
task:
|
| 24 |
-
_name: mert_pretraining
|
| 25 |
-
data: ???
|
| 26 |
-
label_dir: ???
|
| 27 |
-
labels: ???
|
| 28 |
-
label_rate: ${model.label_rate}
|
| 29 |
-
sample_rate: 24000
|
| 30 |
-
# # crop to 5s
|
| 31 |
-
# max_sample_size: 120000
|
| 32 |
-
# min_sample_size: 72000
|
| 33 |
-
|
| 34 |
-
# crop to 30s
|
| 35 |
-
max_sample_size: 720000
|
| 36 |
-
min_sample_size: 12000
|
| 37 |
-
# clip_secs: 30
|
| 38 |
-
|
| 39 |
-
pad_audio: false
|
| 40 |
-
random_crop: true
|
| 41 |
-
normalize: false # must be consistent with extractor
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
dataset:
|
| 45 |
-
num_workers: 6
|
| 46 |
-
max_tokens: 2000000
|
| 47 |
-
skip_invalid_size_inputs_valid_test: true
|
| 48 |
-
validate_interval: 1
|
| 49 |
-
validate_interval_updates: 10000
|
| 50 |
-
disable_validation: true
|
| 51 |
-
|
| 52 |
-
criterion:
|
| 53 |
-
_name: model
|
| 54 |
-
# log_keys:
|
| 55 |
-
# - accuracies
|
| 56 |
-
|
| 57 |
-
optimization:
|
| 58 |
-
max_update: 400000
|
| 59 |
-
lr: [0.0005]
|
| 60 |
-
clip_norm: 10.0
|
| 61 |
-
update_freq: [1]
|
| 62 |
-
|
| 63 |
-
optimizer:
|
| 64 |
-
_name: adam
|
| 65 |
-
adam_betas: (0.9,0.98)
|
| 66 |
-
adam_eps: 1e-06
|
| 67 |
-
weight_decay: 0.01
|
| 68 |
-
|
| 69 |
-
lr_scheduler:
|
| 70 |
-
_name: polynomial_decay
|
| 71 |
-
warmup_updates: 32000
|
| 72 |
-
|
| 73 |
-
model:
|
| 74 |
-
_name: musicfm
|
| 75 |
-
label_rate: 25
|
| 76 |
-
num_codebooks: 1
|
| 77 |
-
codebook_dim: 16
|
| 78 |
-
codebook_size: 4096
|
| 79 |
-
features: ["melspec_2048"]
|
| 80 |
-
hop_length: 240
|
| 81 |
-
n_mels: 128
|
| 82 |
-
conv_dim: 512
|
| 83 |
-
encoder_dim: 1024
|
| 84 |
-
encoder_depth: 12
|
| 85 |
-
mask_hop: 0.4
|
| 86 |
-
mask_prob: 0.6
|
| 87 |
-
is_flash: false
|
| 88 |
-
stat_path: msd_stats.json
|
| 89 |
-
model_path: null
|
| 90 |
-
w2v2_config_path: models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa
|
| 91 |
-
|
| 92 |
-
hydra:
|
| 93 |
-
job:
|
| 94 |
-
config:
|
| 95 |
-
override_dirname:
|
| 96 |
-
kv_sep: '-'
|
| 97 |
-
item_sep: '__'
|
| 98 |
-
exclude_keys:
|
| 99 |
-
- run
|
| 100 |
-
- task.data
|
| 101 |
-
- task.label_dir
|
| 102 |
-
run:
|
| 103 |
-
dir: ???
|
| 104 |
-
sweep:
|
| 105 |
-
dir: ???
|
| 106 |
-
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
# @package _global_
|
| 2 |
-
|
| 3 |
-
hydra:
|
| 4 |
-
launcher:
|
| 5 |
-
cpus_per_task: 8
|
| 6 |
-
gpus_per_node: 8
|
| 7 |
-
tasks_per_node: ${hydra.launcher.gpus_per_node}
|
| 8 |
-
nodes: 4
|
| 9 |
-
comment: null
|
| 10 |
-
mem_gb: 384
|
| 11 |
-
timeout_min: 4320
|
| 12 |
-
max_num_timeout: 100
|
| 13 |
-
constraint: volta32gb
|
| 14 |
-
name: ${hydra.job.config_name}/${hydra.job.override_dirname}
|
| 15 |
-
submitit_folder: ${hydra.sweep.dir}/submitit/%j
|
| 16 |
-
|
| 17 |
-
distributed_training:
|
| 18 |
-
distributed_world_size: 32
|
| 19 |
-
distributed_port: 29671
|
| 20 |
-
nprocs_per_node: 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
from .mert_dataset import MERTDataset
|
| 2 |
-
from .eat_data import *
|
|
|
|
|
|
|
|
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
|
| 5 |
-
from typing import Tuple
|
| 6 |
-
try:
|
| 7 |
-
import kaldiio
|
| 8 |
-
except:
|
| 9 |
-
kaldiio = None
|
| 10 |
-
import warnings
|
| 11 |
-
|
| 12 |
-
logger = logging.getLogger(__name__)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
class ArkDataset(RawAudioDataset):
|
| 16 |
-
def __init__(
|
| 17 |
-
self,
|
| 18 |
-
wav_scp,
|
| 19 |
-
dur_scp,
|
| 20 |
-
sr = 24000,
|
| 21 |
-
max_dur = 20,
|
| 22 |
-
num_buckets=0,
|
| 23 |
-
normalize=False,
|
| 24 |
-
):
|
| 25 |
-
super().__init__(
|
| 26 |
-
sample_rate=sr,
|
| 27 |
-
max_sample_size=max_dur*sr,
|
| 28 |
-
min_sample_size=1200,
|
| 29 |
-
shuffle=True,
|
| 30 |
-
pad=True,
|
| 31 |
-
normalize=normalize,
|
| 32 |
-
compute_mask=False,
|
| 33 |
-
)
|
| 34 |
-
self.sr = sr
|
| 35 |
-
self.max_dur = max_dur
|
| 36 |
-
self.normalize = normalize
|
| 37 |
-
|
| 38 |
-
logger.info("Loading Kaldi scp files from {}".format(wav_scp))
|
| 39 |
-
|
| 40 |
-
self.wav_data = kaldiio.load_scp(wav_scp)
|
| 41 |
-
self.keys = list(self.wav_data.keys())
|
| 42 |
-
dur_data = {}
|
| 43 |
-
keys_set = set(self.keys)
|
| 44 |
-
|
| 45 |
-
with open(dur_scp, 'r') as f:
|
| 46 |
-
for line in f:
|
| 47 |
-
line = line.strip().split()
|
| 48 |
-
if line[0] in keys_set:
|
| 49 |
-
dur_data[line[0]] = float(line[-1])
|
| 50 |
-
self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys]
|
| 51 |
-
|
| 52 |
-
logger.info("Loading Kaldi scp files done")
|
| 53 |
-
|
| 54 |
-
self.dataset_len = len(self.keys)
|
| 55 |
-
self.set_bucket_info(num_buckets)
|
| 56 |
-
|
| 57 |
-
def __len__(self):
|
| 58 |
-
return self.dataset_len
|
| 59 |
-
|
| 60 |
-
def __getitem__(self, idx):
|
| 61 |
-
# print("getitem idx: ", idx)
|
| 62 |
-
try_cnt = 0
|
| 63 |
-
while True:
|
| 64 |
-
idx = idx + try_cnt
|
| 65 |
-
try:
|
| 66 |
-
with warnings.catch_warnings():
|
| 67 |
-
warnings.simplefilter("ignore")
|
| 68 |
-
key = self.keys[idx]
|
| 69 |
-
# print(self.wav_data[key].keys())
|
| 70 |
-
wav = self.wav_data[key]['wav']
|
| 71 |
-
|
| 72 |
-
wav = torch.from_numpy(wav).float()
|
| 73 |
-
wav = self.postprocess(wav)
|
| 74 |
-
# print("success load", idx, " shape =", wav.shape)
|
| 75 |
-
return {"id": idx, "source": wav}
|
| 76 |
-
except Exception as e:
|
| 77 |
-
# from traceback import print_exc
|
| 78 |
-
# print_exc()
|
| 79 |
-
# print("Error loadding ", idx)
|
| 80 |
-
# return {"id": idx, "source": None}
|
| 81 |
-
try_cnt += 1
|
| 82 |
-
if try_cnt > 50:
|
| 83 |
-
return {"id": idx, "source": None}
|
| 84 |
-
continue
|
| 85 |
-
|
| 86 |
-
def size(self, idx):
|
| 87 |
-
return self.sizes[idx]
|
| 88 |
-
|
| 89 |
-
def postprocess(self, wav):
|
| 90 |
-
if wav.dim() == 2:
|
| 91 |
-
wav = wav.mean(-1)
|
| 92 |
-
assert wav.dim() == 1, wav.dim()
|
| 93 |
-
|
| 94 |
-
if self.normalize:
|
| 95 |
-
with torch.no_grad():
|
| 96 |
-
wav = F.layer_norm(wav, wav.shape)
|
| 97 |
-
return wav
|
| 98 |
-
|
| 99 |
-
def collater(self, samples):
|
| 100 |
-
# print("collate from:", [s['source'].shape for s in samples if s['source'] is not None])
|
| 101 |
-
return super().collater(samples)
|
| 102 |
-
|
| 103 |
-
if __name__ == '__main__':
|
| 104 |
-
import torch
|
| 105 |
-
raw_tensor_str = torch.Tensor.__repr__
|
| 106 |
-
torch.Tensor.__str__ = torch.Tensor.__repr__ = lambda self: f'Tensor{{Size({[*self.shape]}) {self.device} {str(self.dtype)[6]}{str(self.dtype)[-2:]}}}' if self.numel() > 10 else raw_tensor_str(self)
|
| 107 |
-
|
| 108 |
-
ds = ArkDataset(
|
| 109 |
-
wav_scp='data/ark_demo/wav_ark.scp',
|
| 110 |
-
dur_scp='data/ark_demo/dur_ark.scp',
|
| 111 |
-
sr=24000,
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
for i in range(len(ds)):
|
| 115 |
-
print(ds[i])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|