root commited on
Commit
f9e2d84
·
1 Parent(s): 410c1c2

update v1.5-beta

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +5 -7
  2. codeclm/models/builders.py +5 -4
  3. codeclm/modules/conditioners.py +13 -3
  4. codeclm/tokenizer/Flow1dVAE/cal_token_stat.py +0 -19
  5. codeclm/tokenizer/Flow1dVAE/compare_model_weight.py +0 -13
  6. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py +0 -121
  7. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py +0 -94
  8. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py +0 -70
  9. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py +0 -46
  10. codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py +0 -86
  11. codeclm/tokenizer/Flow1dVAE/generate_1rvq.py +3 -32
  12. codeclm/tokenizer/Flow1dVAE/generate_2rvq.py +0 -293
  13. codeclm/tokenizer/Flow1dVAE/generate_4rvq.py +0 -292
  14. codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py +0 -1278
  15. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py +0 -372
  16. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py +0 -830
  17. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py +0 -994
  18. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py +0 -313
  19. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py +0 -313
  20. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py +0 -313
  21. codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py +0 -461
  22. codeclm/tokenizer/Flow1dVAE/model_1rvq.py +0 -2
  23. codeclm/tokenizer/Flow1dVAE/model_2rvq.py +0 -774
  24. codeclm/tokenizer/Flow1dVAE/model_4rvq.py +0 -774
  25. codeclm/tokenizer/Flow1dVAE/model_septoken.py +0 -2
  26. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml +0 -122
  27. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml +0 -125
  28. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml +0 -137
  29. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml +0 -139
  30. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml +0 -138
  31. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml +0 -139
  32. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml +0 -135
  33. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml +0 -137
  34. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml +0 -116
  35. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml +0 -125
  36. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml +0 -128
  37. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml +0 -126
  38. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml +0 -128
  39. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml +0 -128
  40. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml +0 -121
  41. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac.yaml +0 -0
  42. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml +0 -121
  43. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml +0 -125
  44. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml +0 -124
  45. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml +0 -108
  46. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml +0 -105
  47. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml +0 -106
  48. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml +0 -20
  49. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py +0 -2
  50. codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py +0 -115
app.py CHANGED
@@ -16,14 +16,12 @@ from download import download_model
16
  # 下载模型
17
  APP_DIR = op.dirname(op.abspath(__file__))
18
  download_model(APP_DIR)
19
- base_full_path = op.join(APP_DIR, "ckpt", "songgeneration-base-full")
20
- os.makedirs(base_full_path, exist_ok=True)
21
- download_model(base_full_path, repo_id="lglg666/SongGeneration-base-full", revision="19ebdb6")
22
  print("Successful downloaded model.")
23
 
24
  # 模型初始化
25
  from levo_inference import LeVoInference
26
- MODEL = LeVoInference(base_full_path)
27
 
28
  EXAMPLE_LYRICS = """
29
  [intro-medium]
@@ -225,7 +223,7 @@ lyrics
225
  minimum=0.1,
226
  maximum=2.0,
227
  step=0.1,
228
- value=0.75,
229
  interactive=True,
230
  elem_id="temperature",
231
  )
@@ -268,12 +266,12 @@ lyrics
268
  # 生成按钮点击事件
269
  generate_btn.click(
270
  fn=generate_song,
271
- inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(-1)],
272
  outputs=[output_audio, output_json]
273
  )
274
  generate_bgm_btn.click(
275
  fn=generate_song,
276
- inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(-1), gr.State("bgm")],
277
  outputs=[output_audio, output_json]
278
  )
279
 
 
16
  # 下载模型
17
  APP_DIR = op.dirname(op.abspath(__file__))
18
  download_model(APP_DIR)
19
+ download_model(op.join(APP_DIR, "ckpt"), repo_id="waytan22/SongGeneration-v1.5-beta", revision="db10f47")
 
 
20
  print("Successful downloaded model.")
21
 
22
  # 模型初始化
23
  from levo_inference import LeVoInference
24
+ MODEL = LeVoInference(op.join(APP_DIR, "ckpt", "SongGeneration-v1.5-beta"))
25
 
26
  EXAMPLE_LYRICS = """
27
  [intro-medium]
 
223
  minimum=0.1,
224
  maximum=2.0,
225
  step=0.1,
226
+ value=0.8,
227
  interactive=True,
228
  elem_id="temperature",
229
  )
 
266
  # 生成按钮点击事件
267
  generate_btn.click(
268
  fn=generate_song,
269
+ inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50)],
270
  outputs=[output_audio, output_json]
271
  )
272
  generate_bgm_btn.click(
273
  fn=generate_song,
274
+ inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50), gr.State("bgm")],
275
  outputs=[output_audio, output_json]
276
  )
277
 
codeclm/models/builders.py CHANGED
@@ -52,7 +52,7 @@ def get_audio_tokenizer_model_cpu(checkpoint_path: str, cfg: omegaconf.DictConfi
52
  return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
53
 
54
 
55
- def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
56
  """Instantiate a LM."""
57
  lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
58
 
@@ -61,8 +61,8 @@ def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
61
  q_modeling = lm_kwargs.pop('q_modeling', None)
62
 
63
  # conditioner
64
- condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg)
65
-
66
  # codebook pattern: delay
67
  codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
68
  if codebooks_pattern_cfg.modeling is None:
@@ -97,7 +97,7 @@ def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel:
97
  raise KeyError(f"Unexpected LM model {lm_type}")
98
 
99
 
100
- def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider:
101
  """Instantiate a conditioning model."""
102
  cfg = getattr(cfg, 'conditioners')
103
  dict_cfg = {} if cfg is None else dict_from_config(cfg)
@@ -115,6 +115,7 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
115
  elif model_type == "QwTextTokenizer":
116
  conditioners[str(cond)] = QwTextConditioner(
117
  output_dim=output_dim,
 
118
  **model_args
119
  )
120
  elif model_type == "qt_embedding":
 
52
  return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu')
53
 
54
 
55
+ def get_lm_model(cfg: omegaconf.DictConfig, version: str = 'v1.0'): #-> LMModel:
56
  """Instantiate a LM."""
57
  lm_kwargs = dict_from_config(getattr(cfg, 'lm'))
58
 
 
61
  q_modeling = lm_kwargs.pop('q_modeling', None)
62
 
63
  # conditioner
64
+ condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg, version=version)
65
+
66
  # codebook pattern: delay
67
  codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
68
  if codebooks_pattern_cfg.modeling is None:
 
97
  raise KeyError(f"Unexpected LM model {lm_type}")
98
 
99
 
100
+ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig, version: str = 'v1.0') -> ConditionerProvider:
101
  """Instantiate a conditioning model."""
102
  cfg = getattr(cfg, 'conditioners')
103
  dict_cfg = {} if cfg is None else dict_from_config(cfg)
 
115
  elif model_type == "QwTextTokenizer":
116
  conditioners[str(cond)] = QwTextConditioner(
117
  output_dim=output_dim,
118
+ version=version,
119
  **model_args
120
  )
121
  elif model_type == "qt_embedding":
codeclm/modules/conditioners.py CHANGED
@@ -188,10 +188,13 @@ class QwTokenizerConditioner(TextConditioner):
188
  class QwTextConditioner(TextConditioner):
189
  def __init__(self, output_dim: int,
190
  token_path = "",
191
- max_len = 300): #""
 
192
 
193
  from transformers import Qwen2Tokenizer
194
- self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
 
 
195
  voc_size = len(self.text_tokenizer.get_vocab())
196
  # here initialize a output_proj (nn.Embedding) layer
197
  super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
@@ -636,7 +639,14 @@ class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout):
636
  sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
637
  else:
638
  if customized is None:
639
- sample.text[condition] = None
 
 
 
 
 
 
 
640
  else:
641
  text_cond = deepcopy(sample.text[condition])
642
  if "structure" in customized:
 
188
  class QwTextConditioner(TextConditioner):
189
  def __init__(self, output_dim: int,
190
  token_path = "",
191
+ max_len = 300,
192
+ version: str = 'v1.0'): #""
193
 
194
  from transformers import Qwen2Tokenizer
195
+ self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path)
196
+ if version == 'v1.5':
197
+ self.text_tokenizer.add_tokens(['[Musicality-very-high]', '[Musicality-high]', '[Musicality-medium]', '[Musicality-low]', '[Musicality-very-low]'], special_tokens=True)
198
  voc_size = len(self.text_tokenizer.get_vocab())
199
  # here initialize a output_proj (nn.Embedding) layer
200
  super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643)
 
639
  sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0])
640
  else:
641
  if customized is None:
642
+ if condition in ['type_info'] and sample.text[condition] is not None:
643
+ if "[Musicality-very-high]" in sample.text[condition]:
644
+ sample.text[condition] = "[Musicality-very-low], ."
645
+ print(f"cfg unconditioning: change sample.text[condition] to [Musicality-very-low]")
646
+ else:
647
+ sample.text[condition] = None
648
+ else:
649
+ sample.text[condition] = None
650
  else:
651
  text_cond = deepcopy(sample.text[condition])
652
  if "structure" in customized:
codeclm/tokenizer/Flow1dVAE/cal_token_stat.py DELETED
@@ -1,19 +0,0 @@
1
- import kaldiio
2
- from tqdm import tqdm
3
- import torch
4
-
5
- if __name__ == "__main__":
6
- bar = torch.zeros(1, 16384)
7
- with open('token.scp', 'r') as f:
8
- for item_idx, line in tqdm(enumerate(f)):
9
- idx, pos = line.strip().split()
10
- codes = kaldiio.load_mat(pos)
11
- for i0 in range(codes.shape[-1]):
12
- bar[0, codes[0, 0, i0]] += 1
13
- if(item_idx % 1000 == 0):
14
- print("=========")
15
- print(1 - (bar[0]==0).sum() / bar.shape[-1])
16
- print("=========")
17
- print("=========")
18
- print(1 - (bar[0]==0).sum() / bar.shape[-1])
19
- print("=========")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/compare_model_weight.py DELETED
@@ -1,13 +0,0 @@
1
- import torch
2
- import sys
3
- from safetensors.torch import load_file
4
-
5
- if __name__ == "__main__":
6
- m0, m1 = sys.argv[1], sys.argv[2]
7
- m0 = load_file(m0)
8
- m1 = load_file(m1)
9
-
10
- ks = [k for k in m0.keys() if 'bestrq' in k]
11
- for k in ks:
12
- print(k, (m0[k] - m1[k]).abs().sum())
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py DELETED
@@ -1,121 +0,0 @@
1
- import torch,torchaudio
2
- import os,sys,json
3
- from tqdm import tqdm
4
- import numpy as np
5
-
6
- #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
7
- from generate_septoken import Tango as Tango_sep
8
- from generate_2rvq import Tango as Tango_1x2
9
- import kaldiio
10
- from kaldiio import WriteHelper
11
- from audio import AudioFile
12
-
13
- from demucs.models.pretrained import get_model_from_yaml
14
- from filelock import FileLock
15
-
16
- # os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml")
17
- class Separator:
18
- def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
19
- if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
20
- self.device = torch.device(f"cuda:{gpu_id}")
21
- else:
22
- self.device = torch.device("cpu")
23
- self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
24
-
25
- def init_demucs_model(self, model_path, config_path):
26
- model = get_model_from_yaml(config_path, model_path)
27
- model.to(self.device)
28
- model.eval()
29
- return model
30
-
31
- def load_audio(self, f):
32
- a, fs = torchaudio.load(f)
33
- if (fs != 48000):
34
- a = torchaudio.functional.resample(a, fs, 48000)
35
- # if a.shape[-1] >= 48000*10:
36
- # a = a[..., :48000*10]
37
- # else:
38
- # a = torch.cat([a, a], -1)
39
- # return a[:, 0:48000*10]
40
- return a
41
-
42
- def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"):
43
- name, _ = os.path.splitext(os.path.split(audio_path)[-1])
44
- output_paths = []
45
- # lock_path = os.path.join(output_dir, f"{name}.lock")
46
- # with FileLock(lock_path): # 加一个避免多卡访问时死锁
47
- for stem in self.demucs_model.sources:
48
- output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
49
- if os.path.exists(output_path):
50
- output_paths.append(output_path)
51
- if len(output_paths) == 1: # 4
52
- # drums_path, bass_path, other_path, vocal_path = output_paths
53
- vocal_path = output_paths[0]
54
- else:
55
- lock_path = os.path.join(output_dir, f"{name}_separate.lock")
56
- with FileLock(lock_path):
57
- drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
58
- full_audio = self.load_audio(audio_path)
59
- vocal_audio = self.load_audio(vocal_path)
60
- minlen = min(full_audio.shape[-1], vocal_audio.shape[-1])
61
- # bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen]
62
- bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path)
63
- for path in [drums_path, bass_path, other_path, vocal_path]:
64
- os.remove(path)
65
- return full_audio, vocal_audio, bgm_audio
66
-
67
- def read_wav(fname, sample_rate=48_000):
68
- try:
69
- orig_samples, fs = torchaudio.load(fname)
70
- except:
71
- af = AudioFile(fname)
72
- orig_samples = af.read()
73
- fs = af.samplerate()
74
- orig_samples = orig_samples[0]
75
- if(fs!=sample_rate):
76
- orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate)
77
- fs = sample_rate
78
- if orig_samples.shape[0] == 1:
79
- orig_samples = torch.cat([orig_samples, orig_samples], 0)
80
- return orig_samples
81
-
82
- if __name__ == "__main__":
83
- # Define Model
84
- json_path = sys.argv[1]
85
-
86
- mus_infos = []
87
- with open(json_path) as f:
88
- for line in f:
89
- item = json.loads(line)
90
- mus_infos.append(item)
91
-
92
- tango_sep = Tango_sep(model_path="./saved/model_septoken/model_2.safetensors")
93
- tango_1x2 = Tango_1x2(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2)
94
- separator = Separator()
95
-
96
- # Feature extraction loop
97
- # for i in tqdm(range(2000)):
98
- first_time = True
99
- for item in tqdm(mus_infos):
100
- if(os.path.exists(item['path'])):
101
- full_path = item['path']
102
- else:
103
- full_path = '/mnt/share/' + item['path']
104
-
105
- full_tensor, vocal_tensor, bgm_tensor = separator.run(full_path)
106
-
107
- # full_tensor = read_wav(full_path)
108
- # vocal_tensor = read_wav(vocal_path)
109
- # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
110
- # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
111
- # bgm_tensor = full_tensor - vocal_tensor
112
- codes_1x2 = tango_1x2.sound2code(full_tensor)
113
- codes_vocal, codes_bgm = tango_sep.sound2code(vocal_tensor, bgm_tensor)
114
- codes = torch.cat([codes_1x2[:,[0],:], codes_vocal, codes_bgm], 1).cpu().numpy()
115
- save_path = full_path.replace('.wav', '.1x1_and_sep.npy').replace('.mp3', '.1x1_and_sep.npy').replace('.flac', '.1x1_and_sep.npy').replace('.ogg', '.1x1_and_sep.npy')
116
- assert save_path != full_path, (save_path, full_path)
117
- np.save(save_path, codes)
118
-
119
- if(first_time):
120
- first_time = False
121
- print(codes_vocal.shape, codes_bgm.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py DELETED
@@ -1,94 +0,0 @@
1
- import torch,torchaudio
2
- import os,sys,json
3
- from tqdm import tqdm
4
-
5
- #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
6
- from generate_septoken import Tango
7
- import kaldiio
8
- from kaldiio import WriteHelper
9
- from audio import AudioFile
10
-
11
- def read_wav(fname, sample_rate=48_000):
12
- try:
13
- orig_samples, fs = torchaudio.load(fname)
14
- except:
15
- af = AudioFile(fname)
16
- orig_samples = af.read()
17
- fs = af.samplerate()
18
- orig_samples = orig_samples[0]
19
- if(fs!=sample_rate):
20
- orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate)
21
- fs = sample_rate
22
- if orig_samples.shape[0] == 1:
23
- orig_samples = torch.cat([orig_samples, orig_samples], 0)
24
- return orig_samples
25
-
26
- if __name__ == "__main__":
27
- # Define Model
28
- json_path = sys.argv[1]
29
- outdir = sys.argv[2]
30
-
31
- mus_infos = []
32
- with open(json_path) as f:
33
- for line in f:
34
- item = json.loads(line)
35
- mus_infos.append(item)
36
-
37
- tango = Tango(model_path="./saved/model_septoken/model_2.safetensors")
38
-
39
-
40
- # Feature extraction loop
41
- # for i in tqdm(range(2000)):
42
- first_time = True
43
- with WriteHelper('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir), write_function="pickle") as writer_vocal, WriteHelper('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir), write_function="pickle") as writer_bgm:
44
- print('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir))
45
- print('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir))
46
- for item in tqdm(mus_infos):
47
- try:
48
- # if True:
49
- idx = item['idx']
50
- # print(idx)
51
- if(os.path.exists(item['path'])):
52
- full_path = item['path']
53
- else:
54
- full_path = '/mnt/share/' + item['path']
55
- if(os.path.exists(item['vocal_path'])):
56
- vocal_path = item['vocal_path']
57
- bgm_paths = item['bgm_path']
58
- else:
59
- vocal_path = '/mnt/share/' + item['vocal_path']
60
- bgm_paths = ['/mnt/share/' + p for p in item['bgm_path']]
61
- vocal_tensor = read_wav(vocal_path)
62
- # full_tensor = read_wav(full_path)
63
- # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
64
- # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
65
- # bgm_tensor = full_tensor - vocal_tensor
66
- bgm_tensor = sum([read_wav(p) for p in bgm_paths])
67
- codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor)
68
- writer_vocal(str(idx), codes_vocal.cpu())
69
- writer_bgm(str(idx), codes_bgm.cpu())
70
- if(first_time):
71
- first_time = False
72
- print(codes_vocal.shape, codes_bgm.shape)
73
- except:
74
- print(item['vocal_path'])
75
- print(item['bgm_path'])
76
- continue
77
-
78
- # idx = item['idx']
79
- # # print(idx)
80
- # full_path = item['path']
81
- # vocal_path = item['vocal_path']
82
- # bgm_paths = item['bgm_path']
83
- # full_tensor = read_wav(full_path)
84
- # vocal_tensor = read_wav(vocal_path)
85
- # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1])
86
- # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length]
87
- # bgm_tensor = full_tensor - vocal_tensor
88
- # codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor)
89
- # writer_vocal(str(idx), codes_vocal.cpu())
90
- # writer_bgm(str(idx), codes_bgm.cpu())
91
- # if(first_time):
92
- # first_time = False
93
- # print(codes_vocal.shape, codes_bgm.shape)
94
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py DELETED
@@ -1,70 +0,0 @@
1
- import torch,torchaudio
2
- import os,sys,json
3
- from tqdm import tqdm
4
-
5
- #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
6
- from generate_2rvq import Tango
7
- import kaldiio
8
- from kaldiio import WriteHelper
9
- import torch
10
- import subprocess
11
- import time
12
- import sys
13
-
14
- def get_gpu_memory():
15
- _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
16
-
17
- ACCEPTABLE_AVAILABLE_MEMORY = 1024
18
- COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
19
- memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:]
20
- memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
21
- return memory_free_values
22
-
23
- if __name__ == "__main__":
24
- # Define Model
25
- json_path = sys.argv[1]
26
- outdir = sys.argv[2]
27
-
28
- gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES'])
29
- while True:
30
- free_mem = get_gpu_memory()
31
- free_mem = free_mem[gpu_idx]
32
- if(free_mem > 25_000):
33
- print("GPU memory {}, run matrix cal".format(free_mem))
34
- break
35
- else:
36
- print("GPU memory {}, sleep 1min".format(free_mem))
37
- time.sleep(60)
38
-
39
- mus_infos = []
40
- with open(json_path) as f:
41
- for line in f:
42
- item = json.loads(line)
43
- mus_infos.append(item)
44
-
45
- tango = Tango(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2)
46
-
47
-
48
- # Feature extraction loop
49
- # for i in tqdm(range(2000)):
50
- with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
51
- print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
52
- for item in tqdm(mus_infos):
53
- try:
54
- # if True:
55
- idx = item['idx']
56
- # print(idx)
57
- with torch.autocast(device_type="cuda", dtype=torch.float16):
58
- if(os.path.exists(item['path'])):
59
- codes = tango.file2code(item['path'])
60
- else:
61
- codes = tango.file2code('/mnt/share/' + item['path'])
62
- writer(str(idx), codes.cpu())
63
- except:
64
- print(item['path'])
65
- continue
66
- # idx = item['idx']
67
- # # print(idx)
68
- # with torch.autocast(device_type="cuda", dtype=torch.float16):
69
- # codes = tango.file2code(item['path'])
70
- # writer(str(idx), codes.cpu())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py DELETED
@@ -1,46 +0,0 @@
1
- import torch,torchaudio
2
- import os,sys,json
3
- from tqdm import tqdm
4
-
5
- #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
6
- from generate_4rvq import Tango
7
- import kaldiio
8
- from kaldiio import WriteHelper
9
-
10
- if __name__ == "__main__":
11
- # Define Model
12
- json_path = sys.argv[1]
13
- outdir = sys.argv[2]
14
-
15
- mus_infos = []
16
- with open(json_path) as f:
17
- for line in f:
18
- item = json.loads(line)
19
- mus_infos.append(item)
20
-
21
- tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4)
22
-
23
-
24
- # Feature extraction loop
25
- # for i in tqdm(range(2000)):
26
- with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
27
- print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
28
- for item in tqdm(mus_infos):
29
- try:
30
- # if True:
31
- idx = item['idx']
32
- # print(idx)
33
- with torch.autocast(device_type="cuda", dtype=torch.float16):
34
- if(os.path.exists(item['path'])):
35
- codes = tango.file2code(item['path'])
36
- else:
37
- codes = tango.file2code('/mnt/share/' + item['path'])
38
- writer(str(idx), codes.cpu())
39
- except:
40
- print(item['path'])
41
- continue
42
- # idx = item['idx']
43
- # # print(idx)
44
- # with torch.autocast(device_type="cuda", dtype=torch.float16):
45
- # codes = tango.file2code(item['path'])
46
- # writer(str(idx), codes.cpu())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py DELETED
@@ -1,86 +0,0 @@
1
- import torch,torchaudio
2
- import os,sys,json
3
- from tqdm import tqdm
4
-
5
- #from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango
6
- from generate_4rvq import Tango
7
- import kaldiio
8
- from kaldiio import WriteHelper
9
- import torch
10
- import subprocess
11
- import time
12
- import sys
13
-
14
- def get_gpu_memory():
15
- _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
16
-
17
- ACCEPTABLE_AVAILABLE_MEMORY = 1024
18
- COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
19
- memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:]
20
- memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
21
- return memory_free_values
22
-
23
- if __name__ == "__main__":
24
- # Define Model
25
- json_path = sys.argv[1]
26
- outdir = sys.argv[2]
27
- ds = int(sys.argv[3])
28
-
29
- gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES'])
30
- while True:
31
- free_mem = get_gpu_memory()
32
- free_mem = free_mem[gpu_idx]
33
- if(free_mem > 25_000):
34
- print("GPU memory {}, run matrix cal".format(free_mem))
35
- break
36
- else:
37
- print("GPU memory {}, sleep 1min".format(free_mem))
38
- time.sleep(60)
39
-
40
- mus_infos = []
41
- with open(json_path) as f:
42
- for line in f:
43
- item = json.loads(line)
44
- mus_infos.append(item)
45
-
46
- tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4)
47
-
48
-
49
- # Feature extraction loop
50
- # for i in tqdm(range(2000)):
51
- with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer:
52
- print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir))
53
- bar = torch.zeros(4, 16384)
54
- for item_idx, item in tqdm(enumerate(mus_infos)):
55
- try:
56
- # if True:
57
- idx = item['idx']
58
- # print(idx)
59
- with torch.autocast(device_type="cuda", dtype=torch.float16):
60
- if(os.path.exists(item['path'])):
61
- codes = tango.file2code_ds(item['path'], ds)
62
- else:
63
- codes = tango.file2code_ds('/mnt/share/' + item['path'], ds)
64
- codes = codes.cpu()
65
- writer(str(idx), codes)
66
- for i0 in range(codes.shape[-1]):
67
- bar[0, codes[0, 0, i0]] += 1
68
- bar[1, codes[0, 1, i0]] += 1
69
- bar[2, codes[0, 2, i0]] += 1
70
- bar[3, codes[0, 3, i0]] += 1
71
- except Exception as e:
72
- print(item['path'])
73
- # print(e.message, e.args)
74
- # exit(1)
75
- continue
76
-
77
- if(item_idx % 1000 == 0):
78
- print("=========")
79
- print(1 - (bar[0]==0).sum() / bar.shape[-1])
80
- print("=========")
81
-
82
- # idx = item['idx']
83
- # # print(idx)
84
- # with torch.autocast(device_type="cuda", dtype=torch.float16):
85
- # codes = tango.file2code(item['path'])
86
- # writer(str(idx), codes.cpu())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/generate_1rvq.py CHANGED
@@ -8,7 +8,6 @@ import librosa
8
  import os
9
  import math
10
  import numpy as np
11
- from tools.get_1dvae_large import get_model
12
  import tools.torch_tools as torch_tools
13
  from safetensors.torch import load_file
14
 
@@ -24,9 +23,9 @@ class Tango:
24
  scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
25
  self.device = device
26
 
27
- self.vae = get_model(vae_config, vae_model)
28
- self.vae = self.vae.to(device)
29
- self.vae=self.vae.eval()
30
  self.layer_num = layer_num
31
 
32
  self.MAX_DURATION = 360
@@ -254,37 +253,9 @@ class Tango:
254
  # print(fname, wave.shape)
255
  return wave
256
 
257
- @torch.no_grad()
258
- def sound2sound_vae(self, sound, prompt=None, steps=50, disable_progress=False):
259
- min_samples = int(40 * 25) # 40ms per frame
260
- hop_samples = min_samples // 4 * 3
261
- ovlp_samples = min_samples - hop_samples
262
- dur = 20
263
-
264
- latent_list = []
265
- for i in range(0, sound.shape[-1], dur*48000):
266
- if(i+dur*2*48000 > sound.shape[-1]):
267
- latent = tango.vae.encode_audio(sound.cuda()[None,:,i:])
268
- break
269
- else:
270
- latent = tango.vae.encode_audio(sound.cuda()[None,:,i:i+dur*48000])
271
- latent_list.append(latent)
272
-
273
- output = None
274
- for i in range(len(latent_list)):
275
- print(i)
276
- latent = latent_list[i]
277
- cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
278
- if output is None:
279
- output = cur_output
280
- else:
281
- output = torch.cat([output, cur_output], -1)
282
- return output
283
-
284
  def to(self, device=None, dtype=None, non_blocking=False):
285
  if device is not None:
286
  self.device = device
287
  self.model.device = device
288
- self.vae = self.vae.to(device, dtype, non_blocking)
289
  self.model = self.model.to(device, dtype, non_blocking)
290
  return self
 
8
  import os
9
  import math
10
  import numpy as np
 
11
  import tools.torch_tools as torch_tools
12
  from safetensors.torch import load_file
13
 
 
23
  scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
24
  self.device = device
25
 
26
+ # self.vae = get_model(vae_config, vae_model)
27
+ # self.vae = self.vae.to(device)
28
+ # self.vae=self.vae.eval()
29
  self.layer_num = layer_num
30
 
31
  self.MAX_DURATION = 360
 
253
  # print(fname, wave.shape)
254
  return wave
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  def to(self, device=None, dtype=None, non_blocking=False):
257
  if device is not None:
258
  self.device = device
259
  self.model.device = device
 
260
  self.model = self.model.to(device, dtype, non_blocking)
261
  return self
codeclm/tokenizer/Flow1dVAE/generate_2rvq.py DELETED
@@ -1,293 +0,0 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- from model_2rvq import PromptCondAudioDiffusion
5
- from diffusers import DDIMScheduler, DDPMScheduler
6
- import torchaudio
7
- import librosa
8
- import os
9
- import math
10
- import numpy as np
11
- # from tools.get_mulan import get_mulan
12
- from tools.get_1dvae_large import get_model
13
- import tools.torch_tools as torch_tools
14
- from safetensors.torch import load_file
15
- from audio import AudioFile
16
- import kaldiio
17
-
18
- class Tango:
19
- def __init__(self, \
20
- model_path, \
21
- layer_num=6, \
22
- rvq_num=1, \
23
- device="cuda:0"):
24
-
25
- self.sample_rate = 48000
26
- scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
27
- self.device = device
28
-
29
- self.vae = get_model()
30
- self.vae = self.vae.to(device)
31
- self.vae=self.vae.eval()
32
- self.layer_num = layer_num
33
-
34
- self.MAX_DURATION = 360
35
- main_config = {
36
- "num_channels":32,
37
- "unet_model_name":None,
38
- "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
39
- "snr_gamma":None,
40
- }
41
- self.rvq_num = rvq_num
42
- # print("rvq_num: ", self.rvq_num)
43
- # exit()
44
- self.model = PromptCondAudioDiffusion(**main_config).to(device)
45
- if model_path.endswith(".safetensors"):
46
- main_weights = load_file(model_path)
47
- else:
48
- main_weights = torch.load(model_path, map_location=device)
49
- self.model.load_state_dict(main_weights, strict=False)
50
- print ("Successfully loaded checkpoint from:", model_path)
51
-
52
- self.model.eval()
53
- self.model.init_device_dtype(torch.device(device), torch.float32)
54
-
55
- # self.scheduler = DDIMScheduler.from_pretrained( \
56
- # scheduler_name, subfolder="scheduler")
57
- # self.scheduler = DDPMScheduler.from_pretrained( \
58
- # scheduler_name, subfolder="scheduler")
59
- print("Successfully loaded inference scheduler from {}".format(scheduler_name))
60
-
61
-
62
-
63
- @torch.no_grad()
64
- @torch.autocast(device_type="cuda", dtype=torch.float32)
65
- def sound2code(self, orig_samples, batch_size=8):
66
- if(orig_samples.ndim == 2):
67
- audios = orig_samples.unsqueeze(0).to(self.device)
68
- elif(orig_samples.ndim == 3):
69
- audios = orig_samples.to(self.device)
70
- else:
71
- assert orig_samples.ndim in (2,3), orig_samples.shape
72
- audios = self.preprocess_audio(audios)
73
- audios = audios.squeeze(0)
74
- orig_length = audios.shape[-1]
75
- min_samples = int(40 * self.sample_rate)
76
- # 40秒对应10个token
77
- output_len = int(orig_length / float(self.sample_rate) * 25) + 1
78
- # print("output_len: ", output_len)
79
-
80
- while(audios.shape[-1] < min_samples):
81
- audios = torch.cat([audios, audios], -1)
82
- int_max_len=audios.shape[-1]//min_samples+1
83
- audios = torch.cat([audios, audios], -1)
84
- audios=audios[:,:int(int_max_len*(min_samples))]
85
- codes_list=[]
86
-
87
- audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
88
-
89
- for audio_inx in range(0, audio_input.shape[0], batch_size):
90
- # import pdb; pdb.set_trace()
91
- codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num)
92
- # print("codes",codes[0].shape)
93
-
94
- codes_list.append(torch.cat(codes, 1))
95
- # print("codes_list",codes_list[0].shape)
96
-
97
- codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
98
- codes=codes[:,:,:output_len]
99
-
100
- return codes
101
-
102
- @torch.no_grad()
103
- @torch.autocast(device_type="cuda", dtype=torch.float32)
104
- def sound2code_ds(self, orig_samples, ds, batch_size=8):
105
- if(orig_samples.ndim == 2):
106
- audios = orig_samples.unsqueeze(0).to(self.device)
107
- elif(orig_samples.ndim == 3):
108
- audios = orig_samples.to(self.device)
109
- else:
110
- assert orig_samples.ndim in (2,3), orig_samples.shape
111
- audios = self.preprocess_audio(audios)
112
- audios = audios.squeeze(0)
113
- orig_length = audios.shape[-1]
114
- min_samples = int(40 * self.sample_rate)
115
- # 40秒对应10个token
116
- output_len = int(orig_length / float(self.sample_rate) * 25) + 1
117
- # print("output_len: ", output_len)
118
-
119
- while(audios.shape[-1] < min_samples):
120
- audios = torch.cat([audios, audios], -1)
121
- int_max_len=audios.shape[-1]//min_samples+1
122
- audios = torch.cat([audios, audios], -1)
123
- audios=audios[:,:int(int_max_len*(min_samples))]
124
- codes_list=[]
125
-
126
- audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
127
-
128
- for audio_inx in range(0, audio_input.shape[0], batch_size):
129
- # import pdb; pdb.set_trace()
130
- codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds)
131
- # print("codes",codes[0].shape)
132
-
133
- codes_list.append(torch.cat(codes, 1))
134
- # print("codes_list",codes_list[0].shape)
135
-
136
- codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
137
- codes=codes[:,:,:output_len]
138
-
139
- return codes
140
-
141
- @torch.no_grad()
142
- def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
143
- codes = codes.to(self.device)
144
-
145
- min_samples = duration * 25 # 40ms per frame
146
- hop_samples = min_samples // 4 * 3
147
- ovlp_samples = min_samples - hop_samples
148
- hop_frames = hop_samples
149
- ovlp_frames = ovlp_samples
150
- first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
151
- first_latent_length = 0
152
- first_latent_codes_length = 0
153
-
154
- if(isinstance(prompt, torch.Tensor)):
155
- # prepare prompt
156
- prompt = prompt.to(self.device)
157
- if(prompt.ndim == 3):
158
- assert prompt.shape[0] == 1, prompt.shape
159
- prompt = prompt[0]
160
- elif(prompt.ndim == 1):
161
- prompt = prompt.unsqueeze(0).repeat(2,1)
162
- elif(prompt.ndim == 2):
163
- if(prompt.shape[0] == 1):
164
- prompt = prompt.repeat(2,1)
165
-
166
- if(prompt.shape[-1] < int(30 * self.sample_rate)):
167
- # if less than 30s, just choose the first 10s
168
- prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
169
- else:
170
- # else choose from 20.48s which might includes verse or chorus
171
- prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
172
-
173
- true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
174
- # print("true_latent.shape", true_latent.shape)
175
- # print("first_latent.shape", first_latent.shape)
176
- #true_latent.shape torch.Size([1, 250, 64])
177
- # first_latent.shape torch.Size([1, 1000, 64])
178
-
179
- first_latent[:,0:true_latent.shape[1],:] = true_latent
180
- first_latent_length = true_latent.shape[1]
181
- first_latent_codes = self.sound2code(prompt)
182
- first_latent_codes_length = first_latent_codes.shape[-1]
183
- codes = torch.cat([first_latent_codes, codes], -1)
184
-
185
- codes_len= codes.shape[-1]
186
- target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
187
- # target_len = int(codes_len / 100 * 4 * self.sample_rate)
188
- # code repeat
189
- if(codes_len < min_samples):
190
- while(codes.shape[-1] < min_samples):
191
- codes = torch.cat([codes, codes], -1)
192
- codes = codes[:,:,0:min_samples]
193
- codes_len = codes.shape[-1]
194
- if((codes_len - ovlp_samples) % hop_samples > 0):
195
- len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
196
- while(codes.shape[-1] < len_codes):
197
- codes = torch.cat([codes, codes], -1)
198
- codes = codes[:,:,0:len_codes]
199
- latent_length = min_samples
200
- latent_list = []
201
- spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
202
- with torch.autocast(device_type="cuda", dtype=torch.float16):
203
- for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
204
- codes_input=[]
205
- codes_input.append(codes[:,:,sinx:sinx+min_samples])
206
- if(sinx == 0):
207
- # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
208
- incontext_length = first_latent_length
209
- latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
210
- latent_list.append(latents)
211
- else:
212
- # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
213
- true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
214
- print("true_latent.shape", true_latent.shape)
215
- len_add_to_1000 = 1000 - true_latent.shape[-2]
216
- # print("len_add_to_1000", len_add_to_1000)
217
- # exit()
218
- incontext_length = true_latent.shape[-2]
219
- true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
220
- latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
221
- latent_list.append(latents)
222
-
223
- latent_list = [l.float() for l in latent_list]
224
- latent_list[0] = latent_list[0][:,:,first_latent_length:]
225
- min_samples = int(min_samples * self.sample_rate // 1000 * 40)
226
- hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
227
- ovlp_samples = min_samples - hop_samples
228
- with torch.no_grad():
229
- output = None
230
- for i in range(len(latent_list)):
231
- latent = latent_list[i]
232
- cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
233
-
234
- if output is None:
235
- output = cur_output
236
- else:
237
- ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
238
- ov_win = torch.cat([ov_win, 1 - ov_win], -1)
239
- print("output.shape", output.shape)
240
- print("ov_win.shape", ov_win.shape)
241
- output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
242
- output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
243
- output = output[:, 0:target_len]
244
- return output
245
-
246
- @torch.no_grad()
247
- def preprocess_audio(self, input_audios, threshold=0.8):
248
- assert len(input_audios.shape) == 3, input_audios.shape
249
- nchan = input_audios.shape[1]
250
- input_audios = input_audios.reshape(input_audios.shape[0], -1)
251
- norm_value = torch.ones_like(input_audios[:,0])
252
- max_volume = input_audios.abs().max(dim=-1)[0]
253
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
254
- return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
255
-
256
- @torch.no_grad()
257
- def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
258
- codes = self.sound2code(sound)
259
- # print(codes.shape)
260
- # exit()
261
- wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
262
- # print(fname, wave.shape)
263
- return wave
264
-
265
- def file2code(self, fname):
266
- try:
267
- orig_samples, fs = torchaudio.load(fname)
268
- except:
269
- af = AudioFile(fname)
270
- orig_samples = af.read()
271
- fs = af.samplerate()
272
- orig_samples = orig_samples[0]
273
- if(fs!=self.sample_rate):
274
- orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
275
- fs = self.sample_rate
276
- if orig_samples.shape[0] == 1:
277
- orig_samples = torch.cat([orig_samples, orig_samples], 0)
278
- return self.sound2code(orig_samples)
279
-
280
- def file2code_ds(self, fname, ds):
281
- try:
282
- orig_samples, fs = torchaudio.load(fname)
283
- except:
284
- af = AudioFile(fname)
285
- orig_samples = af.read()
286
- fs = af.samplerate()
287
- orig_samples = orig_samples[0]
288
- if(fs!=self.sample_rate):
289
- orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
290
- fs = self.sample_rate
291
- if orig_samples.shape[0] == 1:
292
- orig_samples = torch.cat([orig_samples, orig_samples], 0)
293
- return self.sound2code_ds(orig_samples, ds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/generate_4rvq.py DELETED
@@ -1,292 +0,0 @@
1
- import json
2
- import torch
3
- from tqdm import tqdm
4
- from model_4rvq import PromptCondAudioDiffusion
5
- from diffusers import DDIMScheduler, DDPMScheduler
6
- import torchaudio
7
- import librosa
8
- import os
9
- import math
10
- import numpy as np
11
- # from tools.get_mulan import get_mulan
12
- from tools.get_1dvae_large import get_model
13
- import tools.torch_tools as torch_tools
14
- from safetensors.torch import load_file
15
- from audio import AudioFile
16
-
17
- class Tango:
18
- def __init__(self, \
19
- model_path, \
20
- layer_num=6, \
21
- rvq_num=1, \
22
- device="cuda:0"):
23
-
24
- self.sample_rate = 48000
25
- scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json"
26
- self.device = device
27
-
28
- self.vae = get_model()
29
- self.vae = self.vae.to(device)
30
- self.vae=self.vae.eval()
31
- self.layer_num = layer_num
32
-
33
- self.MAX_DURATION = 360
34
- main_config = {
35
- "num_channels":32,
36
- "unet_model_name":None,
37
- "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json",
38
- "snr_gamma":None,
39
- }
40
- self.rvq_num = rvq_num
41
- # print("rvq_num: ", self.rvq_num)
42
- # exit()
43
- self.model = PromptCondAudioDiffusion(**main_config).to(device)
44
- if model_path.endswith(".safetensors"):
45
- main_weights = load_file(model_path)
46
- else:
47
- main_weights = torch.load(model_path, map_location=device)
48
- self.model.load_state_dict(main_weights, strict=False)
49
- print ("Successfully loaded checkpoint from:", model_path)
50
-
51
- self.model.eval()
52
- self.model.init_device_dtype(torch.device(device), torch.float32)
53
-
54
- # self.scheduler = DDIMScheduler.from_pretrained( \
55
- # scheduler_name, subfolder="scheduler")
56
- # self.scheduler = DDPMScheduler.from_pretrained( \
57
- # scheduler_name, subfolder="scheduler")
58
- print("Successfully loaded inference scheduler from {}".format(scheduler_name))
59
-
60
-
61
-
62
- @torch.no_grad()
63
- @torch.autocast(device_type="cuda", dtype=torch.float32)
64
- def sound2code(self, orig_samples, batch_size=8):
65
- if(orig_samples.ndim == 2):
66
- audios = orig_samples.unsqueeze(0).to(self.device)
67
- elif(orig_samples.ndim == 3):
68
- audios = orig_samples.to(self.device)
69
- else:
70
- assert orig_samples.ndim in (2,3), orig_samples.shape
71
- audios = self.preprocess_audio(audios)
72
- audios = audios.squeeze(0)
73
- orig_length = audios.shape[-1]
74
- min_samples = int(40 * self.sample_rate)
75
- # 40秒对应10个token
76
- output_len = int(orig_length / float(self.sample_rate) * 25) + 1
77
- # print("output_len: ", output_len)
78
-
79
- while(audios.shape[-1] < min_samples):
80
- audios = torch.cat([audios, audios], -1)
81
- int_max_len=audios.shape[-1]//min_samples+1
82
- audios = torch.cat([audios, audios], -1)
83
- audios=audios[:,:int(int_max_len*(min_samples))]
84
- codes_list=[]
85
-
86
- audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
87
-
88
- for audio_inx in range(0, audio_input.shape[0], batch_size):
89
- # import pdb; pdb.set_trace()
90
- codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num)
91
- # print("codes",codes[0].shape)
92
-
93
- codes_list.append(torch.cat(codes, 1))
94
- # print("codes_list",codes_list[0].shape)
95
-
96
- codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
97
- codes=codes[:,:,:output_len]
98
-
99
- return codes
100
-
101
- @torch.no_grad()
102
- @torch.autocast(device_type="cuda", dtype=torch.float32)
103
- def sound2code_ds(self, orig_samples, ds, batch_size=6):
104
- if(orig_samples.ndim == 2):
105
- audios = orig_samples.unsqueeze(0).to(self.device)
106
- elif(orig_samples.ndim == 3):
107
- audios = orig_samples.to(self.device)
108
- else:
109
- assert orig_samples.ndim in (2,3), orig_samples.shape
110
- audios = self.preprocess_audio(audios)
111
- audios = audios.squeeze(0)
112
- orig_length = audios.shape[-1]
113
- min_samples = int(40 * self.sample_rate)
114
- # 40秒对应10个token
115
- output_len = int(orig_length / float(self.sample_rate) * 25) + 1
116
- # print("output_len: ", output_len)
117
-
118
- while(audios.shape[-1] < min_samples):
119
- audios = torch.cat([audios, audios], -1)
120
- int_max_len=audios.shape[-1]//min_samples+1
121
- audios = torch.cat([audios, audios], -1)
122
- audios=audios[:,:int(int_max_len*(min_samples))]
123
- codes_list=[]
124
-
125
- audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples)
126
-
127
- for audio_inx in range(0, audio_input.shape[0], batch_size):
128
- # import pdb; pdb.set_trace()
129
- codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds)
130
- # print("codes",codes[0].shape)
131
-
132
- codes_list.append(torch.cat(codes, 1))
133
- # print("codes_list",codes_list[0].shape)
134
-
135
- codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T
136
- codes=codes[:,:,:output_len]
137
-
138
- return codes
139
-
140
- @torch.no_grad()
141
- def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False):
142
- codes = codes.to(self.device)
143
-
144
- min_samples = duration * 25 # 40ms per frame
145
- hop_samples = min_samples // 4 * 3
146
- ovlp_samples = min_samples - hop_samples
147
- hop_frames = hop_samples
148
- ovlp_frames = ovlp_samples
149
- first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device)
150
- first_latent_length = 0
151
- first_latent_codes_length = 0
152
-
153
- if(isinstance(prompt, torch.Tensor)):
154
- # prepare prompt
155
- prompt = prompt.to(self.device)
156
- if(prompt.ndim == 3):
157
- assert prompt.shape[0] == 1, prompt.shape
158
- prompt = prompt[0]
159
- elif(prompt.ndim == 1):
160
- prompt = prompt.unsqueeze(0).repeat(2,1)
161
- elif(prompt.ndim == 2):
162
- if(prompt.shape[0] == 1):
163
- prompt = prompt.repeat(2,1)
164
-
165
- if(prompt.shape[-1] < int(30 * self.sample_rate)):
166
- # if less than 30s, just choose the first 10s
167
- prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24
168
- else:
169
- # else choose from 20.48s which might includes verse or chorus
170
- prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24
171
-
172
- true_latent = self.vae.encode_audio(prompt).permute(0,2,1)
173
- # print("true_latent.shape", true_latent.shape)
174
- # print("first_latent.shape", first_latent.shape)
175
- #true_latent.shape torch.Size([1, 250, 64])
176
- # first_latent.shape torch.Size([1, 1000, 64])
177
-
178
- first_latent[:,0:true_latent.shape[1],:] = true_latent
179
- first_latent_length = true_latent.shape[1]
180
- first_latent_codes = self.sound2code(prompt)
181
- first_latent_codes_length = first_latent_codes.shape[-1]
182
- codes = torch.cat([first_latent_codes, codes], -1)
183
-
184
- codes_len= codes.shape[-1]
185
- target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
186
- # target_len = int(codes_len / 100 * 4 * self.sample_rate)
187
- # code repeat
188
- if(codes_len < min_samples):
189
- while(codes.shape[-1] < min_samples):
190
- codes = torch.cat([codes, codes], -1)
191
- codes = codes[:,:,0:min_samples]
192
- codes_len = codes.shape[-1]
193
- if((codes_len - ovlp_samples) % hop_samples > 0):
194
- len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
195
- while(codes.shape[-1] < len_codes):
196
- codes = torch.cat([codes, codes], -1)
197
- codes = codes[:,:,0:len_codes]
198
- latent_length = min_samples
199
- latent_list = []
200
- spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
201
- with torch.autocast(device_type="cuda", dtype=torch.float16):
202
- for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
203
- codes_input=[]
204
- codes_input.append(codes[:,:,sinx:sinx+min_samples])
205
- if(sinx == 0):
206
- # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
207
- incontext_length = first_latent_length
208
- latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
209
- latent_list.append(latents)
210
- else:
211
- # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate))
212
- true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1)
213
- print("true_latent.shape", true_latent.shape)
214
- len_add_to_1000 = 1000 - true_latent.shape[-2]
215
- # print("len_add_to_1000", len_add_to_1000)
216
- # exit()
217
- incontext_length = true_latent.shape[-2]
218
- true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2)
219
- latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
220
- latent_list.append(latents)
221
-
222
- latent_list = [l.float() for l in latent_list]
223
- latent_list[0] = latent_list[0][:,:,first_latent_length:]
224
- min_samples = int(min_samples * self.sample_rate // 1000 * 40)
225
- hop_samples = int(hop_samples * self.sample_rate // 1000 * 40)
226
- ovlp_samples = min_samples - hop_samples
227
- with torch.no_grad():
228
- output = None
229
- for i in range(len(latent_list)):
230
- latent = latent_list[i]
231
- cur_output = self.vae.decode_audio(latent)[0].detach().cpu()
232
-
233
- if output is None:
234
- output = cur_output
235
- else:
236
- ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
237
- ov_win = torch.cat([ov_win, 1 - ov_win], -1)
238
- print("output.shape", output.shape)
239
- print("ov_win.shape", ov_win.shape)
240
- output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
241
- output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
242
- output = output[:, 0:target_len]
243
- return output
244
-
245
- @torch.no_grad()
246
- def preprocess_audio(self, input_audios, threshold=0.8):
247
- assert len(input_audios.shape) == 3, input_audios.shape
248
- nchan = input_audios.shape[1]
249
- input_audios = input_audios.reshape(input_audios.shape[0], -1)
250
- norm_value = torch.ones_like(input_audios[:,0])
251
- max_volume = input_audios.abs().max(dim=-1)[0]
252
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
253
- return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
254
-
255
- @torch.no_grad()
256
- def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False):
257
- codes = self.sound2code(sound)
258
- # print(codes.shape)
259
- # exit()
260
- wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
261
- # print(fname, wave.shape)
262
- return wave
263
-
264
- def file2code(self, fname):
265
- try:
266
- orig_samples, fs = torchaudio.load(fname)
267
- except:
268
- af = AudioFile(fname)
269
- orig_samples = af.read()
270
- fs = af.samplerate()
271
- orig_samples = orig_samples[0]
272
- if(fs!=self.sample_rate):
273
- orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
274
- fs = self.sample_rate
275
- if orig_samples.shape[0] == 1:
276
- orig_samples = torch.cat([orig_samples, orig_samples], 0)
277
- return self.sound2code(orig_samples)
278
-
279
- def file2code_ds(self, fname, ds):
280
- try:
281
- orig_samples, fs = torchaudio.load(fname)
282
- except:
283
- af = AudioFile(fname)
284
- orig_samples = af.read()
285
- fs = af.samplerate()
286
- orig_samples = orig_samples[0]
287
- if(fs!=self.sample_rate):
288
- orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
289
- fs = self.sample_rate
290
- if orig_samples.shape[0] == 1:
291
- orig_samples = torch.cat([orig_samples, orig_samples], 0)
292
- return self.sound2code_ds(orig_samples, ds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py DELETED
@@ -1,1278 +0,0 @@
1
- from torch.utils.data import Dataset
2
- from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List, Union
3
- from beartype import beartype
4
- from beartype.door import is_bearable
5
- import random
6
- import pandas as pd
7
- import os
8
- from torchaudio.functional import resample
9
- import torch
10
- import typing as tp
11
- from pathlib import Path
12
- import torchaudio as ta
13
- import torch.nn.functional as F
14
- import numpy as np
15
- import json
16
- import yaml
17
- import torchaudio
18
- import math
19
- import re
20
- from loguru import logger
21
- import ffmpeg
22
-
23
- class Read_and_PadCrop_Normalized_T(torch.nn.Module):
24
- def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
25
-
26
- super().__init__()
27
-
28
- self.n_samples = n_samples
29
- self.sample_rate = sample_rate
30
- self.randomize = randomize
31
-
32
- def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
33
- if self.n_samples < 0: #means not clip
34
- chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
35
- t_start = 0.
36
- t_end = 1.0
37
- offset = 0
38
- else:
39
- if(duration<(float(self.n_samples)/self.sample_rate+1)):
40
- # print(duration,(float(self.n_samples)/self.sample_rate+1))
41
- chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
42
- t_start = 0.
43
- t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
44
- offset = 0
45
- # print('c1:',chunk.shape)
46
- else:
47
- offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
48
- t_start = offset / float(cur_sample_rate) / duration
49
- t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
50
- chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
51
- # print('offset:',offset)
52
- # print('c0:',chunk.shape)
53
- # Pad with silence if necessary.
54
- if(chunk.shape[0]>1):
55
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
56
- else:
57
- chunk = chunk[[0],:].float()
58
- if(cur_sample_rate!=self.sample_rate):
59
- # print('a:',cur_sample_rate,chunk.shape)
60
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
61
- # print('b:',self.sample_rate,chunk.shape)
62
-
63
- if self.n_samples > 0:
64
- if chunk.shape[-1] < self.n_samples:
65
- chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
66
- else:
67
- chunk = chunk[:,0:self.n_samples]
68
- seconds_start = math.floor(offset / cur_sample_rate)
69
- seconds_total = math.floor(duration)
70
-
71
- return (
72
- chunk,
73
- t_start,
74
- t_end,
75
- seconds_start,
76
- seconds_total
77
- )
78
-
79
- class Read_and_PadCrop_Normalized_T_Avoid_Watermark(torch.nn.Module):
80
- def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True, w_start = 0, w_interval = 11.3):
81
-
82
- super().__init__()
83
-
84
- self.n_samples = n_samples
85
- self.sample_rate = sample_rate
86
- self.randomize = randomize
87
-
88
- self.w_start = w_start
89
- self.w_interval = w_interval
90
-
91
- def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
92
- if self.n_samples < 0: #means not clip
93
- chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
94
- t_start = 0.
95
- t_end = 1.0
96
- offset = 0
97
- else:
98
- if(duration<(float(self.n_samples)/self.sample_rate+1)):
99
- # print(duration,(float(self.n_samples)/self.sample_rate+1))
100
- chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
101
- t_start = 0.
102
- t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
103
- offset = 0
104
- # print('c1:',chunk.shape)
105
- else:
106
- n_offset_option = (duration - self.w_start) // self.w_interval
107
- if n_offset_option <= 1:
108
- offset = 0
109
- else:
110
- offset = int((random.randint(0,n_offset_option-1) * self.w_interval + self.w_start) * cur_sample_rate)
111
- # offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
112
- t_start = offset / float(cur_sample_rate) / duration
113
- t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
114
- chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
115
- # print('offset:',offset)
116
- # print('c0:',chunk.shape)
117
- # Pad with silence if necessary.
118
- if(chunk.shape[0]>1):
119
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
120
- else:
121
- chunk = chunk[[0],:].float()
122
- if(cur_sample_rate!=self.sample_rate):
123
- # print('a:',cur_sample_rate,chunk.shape)
124
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
125
- # print('b:',self.sample_rate,chunk.shape)
126
-
127
- if self.n_samples > 0:
128
- if chunk.shape[-1] < self.n_samples:
129
- chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
130
- else:
131
- chunk = chunk[:,0:self.n_samples]
132
- seconds_start = math.floor(offset / cur_sample_rate)
133
- seconds_total = math.floor(duration)
134
-
135
- return (
136
- chunk,
137
- t_start,
138
- t_end,
139
- seconds_start,
140
- seconds_total
141
- )
142
-
143
- USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
144
- if USE_DUMMY_AUDIO:
145
- logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
146
-
147
- class SafeAudioReader:
148
- """
149
- This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
150
- """
151
- def __init__(self,
152
- duration: float, # 返回音频长度
153
- sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
154
- randomize: bool = True,
155
- use_avoid_watermark_policy = False,
156
- ):
157
- self.n_samples = int(sample_rate * duration)
158
- self.reader = (
159
- Read_and_PadCrop_Normalized_T_Avoid_Watermark if use_avoid_watermark_policy \
160
- else Read_and_PadCrop_Normalized_T
161
- )(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
162
-
163
- #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
164
- def __call__(self,
165
- filepath: os.PathLike, # 音频路径
166
- origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
167
- origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
168
- ) -> torch.Tensor:
169
- if USE_DUMMY_AUDIO:
170
- wav = torch.zeros(self.n_samples, dtype=torch.float32)
171
- return wav
172
- try:
173
- if origin_sample_rate is None or origin_duration is None:
174
- # audio_info = torchaudio.info(filepath)
175
- # origin_sample_rate = audio_info.sample_rate
176
- # origin_duration = audio_info.num_frames / origin_sample_rate
177
- info = ffmpeg.probe(filepath)
178
- origin_duration = float(info['format']['duration'])
179
- origin_sample_rate = int(info['streams'][0]['sample_rate'])
180
- wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
181
- wav = wav.squeeze_(0)
182
- except Exception as e:
183
- logger.error(f"Error reading {filepath}: {e}")
184
- wav = torch.zeros(self.n_samples, dtype=torch.float32)
185
- return wav
186
-
187
-
188
- class PromptTemplate:
189
- def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
190
- self.template_text = template_text
191
- self.tag_map = tag_map
192
- self.lang = lang
193
-
194
- @property
195
- def tags(self):
196
- return tuple(self.tag_map.keys())
197
-
198
- def apply(self, **kwargs):
199
- for tag in list(kwargs.keys()):
200
- if kwargs[tag] == '':
201
- kwargs.pop(tag)
202
- for tag in self.tags:
203
- if tag in kwargs:
204
- kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
205
- else:
206
- kwargs[tag] = ''
207
- prompt = self.template_text.format(**kwargs)
208
-
209
- return self.beautify(prompt)
210
-
211
- def beautify(self, text):
212
- if self.lang == 'en':
213
- return self._beautify_en(text)
214
- elif self.lang == 'zh':
215
- return self._beautify_zh(text)
216
- else:
217
- raise ValueError(f'Unknown language {self.lang}')
218
-
219
- @staticmethod
220
- def _beautify_en(text):
221
- # no continuous commas without content between them
222
- text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
223
- # no continuous whitespace
224
- text = re.sub(r'\s+', ' ', text)
225
- # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
226
- text = re.sub(r'\s+,', r',', text)
227
- text = re.sub(r',\s+', r', ', text)
228
- # no whitespace before the full stop
229
- text = re.sub(r'\s+\.', r'.', text)
230
- # strip whitespace, comma, and replace ',.'
231
- text = text.strip(' ,')
232
- text = text.replace(',.', '.')
233
- return text
234
-
235
- @staticmethod
236
- def _beautify_zh(text):
237
- # no continuous commas without content between them
238
- text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
239
- text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
240
- # assume there should be NO whitespace in Chinese
241
- text = re.sub(r'\s+', r'', text)
242
- # strip whitespace, comma, and replace ',。'
243
- text = text.strip(', 、')
244
- text = text.replace(',。', '。')
245
- return text
246
-
247
- def __repr__(self):
248
- return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
249
-
250
- __str__ = __repr__
251
-
252
- def parse_prompt_template(prompt_template_text, lang='en'):
253
- span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
254
- tag_pattern = re.compile(r'{.+?}', re.DOTALL)
255
-
256
- template_text = prompt_template_text.strip()
257
- span_texts = span_pattern.findall(prompt_template_text)
258
- tag_map = {}
259
- for span_text in span_texts:
260
- tag = tag_pattern.findall(span_text)[0].strip('{}')
261
- tag_map[tag] = span_text
262
- template_text = template_text.replace(span_text, '{'+tag+'}')
263
-
264
- return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
265
-
266
- def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
267
- with open(path, 'r') as f:
268
- lines = f.readlines()
269
- cnt = 0
270
- pts = []
271
- for line in lines:
272
- pt = parse_prompt_template(line, lang=lang)
273
- cnt += 1
274
- if len(pt.tags) < num:
275
- logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
276
- pts.append(pt)
277
-
278
- return pts
279
-
280
-
281
- def get_base_dir_file(key: os.PathLike):
282
- base = os.path.basename(key)
283
- dirname = os.path.basename(os.path.dirname(key))
284
- return os.path.join(dirname, base)
285
-
286
- def read_jsonlike(path: os.PathLike):
287
- #json or jsonl
288
- if str(path).endswith(".json"):
289
- with open(path, 'r', encoding='utf8') as f:
290
- data = json.load(f)
291
- return data
292
- elif str(path).endswith(".jsonl"):
293
- with open(path, 'r', encoding='utf8') as f:
294
- data = [json.loads(line) for line in f.readlines()]
295
- return data
296
- else:
297
- raise ValueError("Unknown file format")
298
-
299
- dist_prob_map = {
300
- 1: (1.0,),
301
- 2: (0.5, 0.5),
302
- 3: (0.3, 0.4, 0.3),
303
- 4: (0.2, 0.3, 0.3, 0.2),
304
- 5: (0.2, 0.2, 0.3, 0.2, 0.1),
305
- 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
306
- 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
307
- 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
308
- 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
309
- 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
310
- }
311
-
312
- '''
313
- #更加偏向短文本的方案
314
- dist_prob_map = {
315
- 1: (1.0,),
316
- 2: (0.7, 0.3),
317
- 3: (0.7, 0.2, 0.1),
318
- 4: (0.6, 0.2, 0.1, 0.1),
319
- 5: (0.6, 0.2, 0.1, 0.05, 0.05),
320
- 6: (0.6, 0.15, 0.1, 0.05, 0.05, 0.05),
321
- 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
322
- 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
323
- 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
324
- 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
325
- }
326
- '''
327
-
328
- #全部都用的方案
329
- # dist_prob_map = {
330
- # 1: (1.0,),
331
- # 2: (0, 1.0),
332
- # 3: (0, 0, 1.0),
333
- # 4: (0, 0, 0, 1.0),
334
- # 5: (0, 0, 0, 0, 1.0),
335
- # 6: (0, 0, 0, 0, 0, 1.0),
336
- # 7: (0, 0, 0, 0, 0, 0, 1.0),
337
- # 8: (0, 0, 0, 0, 0, 0, 0, 1.0),
338
- # 9: (0, 0, 0, 0, 0, 0, 0, 0, 1.0),
339
- # 10: (0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0)
340
- # }
341
-
342
- dist_prob_map_low = {
343
- 1: (1.0,),
344
- 2: (0.8, 0.2),
345
- 3: (0.8, 0.1, 0.1),
346
- 4: (0.7, 0.1, 0.1, 0.1),
347
- 5: (0.7, 0.1, 0.1, 0.05, 0.05),
348
- 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
349
- }
350
-
351
- _bpm_range_rights = (
352
- (40, '20-40'),
353
- (60, '40-60'),
354
- (66, '60-66'),
355
- (76, '66-76'),
356
- (108, '76-108'),
357
- (120, '108-120'),
358
- (168, '120-168'),
359
- (176, '168-176'),
360
- (200, '176-200')
361
- )
362
- _bpm_desc_map = {
363
- '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
364
- '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
365
- '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
366
- '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
367
- '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
368
- '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
369
- '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
370
- '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
371
- '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
372
- '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
373
- }
374
- _bpm_desc_map_zh = {
375
- '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
376
- '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
377
- '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
378
- '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
379
- '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
380
- '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
381
- '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
382
- '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
383
- '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
384
- '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
385
- }
386
- def get_bpm_range(bpm):
387
- bpm = int(bpm)
388
- for right, tag in _bpm_range_rights:
389
- if bpm <= right:
390
- return tag
391
- return '>200'
392
-
393
- def gen_bpm_descript(bpm, lang='en'):
394
- bpm_range = get_bpm_range(bpm)
395
- if lang == 'en':
396
- return random.choice(_bpm_desc_map[bpm_range])
397
- elif lang == 'zh':
398
- return random.choice(_bpm_desc_map_zh[bpm_range])
399
- else:
400
- raise ValueError(f"Unknown language {lang}")
401
-
402
- def read_translate(translate: Union[Dict[str, os.PathLike], os.PathLike, None]):
403
- if translate is None:
404
- return None
405
- if isinstance(translate, str):
406
- return read_jsonlike(translate)
407
- return {k: read_jsonlike(path) for k, path in translate.items()}
408
-
409
-
410
- def gen_plain_prompt(key_list, sep=', '):
411
- if len(key_list) == 0:
412
- return 'none'
413
-
414
- key_list = [k.strip() for k in key_list]
415
-
416
- if len(key_list) > 10:
417
- random.shuffle(key_list)
418
- key_list = key_list[:10]
419
-
420
- probs = dist_prob_map[len(key_list)]
421
-
422
- num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
423
-
424
- random.shuffle(key_list)
425
- tags = key_list[:num_tags]
426
- tags_str = sep.join(tags)
427
- return tags_str
428
-
429
-
430
- class MagnaTagATuneDataset(Dataset):
431
- def __init__(self):
432
- pass
433
-
434
-
435
- def tags_to_desc(tag_list, sep=',') -> str:
436
- if not isinstance(tag_list, Sequence):
437
- return str(tag_list)
438
- if isinstance(tag_list, str):
439
- return tag_list
440
- if len(tag_list) <= 0:
441
- return ''
442
- elif len(tag_list) <= 5:
443
- probs = dist_prob_map[len(tag_list)]
444
- tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
445
- random.shuffle(tag_list)
446
- tag_list = tag_list[:tags_num]
447
- return sep.join(tag_list)
448
- else:
449
- probs = dist_prob_map[5]
450
- tags_num = random.choices(range(1, 6), probs)[0]
451
- random.shuffle(tag_list)
452
- tag_list = tag_list[:tags_num]
453
- return sep.join(tag_list)
454
-
455
- def get_sr_and_duration_info(item):
456
- return item.get('sample_rate', None), item.get('duration', None)
457
-
458
- class MtgJamendoDatasetFromJson(Dataset):
459
- def __init__(self,
460
- data_dir:str,
461
- json_path:str,
462
- duration:float=10,
463
- sr:int = 0,
464
- lang = 'en',
465
- plain_rate = 0,
466
- return_audio = True,
467
- return_path = False,
468
- prompt_template_path: os.PathLike = None,
469
- tag_types = [],
470
- translate:Optional[Dict[str, os.PathLike]] = None,
471
- use_literal_none = True,
472
- ):
473
- self.audio_reader = SafeAudioReader(duration, sr)
474
-
475
- self.data_dir = data_dir
476
- self._load_metadata_json(json_path)
477
- self.sr = sr
478
- self.duration = duration
479
- self.plain_rate = plain_rate
480
- self.return_audio = return_audio
481
- self.return_path = return_path
482
- self.use_literal_none = use_literal_none
483
- self.lang = lang
484
-
485
- self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
486
- if self.use_dynamic_prompt:
487
- self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
488
- self.tag_types = tag_types
489
-
490
- self.translate = read_translate(translate)
491
-
492
- #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
493
- WEAK_TAG_LIST = ["title", "artist"]
494
-
495
- def _load_metadata_json(self, json_path):
496
- with open(json_path) as fp:
497
- self.data = json.load(fp)
498
-
499
- def convert_key_to_path(self, key):
500
- return os.path.join(self.data_dir, get_base_dir_file(key))
501
-
502
- def __len__(self):
503
- return len(self.data)
504
-
505
- def __getitem__(self, idx):
506
- item = self.data[idx]
507
- path = self.convert_key_to_path(item['key'])
508
- description = self.generate_description(item)
509
-
510
- if self.return_audio:
511
- sr, duration = get_sr_and_duration_info(item)
512
- audio = self.audio_reader(path, sr, duration)
513
- else:
514
- audio = None
515
-
516
- if self.return_path:
517
- return audio, description, path
518
- return audio, description
519
-
520
- def tags_to_desc(self, tag_list, tag_type) -> str:
521
- if self.lang == 'en':
522
- return tags_to_desc(tag_list)
523
- elif self.lang == 'zh':
524
- translator = self.translate[tag_type]
525
- translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
526
- return tags_to_desc(translated_tag_list, sep='、')
527
-
528
- def generate_description(self, item):
529
- if random.random() > self.plain_rate:
530
- # dynamically generate prompt from given prompt template
531
- prompt_template = random.choice(self.prompt_templates)
532
- description = self.generate_description_dynamic(item, prompt_template)
533
- else:
534
- # use plain prompt, i.e. tags sequence separated by comma
535
- description = self.generate_description_plain(item)
536
- return description
537
-
538
- def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
539
- exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
540
- exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
541
- exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
542
-
543
- if len(exists_strong_tag) > 0:
544
- probs = dist_prob_map[len(exists_strong_tag)]
545
- tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
546
- random.shuffle(exists_strong_tag)
547
- tags = exists_strong_tag[:tags_num]
548
- weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
549
- weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
550
- random.shuffle(exists_weak_tag)
551
- weak_tags = exists_weak_tag[:weak_tags_num]
552
- tags += weak_tags
553
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
554
- prompt = prompt_template.apply(**tags_args)
555
- else:
556
- # no strong tags, use all weak tags instead
557
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
558
- prompt = prompt_template.apply(**tags_args)
559
-
560
- if self.use_literal_none and len(tags_args) == 0:
561
- return 'none'
562
-
563
- return prompt
564
-
565
- def generate_description_plain(self, item):
566
- keywords = []
567
- for tag_t in self.tag_types:
568
- this_key = item[tag_t]
569
- if this_key is None:
570
- continue
571
- if isinstance(this_key, str):
572
- this_key = [this_key]
573
- if self.lang != 'en':
574
- this_key = [self.get_translation(tag_t, k) for k in this_key]
575
- keywords += this_key
576
- return gen_plain_prompt(keywords, sep=self.keysep)
577
-
578
- def get_translation(self, tag_t, k):
579
- k = k.strip()
580
- if k in self.translate[tag_t]:
581
- return self.translate[tag_t][k]
582
- else:
583
- return k
584
-
585
- @property
586
- def keysep(self):
587
- if self.lang == 'zh':
588
- return ',' if random.random() > 0.5 else '、'
589
- elif self.lang == 'en':
590
- return ', '
591
-
592
- class AudioStockDataset(Dataset):
593
- def __init__(self,
594
- metadata_path:str,
595
- duration:float=10,
596
- sr:int = 0,
597
- plain_rate = 0,
598
- return_path = False,
599
- return_audio = True,
600
- prompt_template_path: os.PathLike = None,
601
- tag_types = [],
602
- lang = 'en',
603
- translate:Optional[Dict[str, os.PathLike]] = None,
604
- use_literal_none = True,
605
- ):
606
- self.audio_reader = SafeAudioReader(duration, sr)
607
-
608
- self._load_metadata(metadata_path)
609
- self.sr = sr
610
- self.duration = duration
611
- self.plain_rate = plain_rate
612
- self.return_path = return_path
613
- self.return_audio = return_audio
614
- self.use_literal_none = use_literal_none
615
-
616
- self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
617
- if self.use_dynamic_prompt:
618
- self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
619
- self.tag_types = tag_types
620
-
621
- self.lang = lang
622
- self.translate = read_translate(translate)
623
-
624
- def _load_metadata(self, metadata_path):
625
- with open(metadata_path) as fp:
626
- lines = fp.readlines()
627
- self.data = []
628
- for line in lines:
629
- item = json.loads(line)
630
- self.data.append(item)
631
- self.is_info_recorded = bool('Tags' in self.data[0])
632
-
633
- def __len__(self):
634
- return len(self.data)
635
-
636
- def __getitem__(self, idx):
637
- path:str = self.data[idx]["path"]
638
- json_path = path[:path.rfind('.')] + ".json"
639
- if self.is_info_recorded:
640
- item = self.data[idx]
641
- else:
642
- try:
643
- with open(json_path) as fp:
644
- item:dict = json.load(fp)
645
- except Exception as e:
646
- print(f"Error loading json file {json_path} :\n{e}")
647
- item = {}
648
- description = self.generate_description(item)
649
- if self.return_audio:
650
- sr, duration = get_sr_and_duration_info(item)
651
- audio = self.audio_reader(path, sr, duration)
652
- else:
653
- audio = None
654
- if self.return_path:
655
- return audio, description, path
656
- return audio, description
657
-
658
- def generate_description(self, item):
659
- if random.random() > self.plain_rate:
660
- # dynamically generate prompt from given prompt template
661
- prompt_template = random.choice(self.prompt_templates)
662
- description = self.generate_description_dynamic(item, prompt_template)
663
- else:
664
- # use plain prompt, i.e. tags sequence separated by comma
665
- description = self.generate_description_plain(item)
666
- return description
667
-
668
- def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
669
- exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
670
-
671
- if len(exists_tag) > 0:
672
- probs = dist_prob_map[len(exists_tag)]
673
- tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
674
- random.shuffle(exists_tag)
675
- tags = exists_tag[:tags_num]
676
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
677
- tags_args = self.handle_BPM_tag(tags_args)
678
- prompt = prompt_template.apply(**tags_args)
679
- else:
680
- return 'none'
681
-
682
- if self.use_literal_none and len(tags_args) == 0:
683
- return 'none'
684
-
685
- return prompt
686
-
687
- def get_translation(self, tag_t, k):
688
- k = k.strip()
689
- if k in self.translate[tag_t]:
690
- return self.translate[tag_t][k]
691
- else:
692
- return k
693
-
694
- def generate_description_plain(self, item):
695
- keywords = []
696
- for tag_t in self.tag_types:
697
- if tag_t == 'BPMDescript':
698
- bpm = item['BPM']
699
- if bpm is None or bpm.strip() == '' or bpm.strip() == '0':
700
- continue
701
- this_key = gen_bpm_descript(bpm.strip(), lang=self.lang)
702
- elif tag_t == 'BPM':
703
- bpm = item['BPM']
704
- if bpm is None or bpm.strip() == '' or bpm.strip() == '0':
705
- continue
706
- this_key = f"{bpm.strip()} bpm"
707
- else:
708
- this_key = item[tag_t]
709
- if this_key is None:
710
- continue
711
- if isinstance(this_key, str):
712
- this_key = [this_key]
713
- if self.lang != 'en':
714
- this_key = [self.get_translation(tag_t, k) for k in this_key]
715
- if this_key is None:
716
- continue
717
- if isinstance(this_key, str):
718
- this_key = [this_key]
719
- keywords += this_key
720
- return gen_plain_prompt(keywords, sep=self.keysep)
721
-
722
- @property
723
- def keysep(self):
724
- if self.lang == 'zh':
725
- return ',' if random.random() > 0.5 else '、'
726
- elif self.lang == 'en':
727
- return ', '
728
-
729
- def tags_to_desc(self, tag_list, tag_type) -> str:
730
- if self.lang == 'en':
731
- return tags_to_desc(tag_list)
732
- elif self.lang == 'zh':
733
- if tag_type == 'BPM':
734
- return tags_to_desc(tag_list, sep='、')
735
- translator = self.translate[tag_type]
736
- translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
737
- return tags_to_desc(translated_tag_list, sep='、')
738
-
739
- def handle_BPM_tag(self, tags_args):
740
- if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
741
- bpm = tags_args["BPM"]
742
- del tags_args["BPM"]
743
- tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
744
- for tag_type in tag_types_used:
745
- tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
746
- return tags_args
747
-
748
- def mp3_path_to_id(mp3_path):
749
- return int(
750
- mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.')]
751
- )
752
-
753
- class TmeDataset(Dataset):
754
- def __init__(self,
755
- data_index:str,
756
- music_info:str = None,
757
- duration:float = 10,
758
- sr:int = 0,
759
- plain_rate = 0,
760
- return_path = False,
761
- return_audio = True,
762
- return_ID = False,
763
- prompt_format_path: os.PathLike = None,
764
- tag_types = ['*'],
765
- lang = 'zh',
766
- translate: Optional[os.PathLike] = None,
767
- prompt_dir: os.PathLike = None, #使用GPT生成的预有的prompt
768
- ):
769
- if plain_rate > 0:
770
- print("Tme Dataset do not support plain rate > 0, use plain_rate = 0 instead.")
771
- plain_rate = 0
772
- self.audio_reader = SafeAudioReader(duration, sr)
773
-
774
- self.sr = sr
775
- self.duration = duration
776
- self.plain_rate = plain_rate
777
- self.return_path = return_path
778
- self.return_audio = return_audio
779
- self.return_ID = return_ID
780
- self.lang = lang
781
-
782
- self.use_ready_prompt = prompt_dir is not None
783
-
784
- data_index = read_jsonlike(data_index)
785
- self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
786
- self.data_ids = list(self.data_index_dict.keys())
787
-
788
- if not self.use_ready_prompt:
789
- #读取音乐的信息文件
790
- music_info = read_jsonlike(music_info)
791
- if 'music' in music_info:
792
- music_info = music_info['music']
793
- self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
794
- self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
795
- self.data_ids = list(self.data_index_dict.keys())
796
-
797
- with open(prompt_format_path) as fp:
798
- self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
799
-
800
- #加载tag types,并分成一般的tag_types和关键的key_tag_types
801
- if '*' in tag_types:
802
- self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
803
- else:
804
- self.tag_types = tag_types
805
-
806
- self.key_tag_types = []
807
- if 'tag' in self.tag_types:
808
- self.tag_types.remove('tag')
809
- self.key_tag_types = list(self.prompt_formats['tag'].keys())
810
-
811
- #加载translate翻译
812
- if translate is not None:
813
- self.translator = read_jsonlike(translate)
814
- else:
815
- data_ids_set = set(self.data_ids)
816
- self.prompts_dict = {}
817
- for fname in os.listdir(prompt_dir):
818
- items = read_jsonlike(os.path.join(prompt_dir, fname))
819
- for item in items:
820
- if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
821
- continue
822
- if item['ID'] not in self.prompts_dict:
823
- self.prompts_dict[item['ID']] = []
824
- self.prompts_dict[item['ID']].append(item['Text'])
825
- self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
826
- self.data_ids = list(self.data_index_dict.keys())
827
-
828
- def tags_to_desc(self, tag_list) -> str:
829
- if is_bearable(tag_list, int):
830
- return str(tag_list)
831
- if self.lang == 'zh':
832
- return tags_to_desc(tag_list, sep=self.sep)
833
- else:
834
- translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
835
- return tags_to_desc(translated_tag_list, sep=self.sep)
836
-
837
- def gen_desc_of_tag(self, formats, tags):
838
- fmt = random.choice(formats)
839
- return fmt.format(self.tags_to_desc(tags))
840
-
841
- @staticmethod
842
- def check_valid(value):
843
- if isinstance(value, int) or isinstance(value, float):
844
- return value > 0
845
- if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
846
- return True
847
- return False
848
-
849
- @staticmethod
850
- def remove_repeat(data):
851
- #若专辑名和歌曲名相同,则只使用后者
852
- album_name = data.get('专辑名', None)
853
- if album_name is not None and album_name == data.get('歌曲名', None):
854
- del data['专辑名']
855
- return data
856
-
857
- @property
858
- def comma(self):
859
- if self.lang == 'zh':
860
- return ','
861
- elif self.lang == 'en':
862
- return ', '
863
-
864
- @property
865
- def sep(self):
866
- if self.lang == 'zh':
867
- return '、'
868
- elif self.lang == 'en':
869
- return ', '
870
-
871
-
872
- def generate_description(self, item):
873
- if random.random() > self.plain_rate:
874
- # dynamically generate prompt from given prompt template
875
- description = self.generate_description_dynamic(item)
876
- else:
877
- # use plain prompt, i.e. tags sequence separated by comma
878
- description = self.generate_description_plain(item)
879
- return description
880
-
881
- def generate_description_dynamic(self, data):
882
- data = self.remove_repeat(data)
883
-
884
- weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
885
-
886
- key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
887
-
888
- prompts = []
889
- if len(weak_tags) > 0:
890
- probs = dist_prob_map_low[len(weak_tags)]
891
- if len(key_tags) > 0:
892
- tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
893
- else:
894
- tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
895
- random.shuffle(weak_tags)
896
- tags = weak_tags[:tags_num]
897
- for tag_type in tags:
898
- tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
899
- prompts.append(tag_desc)
900
-
901
- if len(key_tags) > 0:
902
- probs = dist_prob_map[len(key_tags)]
903
- tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
904
- random.shuffle(key_tags)
905
- tags = key_tags[:tags_num]
906
- for tag_type in tags:
907
- tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
908
- prompts.append(tag_desc)
909
-
910
- random.shuffle(prompts)
911
- return self.comma.join(prompts)
912
-
913
- def generate_description_plain(self, item):
914
- keywords = item['tag']
915
- if self.lang != 'en':
916
- keywords = [self.translator[k.strip()] for k in keywords]
917
- return gen_plain_prompt(keywords, sep=self.keysep)
918
-
919
- @property
920
- def keysep(self):
921
- if self.lang == 'zh':
922
- return ',' if random.random() > 0.5 else '、'
923
- elif self.lang == 'en':
924
- return ', '
925
-
926
- def is_valid_prompt_text(self, text):
927
- for bad in ('抱歉','sorry', 'Sorry'):
928
- if bad in text:
929
- return False
930
- return True
931
-
932
- def get_ready_prompt(self, path):
933
- sid = mp3_path_to_id(path)
934
- return random.choice(self.prompts_dict[sid])
935
-
936
- def __len__(self):
937
- return len(self.data_ids)
938
-
939
- def __getitem__(self, idx):
940
- data_id = self.data_ids[idx]
941
- item = self.data_index_dict[data_id]
942
- path = item['path']
943
- if not self.use_ready_prompt:
944
- info = self.music_info_dict[data_id]
945
- description = self.generate_description(info)
946
- else:
947
- description = self.get_ready_prompt(path)
948
- if self.return_audio:
949
- sr, duration = get_sr_and_duration_info(item)
950
- audio = self.audio_reader(path, sr, duration)
951
- else:
952
- audio = None
953
- if self.return_path:
954
- if self.return_ID:
955
- return audio, description, path, info['歌曲ID']
956
- return audio, description, path
957
- if self.return_ID:
958
- return audio, description, info['歌曲ID']
959
- return audio, description
960
-
961
-
962
- class Pond5Dataset(Dataset):
963
- MAX_PROMPT_LEN = 200
964
- def __init__(self,
965
- metadata_path:str,
966
- index_path:str,
967
- duration:float=10,
968
- sr:int = 0,
969
- plain_rate = 0,
970
- return_path = False,
971
- return_audio = True,
972
- lang = 'en',
973
- translate:Optional[Dict[str, os.PathLike]] = None,
974
- use_literal_none = True,
975
- use_avoid_watermark_policy = None,
976
- ):
977
-
978
- if use_avoid_watermark_policy is None:
979
- raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type")
980
- self.use_avoid_watermark_policy = use_avoid_watermark_policy
981
- self.audio_reader = SafeAudioReader(duration, sr, use_avoid_watermark_policy=use_avoid_watermark_policy)
982
-
983
- self._load_metadata(metadata_path, index_path)
984
- self.sr = sr
985
- self.duration = duration
986
- self.plain_rate = plain_rate
987
- self.return_path = return_path
988
- self.return_audio = return_audio
989
- self.use_literal_none = use_literal_none
990
-
991
- self.lang = lang
992
- self.translate = read_translate(translate)
993
-
994
- def _load_metadata(self, metadata_path, index_path):
995
- data_index = read_jsonlike(index_path)
996
- data_ids = set([item['id'] for item in data_index])
997
-
998
- with open(metadata_path) as fp:
999
- lines = fp.readlines()
1000
-
1001
- append_ids = set()
1002
-
1003
- self.data = []
1004
- for line in lines:
1005
- item = json.loads(line)
1006
- if item['id'] in data_ids and item['id'] not in append_ids:
1007
- self.data.append(item)
1008
- append_ids.add(item['id'])
1009
-
1010
- def __len__(self):
1011
- return len(self.data)
1012
-
1013
- def __getitem__(self, idx):
1014
- item = self.data[idx]
1015
- path:str = item["path"]
1016
- description = self.generate_description(item)
1017
- if self.return_audio:
1018
- sr, duration = get_sr_and_duration_info(item)
1019
- audio = self.audio_reader(path, sr, duration)
1020
- else:
1021
- audio = None
1022
- if self.return_path:
1023
- return audio, description, path
1024
- return audio, description
1025
-
1026
- @property
1027
- def keysep(self):
1028
- if self.lang == 'zh':
1029
- return ',' if random.random() > 0.5 else '、'
1030
- elif self.lang == 'en':
1031
- return ', '
1032
-
1033
- def generate_description(self, item):
1034
- if random.random() > self.plain_rate:
1035
- # dynamically generate prompt from given prompt template
1036
- description = self.generate_description_dynamic(item)
1037
- else:
1038
- # use plain prompt, i.e. tags sequence separated by comma
1039
- description = self.generate_description_plain(item)
1040
- return description
1041
-
1042
- def get_translation(self, k):
1043
- k = k.strip()
1044
- if k in self.translate:
1045
- return self.translate[k]
1046
- else:
1047
- return k
1048
-
1049
- def generate_description_plain(self, item):
1050
- keywords = item['keywords']
1051
- if self.lang != 'en':
1052
- keywords = [self.get_translation(k) for k in keywords]
1053
- return gen_plain_prompt(keywords, sep=self.keysep)
1054
-
1055
- def generate_description_dynamic(self,item):
1056
- desc = item.get('desc', 'none')
1057
- if desc is None:
1058
- desc = 'none'
1059
- desc = desc.strip()
1060
- if len(desc) > self.MAX_PROMPT_LEN:
1061
- shorter_desc = desc[:self.MAX_PROMPT_LEN]
1062
- # find last stop
1063
- stop_idx = shorter_desc.rfind('.')
1064
- if stop_idx == -1:
1065
- stop_idx = shorter_desc.rfind('!')
1066
- if stop_idx == -1:
1067
- stop_idx = shorter_desc.rfind(',')
1068
- if stop_idx == -1:
1069
- stop_idx = self.MAX_PROMPT_LEN - 1
1070
- desc = desc[:stop_idx+1]
1071
- return desc
1072
-
1073
- class SoundDataset(Dataset):
1074
- def __init__(self,
1075
- metadata_index: str,
1076
- duration:float = 10,
1077
- min_non_silent_duration:float = 3,
1078
- sr:int = 0,
1079
- return_path = False,
1080
- return_audio = True,
1081
- ):
1082
- self.data = read_jsonlike(metadata_index)
1083
- self.sr = sr
1084
- self.reader = SafeAudioReader(duration, sr)
1085
- self.duration = duration
1086
- self.min_non_silent_duration = min_non_silent_duration
1087
- self.return_audio = return_audio
1088
- self.return_path = return_path
1089
-
1090
- def __getitem__(self, index):
1091
- item = self.data[index]
1092
- if self.return_audio:
1093
- origin_duration = item['duration']
1094
- if origin_duration < self.min_non_silent_duration:
1095
- audio = self.read_and_repeat_and_pad(item)
1096
- else:
1097
- audio = self.reader(item['path'], item['sample_rate'], origin_duration)
1098
- else:
1099
- audio = None
1100
- desc = item['caption']
1101
- if self.return_path:
1102
- return audio, desc, item['path']
1103
- else:
1104
- return audio, desc
1105
-
1106
- def __len__(self):
1107
- return len(self.data)
1108
-
1109
- def read_and_repeat_and_pad(self, item):
1110
- path = item['path']
1111
- try:
1112
- # read
1113
- clip, sr = torchaudio.load(path)
1114
- if len(clip.shape) > 1:
1115
- clip = torch.mean(clip, dim=0, keepdim=True)
1116
- clip = resample(clip, sr, self.sr)
1117
- #repeat
1118
- n_repeats = math.ceil(self.min_non_silent_duration/item['duration'])
1119
- clip = torch.repeat_interleave(clip, n_repeats, dim=0).reshape(-1)
1120
- #pad
1121
- n_samples = int(self.duration * self.sr)
1122
- if clip.shape[0] >= n_samples:
1123
- audio = clip[:n_samples]
1124
- else:
1125
- audio = torch.zeros(int(self.duration * self.sr), dtype=clip.dtype)
1126
- start_pos = np.random.randint(0, max(0,(n_samples - clip.shape[0])))
1127
- audio[start_pos:start_pos+clip.shape[0]] = clip
1128
- return audio
1129
-
1130
- except Exception as e:
1131
- logger.error(f"Error reading {path}: {e}")
1132
- wav = torch.zeros(int(self.duration * self.sr), dtype=torch.float32)
1133
- return wav
1134
-
1135
- class CombinedDataset(Dataset):
1136
- @beartype
1137
- def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
1138
- self.datasets = datasets
1139
- self.datasets_index = []
1140
-
1141
- for i,dataset in enumerate(datasets):
1142
- if dataset is None:
1143
- continue
1144
- for dup in range(ratios[i]):
1145
- for j in range(len(dataset)):
1146
- self.datasets_index.append((i,j))
1147
-
1148
- def __len__(self):
1149
- return len(self.datasets_index)
1150
-
1151
- def __getitem__(self, idx):
1152
- index = self.datasets_index[idx]
1153
- i,j = index
1154
- return self.datasets[i][j]
1155
-
1156
- class CombinedDataset_random(Dataset):
1157
- @beartype
1158
- def __init__(self, num_examples:int, datasets: Sequence[Dataset], ratios: Sequence[int]):
1159
- self.datasets = datasets
1160
- self.datasets_index = []
1161
-
1162
- for i,dataset in enumerate(datasets):
1163
- if dataset is None:
1164
- continue
1165
- for dup in range(ratios[i]):
1166
- for j in range(len(dataset)):
1167
- self.datasets_index.append((i,j))
1168
-
1169
- if num_examples > 0:
1170
- self.random_choose = True
1171
- self.dataset_len = num_examples
1172
- else:
1173
- self.random_choose = False
1174
- self.dataset_len = len(self.datasets_index)
1175
-
1176
- def __len__(self):
1177
- return self.dataset_len
1178
-
1179
- def __getitem__(self, idx):
1180
- first_try = True
1181
- try_cnt = 0
1182
- while True:
1183
- try:
1184
- if(self.random_choose or not first_try):
1185
- index2 = []
1186
- index2.append(np.random.randint(0,len(self.datasets)))
1187
- index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
1188
- else:
1189
- index2 = self.datasets_index[idx]
1190
- first_try = False
1191
- out = list(self.datasets[index2[0]][index2[1]])
1192
- return out
1193
- except:
1194
- print("Error loadding ", index2)
1195
- try_cnt += 1
1196
- if(try_cnt>10):
1197
- raise ValueError()
1198
-
1199
- class SoundMixedDataset(Dataset):
1200
- @staticmethod
1201
- def music_desc(desc):
1202
- return f'Music:<{desc}>'
1203
- @staticmethod
1204
- def sound_desc(desc):
1205
- return f'Effect:<{desc}>'
1206
-
1207
- def __init__(self,
1208
- music_dataset: Dataset,
1209
- sound_dataset: Dataset,
1210
- mixed_ratios: Tuple[float, float, float] = (0.3, 0.3, 0.4) # 只有音乐:只有音效:音乐音效混合 的比例
1211
- ) -> None:
1212
- self.music_dataset = music_dataset
1213
- self.sound_dataset = sound_dataset
1214
- music_r, sound_r, mix_r = [r/sum(mixed_ratios) for r in mixed_ratios] #化为0-1间的比例
1215
- #三个概率区间的左端点
1216
- self.music_anchor = 0
1217
- self.sound_anchor = music_r
1218
- self.mix_anchor = music_r + sound_r
1219
-
1220
- def __len__(self):
1221
- return len(self.music_dataset)
1222
-
1223
- def get_random_sound_data(self):
1224
- idx = random.randint(0, len(self.sound_dataset)-1)
1225
- return self.sound_dataset[idx]
1226
-
1227
- def __getitem__(self, idx):
1228
- p = random.random()
1229
- if p >= self.mix_anchor:
1230
- music, m_desc = self.music_dataset[idx]
1231
- sound, s_desc = self.get_random_sound_data()
1232
- audio = music + sound
1233
- if(audio.abs().max()>1.0):
1234
- music = music / audio.abs().max() * 0.95
1235
- audio = audio / audio.abs().max() * 0.95
1236
- desc = self.music_desc(m_desc) + self.sound_desc(s_desc)
1237
- return audio[None,:], music[None,:], desc
1238
- elif p >= self.sound_anchor:
1239
- audio, desc = self.get_random_sound_data()
1240
- return audio[None,:], torch.zeros_like(audio[None,:]), self.sound_desc(desc)
1241
- else:
1242
- audio, desc = self.music_dataset[idx]
1243
- return audio[None,:], audio[None,:], self.music_desc(desc)
1244
-
1245
-
1246
- class DecoTagDataset(Dataset):
1247
- '''这个类把普通的datatset包装成适用于标签解耦学习的dataset'''
1248
-
1249
- TAG_TYPES = ('genre', 'mood', 'insrument')
1250
-
1251
- def __init__(self, dataset_class: type, tag_map: Dict[str, str], *args, **kwargs):
1252
- self.datasets = []
1253
- for i, tag_t in enumerate(self.TAG_TYPES):
1254
- kwargs['tag_types'] = [tag_map[tag_t]]
1255
- kwargs['return_audio'] = (i == 0) #只有第0个需要返回音频和文本,其余只需要返回文本
1256
- self.datasets.append(dataset_class(*args, **kwargs))
1257
-
1258
- def __len__(self):
1259
- return len(self.datasets[0])
1260
-
1261
- def __getitem__(self, idx):
1262
- audio, text = self.datasets[0][idx]
1263
- texts = (text, self.datasets[1][idx][1], self.datasets[2][idx][1])
1264
- return audio, texts
1265
-
1266
-
1267
- class DecoTagWrapper:
1268
- '''这是一个包装器,便于选择是否使用标签解耦学习'''
1269
- def __init__(self, dataset_class: Dataset, deco_tag_types: List[str] = list(), switch_on: bool = False):
1270
- self.dataset_class = dataset_class
1271
- self.tag_map = dict(zip(DecoTagDataset.TAG_TYPES, deco_tag_types))
1272
- self.switch_on = switch_on
1273
-
1274
- def __call__(self, *args, **kwargs):
1275
- if self.switch_on:
1276
- return DecoTagDataset(self.dataset_class, self.tag_map, *args, **kwargs)
1277
- else:
1278
- return self.dataset_class(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py DELETED
@@ -1,372 +0,0 @@
1
- import re
2
- import sys
3
- import json
4
- from typing import List, Union
5
-
6
- from torch.utils.data import Dataset
7
- import torchaudio
8
- from torchaudio.functional import resample
9
- import torch
10
- import numpy as np
11
-
12
- from torch.nn.utils.rnn import pad_sequence
13
-
14
- PARAGRAPH_GAP = 6
15
- MIN_MUSIC_LEN = 3
16
-
17
- def check_lryics(lyric):
18
- _FILTER_STRING = [
19
- '作词', '作曲', '编曲', '【', '策划',
20
- '录音', '混音', '母带', ':', '制作',
21
- '版权', '校对', '演奏', '制作', '伴奏'
22
- ]
23
- for item in _FILTER_STRING:
24
- if item in lyric:
25
- return True
26
-
27
- return False
28
-
29
-
30
-
31
- def process_lyrics(lines):
32
- lyric_part = []
33
- timestamp_part = []
34
-
35
- timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
36
-
37
- for i, line in enumerate(lines):
38
-
39
- # 删除前几行的特定信息
40
- if i<10 and check_lryics(line):
41
- continue
42
-
43
- # 检查是否包含有效的时间戳和歌词内容
44
- if timestamp_pattern.match(line):
45
- timestamp_end = line.rfind(']')
46
- lyrics = line[timestamp_end + 1:].strip()
47
- timestamps = line[:timestamp_end + 1]
48
-
49
- if ':' in lyrics:
50
- if len(lyrics.split(":")[0]) <=5:
51
- lyrics = "".join(lyrics.split(":")[1:])
52
- # if lyrics: # 确保歌词部分不是空的
53
- # lyric_part.append(lyrics)
54
- # timestamp_part.append(timestamps)
55
- # print(processed_lyrics)
56
- return timestamp_part, lyric_part
57
-
58
- def get_timestamps(timestamp_part):
59
-
60
- # 转换为秒
61
-
62
- timestamps = []
63
-
64
- for line in timestamp_part:
65
- match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
66
- if match:
67
- minutes = int(match.group(1))
68
- seconds = float(match.group(2))
69
- millis = float(match.group(3)) if match.group(3) else 0
70
- total_seconds = minutes * 60 + seconds + millis
71
- timestamps.append(total_seconds)
72
-
73
-
74
- return timestamps
75
-
76
- def process_lyrics_lrc(lyrics):
77
- timestamp_part, lyric_part = process_lyrics(lyrics)
78
- # print(timestamp_part)
79
- # print(lyric_part)
80
- timestamps = get_timestamps(timestamp_part)
81
- # print(timestamps)
82
- if len(timestamps) == 0:
83
- # print(f'{lyric_path}')
84
- return []
85
-
86
- slice_start = timestamps[0]
87
- slice_start_idx = 0
88
-
89
- output_list = []
90
- for i in range(1, len(timestamps)):
91
- # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
92
- if timestamps[i] - slice_start > 30:
93
- output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
94
-
95
- slice_start = timestamps[i]
96
- slice_start_idx = i
97
-
98
- return output_list
99
-
100
-
101
-
102
- def process_lyrics_yrc(lyrics):
103
-
104
- timestamps, lyric_part = extract_lrc(lyrics)
105
-
106
- # timestamp_part, lyric_part = process_lyrics(lyrics)
107
- # import pdb; pdb.set_trace()
108
- # print(timestamp_part)
109
- # print(lyric_part)
110
- # timestamps = get_timestamps(timestamp_part)
111
- # print(timestamps)
112
- if len(timestamps) == 0:
113
- # print(f'{lyric_path}')
114
- return []
115
-
116
- slice_start = timestamps[0]
117
- slice_start_idx = 0
118
-
119
- output_list = []
120
- for i in range(1, len(timestamps)):
121
- # 如果累积时间超过30秒,则进行切分
122
- if timestamps[i] - slice_start > 30:
123
- output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
124
-
125
- slice_start = timestamps[i]
126
- slice_start_idx = i
127
- # import pdb; pdb.set_trace()
128
- return output_list
129
-
130
- def extract_lrc(lyrics):
131
- timestamp_part, lyric_part = [], []
132
-
133
- for i, text in enumerate(lyrics):
134
- # 提取中括号内的内容
135
- bracket_content = re.search(r'\[(.*?)\]', text).group(1)
136
- bracket_content = bracket_content.split(',')
137
- # 提取小括号内的内容
138
- parentheses_content = re.findall(r'\((.*?)\)', text)
139
- # 提取其他内容
140
- other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
141
-
142
- # 数据怎么处理?
143
- if i<10 and check_lryics(other_content):
144
- continue
145
- timestamp_part.append(float(bracket_content[0])/1000)
146
- lyric_part.append(other_content)
147
- return timestamp_part, lyric_part
148
-
149
-
150
-
151
- class WYYSongDataset(Dataset):
152
- def __init__(self,
153
- metadata_path: Union[str, List[str]],
154
- sr:int = 0,
155
- use_lang = ['en', 'zh-cn'],
156
- num_examples = -1,
157
- max_dur = 20,
158
- min_dur=0,
159
- add_music=False,
160
- pad_to_max= True,
161
- ):
162
-
163
- self.sr = sr
164
- self.use_lang = use_lang
165
- self.data = []
166
- if type(metadata_path) == str:
167
- metadata_path = [metadata_path]
168
- for _meta in metadata_path:
169
- self._load_metadata(_meta)
170
- self.max_dur = max_dur
171
- self.min_dur = min_dur
172
- self.pad_to_max = pad_to_max
173
- self.add_music = add_music
174
-
175
- # buffer
176
- self.lyric_buffer = {}
177
-
178
- if(num_examples<=0):
179
- self.dataset_len = len(self.data)
180
- self.random_slc = False
181
- else:
182
- self.dataset_len = num_examples
183
- self.random_slc = True
184
-
185
-
186
- # 读取jsonl文件
187
- def _load_metadata(self, metadata_path):
188
- with open(metadata_path) as fp:
189
- lines = fp.readlines()
190
- for line in lines:
191
- item = json.loads(line)
192
- if '伴奏' not in item['path']:
193
- # if "lang_type" in item and item['lang_type'] == 'en':
194
- if "lang_type" in item:
195
- self.data.append(item)
196
-
197
-
198
- def __len__(self):
199
- return self.dataset_len
200
-
201
-
202
- def __getitem__(self, idx):
203
- try_cnt = 0
204
- while True:
205
- if(self.random_slc):
206
- idx = np.random.randint(0, len(self.data))
207
- yrc_lyrics = []
208
- lrc_lyrics = []
209
- try:
210
- info = self.data[idx]
211
-
212
- # audio path
213
- path = info["path"]
214
- lang_type = info["lang_type"]
215
- lyrics = info['lyrics'] # chinese
216
- # lyrics = info['lyrics_phone']
217
-
218
- # 随机选取一个lyric段落
219
-
220
- parsed_lyrics = []
221
- # st_idx = np.random.randint(0, len(lyrics))
222
- for ly_id in range(len(lyrics)):
223
- lyric = lyrics[ly_id].strip()
224
- st, et, lyric = self.parse_lyric(lyric)
225
-
226
- if et - st >= self.max_dur:
227
- continue #TODO 前后外沿 [MUSIC]
228
-
229
- if parsed_lyrics != []:
230
- if st - parsed_lyrics[-1][1] >= PARAGRAPH_GAP: # 大gap
231
- parsed_lyrics.append((parsed_lyrics[-1][1], st, '[GAP]'))
232
- elif self.add_music and st - parsed_lyrics[-1][1] >= MIN_MUSIC_LEN:
233
- parsed_lyrics.append((parsed_lyrics[-1][1], st, '[MUSIC]'))
234
-
235
- lyric = lyric.replace("\xa0", " ")
236
- lyric = " ".join(lyric.split())
237
- parsed_lyrics.append((st, et, lyric))
238
-
239
- assert parsed_lyrics != []
240
- # if parsed_lyrics[-1][1] - parsed_lyrics[0][0] > self.max_dur:
241
- # print(f"{parsed_lyrics[0][0]}-{parsed_lyrics[-1][1]} {parsed_lyrics}", file=open('tmp.txt', 'a'))
242
-
243
- parsed_lyrics = [(0, parsed_lyrics[0][0], '[GAP]')] + parsed_lyrics
244
-
245
- possible_starts = [e for e,i in enumerate(parsed_lyrics) if i[2]=='[GAP]']
246
- st_idx = np.random.choice(possible_starts)
247
-
248
- paraphrase = []
249
- for i in parsed_lyrics[st_idx+1:]:
250
- if i[2] == '[GAP]':
251
- break
252
- paraphrase.append(i)
253
- # print(paraphrase, lyrics)
254
-
255
- while paraphrase[-1][1] - paraphrase[0][0] > self.max_dur:
256
- if np.random.rand() > 0.2:
257
- paraphrase.pop(-1) # 大概率从后面截断
258
- else:
259
- paraphrase.pop(0) # 小概率截前面
260
-
261
- st, et, lyric = paraphrase[0][0], paraphrase[-1][1], ', '.join([i[2] for i in paraphrase]) # [SEP]
262
- # print(st, et, lyric)
263
- # import pdb; pdb.set_trace()
264
- assert self.min_dur < et - st < self.max_dur, f"{st}-{et} {lyric}"
265
- # print(et-st, lyric)
266
- # import pdb; pdb.set_trace()
267
-
268
- if info["lang_type"] == 'en':
269
- # print(len(lyric.split())/(et-st))
270
- char_num = sum([len(lrc[-1].split()) for lrc in paraphrase])
271
- assert 6 > char_num / (et-st) > 1
272
- else:
273
- # print(len(lyric.split())/(et-st))
274
- char_num = sum([len(lrc[-1]) for lrc in paraphrase])
275
- assert 6 > char_num / (et-st) > 1
276
-
277
- # 读取音频文件
278
- cur_sample_rate = torchaudio.info(path).sample_rate
279
- offset = int(cur_sample_rate*st)
280
- num_frames = int(cur_sample_rate * (et -st))
281
- chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
282
- # chunk = torch.zeros(1, 48000*15)
283
- if abs(chunk.shape[-1] - num_frames) > num_frames * 0.05: # 音频文件长度与歌词不一致
284
- print(f"fail to load {path} from {st} to {et} !")
285
- raise FileNotFoundError
286
- # 随机选取一个channel
287
- if(chunk.shape[0]>1):
288
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
289
- else:
290
- chunk = chunk[[0],:].float()
291
-
292
- if(cur_sample_rate!=self.sr):
293
- # print('a:',cur_sample_rate,chunk.shape)
294
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
295
-
296
- if self.pad_to_max:
297
- chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0)
298
-
299
- # print(self.sz_cnt)
300
- return chunk, lyric, [st, et], path, lang_type
301
- except (AssertionError, FileNotFoundError, RuntimeError) as e: # 其他Error不ok
302
- # print("Error loadding ", info["path"])
303
- try_cnt += 1
304
- idx = np.random.randint(0, len(self.data))
305
- if(try_cnt>100):
306
- raise e
307
-
308
- def parse_lyric(self, lyric):
309
- pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
310
- match = re.search(pattern, lyric)
311
-
312
- start_time = float(match.group(1))
313
- end_time = float(match.group(2))
314
- content = match.group(3)
315
- return start_time, end_time, content
316
-
317
- def pad_2d_tensor(self, x, max_len, pad_id):
318
- # 获取输入 tensor 的形状
319
- batch_size, seq_len = x.size()
320
- max_len = max(max_len, seq_len)
321
- # 计算需要填充的长度
322
- pad_len = max_len - seq_len
323
-
324
- # 如果需要填充
325
- if pad_len > 0:
326
- # 创建填充 tensor
327
- pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
328
-
329
- # 沿第二个维度(列)连接输入 tensor 和填充 tensor
330
- padded_tensor = torch.cat([x, pad_tensor], dim=1)
331
- else:
332
- # 如果不需要填充,直接返回输入 tensor
333
- padded_tensor = x
334
-
335
- return padded_tensor
336
-
337
- def collect_data(data_list):
338
- audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
339
- lyrics = [data[1] for data in data_list]
340
- st_et = [data[2] for data in data_list]
341
- paths = [data[3] for data in data_list]
342
- lang_types = [data[4] for data in data_list]
343
- return audios, lyrics, st_et
344
- # return audios, lyrics, st_et
345
-
346
-
347
- def build_dataset(train_jsonl_list, val_jsonl_list, min_dur=0, max_dur=20, add_music=False):
348
- print(min_dur,max_dur)
349
- print(train_jsonl_list)
350
- # ["exp/wyy3_20240418_v2f.jsonl",
351
- # "exp/tme_lyric_baokuan.jsonl"]
352
- train_dataset = WYYSongDataset(
353
- metadata_path = train_jsonl_list,
354
- sr = 48000,
355
- use_lang = ['zh-cn', 'en'],
356
- num_examples = 10*10000,
357
- min_dur=min_dur,
358
- max_dur=max_dur,
359
- add_music=add_music
360
- )
361
-
362
- valid_dataset = WYYSongDataset(
363
- metadata_path = val_jsonl_list,
364
- sr = 48000,
365
- use_lang = ['zh-cn', 'en'],
366
- num_examples = 500,
367
- min_dur=min_dur,
368
- max_dur=max_dur,
369
- add_music=add_music
370
- )
371
- print(train_jsonl_list, "\t total_song = ", len(train_dataset.data))
372
- return train_dataset, valid_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py DELETED
@@ -1,830 +0,0 @@
1
- from torch.utils.data import Dataset
2
- from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List
3
- from beartype import beartype
4
- from beartype.door import is_bearable
5
- import random
6
- import pandas as pd
7
- import os
8
- from torchaudio.functional import resample
9
- import torch
10
- import typing as tp
11
- from pathlib import Path
12
- import torchaudio as ta
13
- import torch.nn.functional as F
14
- import numpy as np
15
- import json
16
- import yaml
17
- import torchaudio
18
- import math
19
- import re
20
- from loguru import logger
21
-
22
- class Read_and_PadCrop_Normalized_T(torch.nn.Module):
23
- def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
24
-
25
- super().__init__()
26
-
27
- self.n_samples = n_samples
28
- self.sample_rate = sample_rate
29
- self.randomize = randomize
30
-
31
- def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
32
- if(duration<(float(self.n_samples)/self.sample_rate+1)):
33
- # print(duration,(float(self.n_samples)/self.sample_rate+1))
34
- chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
35
- t_start = 0.
36
- t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
37
- offset = 0
38
- # print('c1:',chunk.shape)
39
- else:
40
- offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
41
- t_start = offset / float(cur_sample_rate) / duration
42
- t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
43
- chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
44
- # print('offset:',offset)
45
- # print('c0:',chunk.shape)
46
- # Pad with silence if necessary.
47
- if(chunk.shape[0]>1):
48
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
49
- else:
50
- chunk = chunk[[0],:].float()
51
- if(cur_sample_rate!=self.sample_rate):
52
- # print('a:',cur_sample_rate,chunk.shape)
53
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
54
- # print('b:',self.sample_rate,chunk.shape)
55
- if chunk.shape[-1] < self.n_samples:
56
- chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
57
- else:
58
- chunk = chunk[:,0:self.n_samples]
59
- seconds_start = math.floor(offset / cur_sample_rate)
60
- seconds_total = math.floor(duration)
61
-
62
- return (
63
- chunk,
64
- t_start,
65
- t_end,
66
- seconds_start,
67
- seconds_total
68
- )
69
-
70
-
71
- USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
72
- if USE_DUMMY_AUDIO:
73
- logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
74
-
75
- class SafeAudioReader:
76
- """
77
- This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
78
- """
79
- def __init__(self,
80
- duration: float, # 返回音频长度
81
- sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
82
- randomize: bool = True
83
- ):
84
- self.n_samples = int(sample_rate * max(duration, 0))
85
- self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
86
-
87
- #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
88
- def __call__(self,
89
- filepath: os.PathLike, # 音频路径
90
- origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
91
- origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
92
- ) -> torch.Tensor:
93
- if USE_DUMMY_AUDIO:
94
- wav = torch.zeros(self.n_samples, dtype=torch.float32)
95
- return wav
96
- try:
97
- if origin_sample_rate is None or origin_duration is None:
98
- audio_info = torchaudio.info(filepath)
99
- origin_sample_rate = audio_info.sample_rate
100
- origin_duration = audio_info.num_frames / origin_sample_rate
101
- wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
102
- except Exception as e:
103
- logger.error(f"Error reading {filepath}: {e}")
104
- wav = torch.zeros(self.n_samples, dtype=torch.float32)
105
- return wav
106
-
107
-
108
- class PromptTemplate:
109
- def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
110
- self.template_text = template_text
111
- self.tag_map = tag_map
112
- self.lang = lang
113
-
114
- @property
115
- def tags(self):
116
- return tuple(self.tag_map.keys())
117
-
118
- def apply(self, **kwargs):
119
- for tag in list(kwargs.keys()):
120
- if kwargs[tag] == '':
121
- kwargs.pop(tag)
122
- for tag in self.tags:
123
- if tag in kwargs:
124
- kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
125
- else:
126
- kwargs[tag] = ''
127
- prompt = self.template_text.format(**kwargs)
128
-
129
- return self.beautify(prompt)
130
-
131
- def beautify(self, text):
132
- if self.lang == 'en':
133
- return self._beautify_en(text)
134
- elif self.lang == 'zh':
135
- return self._beautify_zh(text)
136
- else:
137
- raise ValueError(f'Unknown language {self.lang}')
138
-
139
- @staticmethod
140
- def _beautify_en(text):
141
- # no continuous commas without content between them
142
- text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
143
- # no continuous whitespace
144
- text = re.sub(r'\s+', ' ', text)
145
- # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
146
- text = re.sub(r'\s+,', r',', text)
147
- text = re.sub(r',\s+', r', ', text)
148
- # no whitespace before the full stop
149
- text = re.sub(r'\s+\.', r'.', text)
150
- # strip whitespace, comma, and replace ',.'
151
- text = text.strip(' ,')
152
- text = text.replace(',.', '.')
153
- return text
154
-
155
- @staticmethod
156
- def _beautify_zh(text):
157
- # no continuous commas without content between them
158
- text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
159
- text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
160
- # assume there should be NO whitespace in Chinese
161
- text = re.sub(r'\s+', r'', text)
162
- # strip whitespace, comma, and replace ',。'
163
- text = text.strip(', 、')
164
- text = text.replace(',。', '。')
165
- return text
166
-
167
- def __repr__(self):
168
- return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
169
-
170
- __str__ = __repr__
171
-
172
- def parse_prompt_template(prompt_template_text, lang='en'):
173
- span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
174
- tag_pattern = re.compile(r'{.+?}', re.DOTALL)
175
-
176
- template_text = prompt_template_text.strip()
177
- span_texts = span_pattern.findall(prompt_template_text)
178
- tag_map = {}
179
- for span_text in span_texts:
180
- tag = tag_pattern.findall(span_text)[0].strip('{}')
181
- tag_map[tag] = span_text
182
- template_text = template_text.replace(span_text, '{'+tag+'}')
183
-
184
- return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
185
-
186
- def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
187
- with open(path, 'r') as f:
188
- lines = f.readlines()
189
- cnt = 0
190
- pts = []
191
- for line in lines:
192
- pt = parse_prompt_template(line, lang=lang)
193
- cnt += 1
194
- if len(pt.tags) < num:
195
- logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
196
- pts.append(pt)
197
-
198
- return pts
199
-
200
-
201
- def get_base_dir_file(key: os.PathLike):
202
- base = os.path.basename(key)
203
- dirname = os.path.basename(os.path.dirname(key))
204
- return os.path.join(dirname, base)
205
-
206
- def read_jsonlike(path: os.PathLike):
207
- #json or jsonl
208
- if str(path).endswith(".json"):
209
- with open(path, 'r', encoding='utf8') as f:
210
- data = json.load(f)
211
- return data
212
- elif str(path).endswith(".jsonl"):
213
- with open(path, 'r', encoding='utf8') as f:
214
- data = [json.loads(line) for line in f.readlines()]
215
- return data
216
- else:
217
- raise ValueError("Unknown file format")
218
-
219
- dist_prob_map = {
220
- 1: (1.0,),
221
- 2: (0.5, 0.5),
222
- 3: (0.3, 0.4, 0.3),
223
- 4: (0.2, 0.3, 0.3, 0.2),
224
- 5: (0.2, 0.2, 0.3, 0.2, 0.1),
225
- 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
226
- 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
227
- 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
228
- 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
229
- 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
230
- }
231
-
232
- dist_prob_map_low = {
233
- 1: (1.0,),
234
- 2: (0.8, 0.2),
235
- 3: (0.8, 0.1, 0.1),
236
- 4: (0.7, 0.1, 0.1, 0.1),
237
- 5: (0.7, 0.1, 0.1, 0.05, 0.05),
238
- 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
239
- }
240
-
241
- _bpm_range_rights = (
242
- (40, '20-40'),
243
- (60, '40-60'),
244
- (66, '60-66'),
245
- (76, '66-76'),
246
- (108, '76-108'),
247
- (120, '108-120'),
248
- (168, '120-168'),
249
- (176, '168-176'),
250
- (200, '176-200')
251
- )
252
- _bpm_desc_map = {
253
- '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
254
- '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
255
- '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
256
- '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
257
- '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
258
- '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
259
- '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
260
- '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
261
- '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
262
- '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
263
- }
264
- _bpm_desc_map_zh = {
265
- '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
266
- '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
267
- '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
268
- '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
269
- '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
270
- '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
271
- '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
272
- '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
273
- '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
274
- '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
275
- }
276
- def get_bpm_range(bpm):
277
- bpm = int(bpm)
278
- for right, tag in _bpm_range_rights:
279
- if bpm <= right:
280
- return tag
281
- return '>200'
282
-
283
- def gen_bpm_descript(bpm, lang='en'):
284
- bpm_range = get_bpm_range(bpm)
285
- if lang == 'en':
286
- return random.choice(_bpm_desc_map[bpm_range])
287
- elif lang == 'zh':
288
- return random.choice(_bpm_desc_map_zh[bpm_range])
289
- else:
290
- raise ValueError(f"Unknown language {lang}")
291
-
292
- def read_translate(translate: Optional[Dict[str, os.PathLike]]):
293
- if translate is None:
294
- return None
295
- return {k: read_jsonlike(path) for k, path in translate.items()}
296
-
297
-
298
- class MagnaTagATuneDataset(Dataset):
299
- def __init__(self):
300
- pass
301
-
302
-
303
- def tags_to_desc(tag_list, sep=',') -> str:
304
- if not isinstance(tag_list, Sequence):
305
- return str(tag_list)
306
- if isinstance(tag_list, str):
307
- return tag_list
308
- if len(tag_list) <= 0:
309
- return ''
310
- elif len(tag_list) <= 5:
311
- probs = dist_prob_map[len(tag_list)]
312
- tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
313
- random.shuffle(tag_list)
314
- tag_list = tag_list[:tags_num]
315
- return sep.join(tag_list)
316
- else:
317
- probs = dist_prob_map[5]
318
- tags_num = random.choices(range(1, 6), probs)[0]
319
- random.shuffle(tag_list)
320
- tag_list = tag_list[:tags_num]
321
- return sep.join(tag_list)
322
-
323
- def get_sr_and_duration_info(item):
324
- return item.get('sample_rate', None), item.get('duration', None)
325
-
326
- class MtgJamendoDatasetFromJson(Dataset):
327
- def __init__(self,
328
- data_dir:str,
329
- json_path:str,
330
- duration:float=10,
331
- sr:int = 0,
332
- *,
333
- lang = 'en',
334
- return_path = False,
335
- prompt_template_path: os.PathLike = None,
336
- tag_types = [],
337
- translate:Optional[Dict[str, os.PathLike]] = None,
338
- ):
339
- self.audio_reader = SafeAudioReader(duration, sr)
340
-
341
- self.data_dir = data_dir
342
- self._load_metadata_json(json_path)
343
- self.sr = sr
344
- self.duration = duration
345
- self.return_path = return_path
346
- self.lang = lang
347
-
348
- self.use_dynamic_prompt = prompt_template_path is not None
349
- if self.use_dynamic_prompt:
350
- self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
351
- self.tag_types = tag_types
352
-
353
- self.translate = read_translate(translate)
354
- if not self.use_dynamic_prompt and self.lang != 'en':
355
- raise NotImplementedError
356
-
357
- #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
358
- WEAK_TAG_LIST = ["title", "artist"]
359
-
360
- def _load_metadata_json(self, json_path):
361
- with open(json_path) as fp:
362
- self.data = json.load(fp)
363
-
364
- def convert_key_to_path(self, key):
365
- return os.path.join(self.data_dir, get_base_dir_file(key))
366
-
367
- def __len__(self):
368
- return len(self.data)
369
-
370
- def __getitem__(self, idx):
371
- item = self.data[idx]
372
- path = self.convert_key_to_path(item['key'])
373
- description = self.generate_description(item)
374
-
375
- sr, duration = get_sr_and_duration_info(item)
376
- audio = self.audio_reader(path, sr, duration)
377
-
378
- if self.return_path:
379
- return audio, description, path
380
- return audio, description
381
-
382
- def tags_to_desc(self, tag_list, tag_type) -> str:
383
- if self.lang == 'en':
384
- return tags_to_desc(tag_list)
385
- elif self.lang == 'zh':
386
- translator = self.translate[tag_type]
387
- translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
388
- return tags_to_desc(translated_tag_list, sep='、')
389
-
390
- def generate_description(self, item):
391
- if self.use_dynamic_prompt:
392
- # dynamically generate prompt from given prompt template
393
- prompt_template = random.choice(self.prompt_templates)
394
- description = self.generate_description_dynamic(item, prompt_template)
395
-
396
- else:
397
- # use ordinary static prompt instead
398
- description = self.generate_description_ordinary(item)
399
- return description
400
-
401
- def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
402
- exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
403
- exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
404
- exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
405
-
406
- if len(exists_strong_tag) > 0:
407
- probs = dist_prob_map[len(exists_strong_tag)]
408
- tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
409
- random.shuffle(exists_strong_tag)
410
- tags = exists_strong_tag[:tags_num]
411
- weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
412
- weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
413
- random.shuffle(exists_weak_tag)
414
- weak_tags = exists_weak_tag[:weak_tags_num]
415
- tags += weak_tags
416
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
417
- prompt = prompt_template.apply(**tags_args)
418
- else:
419
- # no strong tags, use all weak tags instead
420
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
421
- prompt = prompt_template.apply(**tags_args)
422
-
423
- return prompt
424
-
425
- def generate_description_ordinary(self, data, thresh = 0.3):
426
- # Initialize the description with title and artist
427
- description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}'
428
-
429
- # Add genre if available
430
- if data["genre"] and random.random() > thresh:
431
- genres = ', '.join(data["genre"])
432
- description += f', belonging to the {genres} genres'
433
-
434
- # Add moods if available
435
- if data["moods"] and random.random() > thresh:
436
- moods = ', '.join(data["moods"])
437
- description += f'. This track conveys a {moods} mood'
438
-
439
- # Add instruments if available
440
- if data["instrument"] and random.random() > thresh:
441
- instruments = ', '.join(data["instrument"])
442
- description += f', and primarily features the following instruments: {instruments}'
443
-
444
- # Add a period to end the description
445
- description += '.'
446
-
447
- return description
448
-
449
- class AudioStockDataset(Dataset):
450
- def __init__(self,
451
- metadata_path:str,
452
- duration:float=10,
453
- sr:int = 0,
454
- return_path = False,
455
- return_audio = True,
456
- prompt_template_path: os.PathLike = None,
457
- tag_types = [],
458
- lang = 'en',
459
- translate:Optional[Dict[str, os.PathLike]] = None
460
- ):
461
- self.audio_reader = SafeAudioReader(duration, sr)
462
-
463
- self._load_metadata(metadata_path)
464
- self.sr = sr
465
- self.duration = duration
466
- self.return_path = return_path
467
- self.return_audio = return_audio
468
-
469
- self.use_dynamic_prompt = prompt_template_path is not None
470
- if self.use_dynamic_prompt:
471
- self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
472
- self.tag_types = tag_types
473
-
474
- self.lang = lang
475
- self.translate = read_translate(translate)
476
-
477
- def _load_metadata(self, metadata_path):
478
- with open(metadata_path) as fp:
479
- lines = fp.readlines()
480
- self.data = []
481
- for line in lines:
482
- item = json.loads(line)
483
- self.data.append(item)
484
- self.is_info_recorded = bool('Tags' in self.data[0])
485
-
486
- def __len__(self):
487
- return len(self.data)
488
-
489
- def __getitem__(self, idx):
490
- path:str = self.data[idx]["path"]
491
- json_path = path[:path.rfind('.')] + ".json"
492
- if self.is_info_recorded:
493
- item = self.data[idx]
494
- else:
495
- try:
496
- with open(json_path) as fp:
497
- item:dict = json.load(fp)
498
- except Exception as e:
499
- print(f"Error loading json file {json_path} :\n{e}")
500
- item = {}
501
- description = self.generate_description(item)
502
- if self.return_audio:
503
- sr, duration = get_sr_and_duration_info(item)
504
- audio = self.audio_reader(path, sr, duration)
505
- else:
506
- audio = None
507
- if self.return_path:
508
- return audio, description, path
509
- return audio, description
510
-
511
- def generate_description(self, item):
512
- if self.use_dynamic_prompt:
513
- # dynamically generate prompt from given prompt template
514
- prompt_template = random.choice(self.prompt_templates)
515
- description = self.generate_description_dynamic(item, prompt_template)
516
- else:
517
- # use ordinary static prompt instead
518
- description = self.generate_description_ordinary(item)
519
- return description
520
-
521
- def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
522
- exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
523
-
524
- if len(exists_tag) > 0:
525
- probs = dist_prob_map[len(exists_tag)]
526
- tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
527
- random.shuffle(exists_tag)
528
- tags = exists_tag[:tags_num]
529
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
530
- tags_args = self.handle_BPM_tag(tags_args)
531
- prompt = prompt_template.apply(**tags_args)
532
- else:
533
- # no strong tags, use all weak tags instead
534
- prompt = prompt_template.apply()
535
-
536
- return prompt
537
-
538
- def tags_to_desc(self, tag_list, tag_type) -> str:
539
- if self.lang == 'en':
540
- return tags_to_desc(tag_list)
541
- elif self.lang == 'zh':
542
- if tag_type == 'BPM':
543
- return tags_to_desc(tag_list, sep='、')
544
- translator = self.translate[tag_type]
545
- translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
546
- return tags_to_desc(translated_tag_list, sep='、')
547
-
548
- def handle_BPM_tag(self, tags_args):
549
- if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
550
- bpm = tags_args["BPM"]
551
- del tags_args["BPM"]
552
- tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
553
- for tag_type in tag_types_used:
554
- tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
555
- return tags_args
556
-
557
- def generate_description_ordinary(self, data, thresh = 0.3):
558
- if self.lang != 'en':
559
- raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
560
- description = f'a piece of music by {data["Artist"]}'
561
-
562
- # Add genre if available
563
- if data["Genre"] and random.random() > thresh:
564
- genres = ', '.join(data["Genre"])
565
- description += f', belonging to the {genres} genres'
566
-
567
- # Add moods if available
568
- if data["Tags"] and random.random() > thresh:
569
- tags = ', '.join(data["Tags"])
570
- description += f'. This track contains the tags:{tags}'
571
-
572
- # Add moods if available
573
- if data["Mood"] and random.random() > thresh:
574
- moods = ', '.join(data["Mood"])
575
- description += f'. This track conveys a {moods} mood.'
576
-
577
- # Add instruments if available
578
- if data["Instrument"] and random.random() > thresh:
579
- instruments = ', '.join(data["Instrument"])
580
- description += f'. and primarily features the following instruments: {instruments}'
581
-
582
- # Add a period to end the description
583
- description += '.'
584
-
585
- return description
586
-
587
- def mp3_path_to_id(mp3_path):
588
- return int(
589
- mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')]
590
- )
591
-
592
- class TmeDataset(Dataset):
593
- def __init__(self,
594
- data_index:str,
595
- music_info:str = None,
596
- duration:float = 10,
597
- sr:int = 0,
598
- return_path = False,
599
- return_audio = True,
600
- prompt_format_path: os.PathLike = None,
601
- tag_types = ['*'],
602
- lang = 'zh',
603
- translate: Optional[os.PathLike] = None,
604
- prompt_dir: os.PathLike = None,
605
- ):
606
- self.audio_reader = SafeAudioReader(duration, sr)
607
-
608
- self.sr = sr
609
- self.duration = duration
610
- self.return_path = return_path
611
- self.return_audio = return_audio
612
- self.lang = lang
613
-
614
- self.use_ready_prompt = prompt_dir is not None
615
-
616
- data_index = read_jsonlike(data_index)
617
- self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
618
- self.data_ids = list(self.data_index_dict.keys())
619
-
620
- if not self.use_ready_prompt:
621
- #读取音乐的信息文件
622
- music_info = read_jsonlike(music_info)
623
- if 'music' in music_info:
624
- music_info = music_info['music']
625
- self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
626
- self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
627
- self.data_ids = list(self.data_index_dict.keys())
628
-
629
- with open(prompt_format_path) as fp:
630
- self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
631
-
632
- #加载tag types,并分成一般的tag_types和关键的key_tag_types
633
- if '*' in tag_types:
634
- self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
635
- else:
636
- self.tag_types = tag_types
637
-
638
- self.key_tag_types = []
639
- if 'tag' in self.tag_types:
640
- self.tag_types.remove('tag')
641
- self.key_tag_types = list(self.prompt_formats['tag'].keys())
642
-
643
- #加载translate翻译
644
- if translate is not None:
645
- self.translator = read_jsonlike(translate)
646
- else:
647
- data_ids_set = set(self.data_ids)
648
- self.prompts_dict = {}
649
- for fname in os.listdir(prompt_dir):
650
- items = read_jsonlike(os.path.join(prompt_dir, fname))
651
- for item in items:
652
- if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
653
- continue
654
- if item['ID'] not in self.prompts_dict:
655
- self.prompts_dict[item['ID']] = []
656
- self.prompts_dict[item['ID']].append(item['Text'])
657
- self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
658
- self.data_ids = list(self.data_index_dict.keys())
659
-
660
- def tags_to_desc(self, tag_list) -> str:
661
- if is_bearable(tag_list, int):
662
- return str(tag_list)
663
- if self.lang == 'zh':
664
- return tags_to_desc(tag_list, sep=self.sep)
665
- else:
666
- translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
667
- return tags_to_desc(translated_tag_list, sep=self.sep)
668
-
669
- def gen_desc_of_tag(self, formats, tags):
670
- fmt = random.choice(formats)
671
- return fmt.format(self.tags_to_desc(tags))
672
-
673
- @staticmethod
674
- def check_valid(value):
675
- if isinstance(value, int) or isinstance(value, float):
676
- return value > 0
677
- if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
678
- return True
679
- return False
680
-
681
- @staticmethod
682
- def remove_repeat(data):
683
- #若专辑名和歌曲名相同,则只使用后者
684
- album_name = data.get('专辑名', None)
685
- if album_name is not None and album_name == data.get('歌曲名', None):
686
- del data['专辑名']
687
- return data
688
-
689
- @property
690
- def comma(self):
691
- if self.lang == 'zh':
692
- return ','
693
- elif self.lang == 'en':
694
- return ', '
695
-
696
- @property
697
- def sep(self):
698
- if self.lang == 'zh':
699
- return '、'
700
- elif self.lang == 'en':
701
- return ', '
702
-
703
- def generate_description(self, data):
704
- data = self.remove_repeat(data)
705
- weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
706
-
707
- key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
708
-
709
- prompts = []
710
- if len(weak_tags) > 0:
711
- probs = dist_prob_map_low[len(weak_tags)]
712
- if len(key_tags) > 0:
713
- tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
714
- else:
715
- tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
716
- random.shuffle(weak_tags)
717
- tags = weak_tags[:tags_num]
718
- for tag_type in tags:
719
- tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
720
- prompts.append(tag_desc)
721
-
722
- if len(key_tags) > 0:
723
- probs = dist_prob_map[len(key_tags)]
724
- tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
725
- random.shuffle(key_tags)
726
- tags = key_tags[:tags_num]
727
- for tag_type in tags:
728
- tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
729
- prompts.append(tag_desc)
730
-
731
- random.shuffle(prompts)
732
- return self.comma.join(prompts)
733
-
734
- def is_valid_prompt_text(self, text):
735
- for bad in ('抱歉','sorry', 'Sorry'):
736
- if bad in text:
737
- return False
738
- return True
739
-
740
- def get_ready_prompt(self, path):
741
- sid = mp3_path_to_id(path)
742
- return random.choice(self.prompts_dict[sid])
743
-
744
- def __len__(self):
745
- return len(self.data_ids)
746
-
747
- def __getitem__(self, idx):
748
- data_id = self.data_ids[idx]
749
- item = self.data_index_dict[data_id]
750
- path = item['path']
751
- if not self.use_ready_prompt:
752
- info = self.music_info_dict[data_id]
753
- description = self.generate_description(info)
754
- else:
755
- description = self.get_ready_prompt(path)
756
- if self.return_audio:
757
- sr, duration = get_sr_and_duration_info(item)
758
- audio = self.audio_reader(path, sr, duration)
759
- else:
760
- audio = None
761
- if self.return_path:
762
- return audio, description, path
763
- return audio, description
764
-
765
- class CombinedDataset(Dataset):
766
- @beartype
767
- def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
768
- self.datasets = datasets
769
- self.datasets_index = []
770
-
771
- for i,dataset in enumerate(datasets):
772
- if dataset is None:
773
- continue
774
- for dup in range(ratios[i]):
775
- for j in range(len(dataset)):
776
- self.datasets_index.append((i,j))
777
-
778
- def __len__(self):
779
- return len(self.datasets_index)
780
-
781
- def __getitem__(self, idx):
782
- index = self.datasets_index[idx]
783
- i,j = index
784
- return self.datasets[i][j]
785
-
786
- class CombinedDataset_random(Dataset):
787
- @beartype
788
- def __init__(self,
789
- num_examples:int,
790
- datasets: Sequence[Dataset], ratios: Sequence[int]
791
- ):
792
- self.datasets = datasets
793
- self.datasets_index = []
794
-
795
- for i,dataset in enumerate(datasets):
796
- if dataset is None:
797
- continue
798
- for dup in range(ratios[i]):
799
- for j in range(len(dataset)):
800
- self.datasets_index.append((i,j))
801
- if num_examples > 0:
802
- self.random_choose = True
803
- self.dataset_len = num_examples
804
- else:
805
- self.random_choose = False
806
- self.dataset_len = len(self.datasets_index)
807
-
808
- def __len__(self):
809
- return self.dataset_len
810
-
811
- def __getitem__(self, idx):
812
- first_try = True
813
- try_cnt = 0
814
- while True:
815
- try:
816
- if(self.random_choose or not first_try):
817
- index2 = []
818
- index2.append(np.random.randint(0,len(self.datasets)))
819
- index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
820
- else:
821
- index2 = self.datasets_index[idx]
822
- first_try = False
823
- out = self.datasets[index2[0]][index2[1]]
824
- if(len(out[0].shape)==1):out[0]=out[0][None,:]
825
- return out
826
- except:
827
- print("Error loadding ", index2)
828
- try_cnt += 1
829
- if(try_cnt>10):
830
- raise ValueError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py DELETED
@@ -1,994 +0,0 @@
1
- from torch.utils.data import Dataset
2
- from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List
3
- from beartype import beartype
4
- from beartype.door import is_bearable
5
- import random
6
- import pandas as pd
7
- import os
8
- from torchaudio.functional import resample
9
- import torch
10
- import typing as tp
11
- from pathlib import Path
12
- import torchaudio as ta
13
- import torch.nn.functional as F
14
- import numpy as np
15
- import json
16
- import yaml
17
- import torchaudio
18
- import math
19
- import re
20
- from loguru import logger
21
-
22
- def gen_plain_prompt(key_list, sep=', '):
23
- if len(key_list) == 0:
24
- return 'none'
25
-
26
- key_list = [k.strip() for k in key_list]
27
-
28
- if len(key_list) > 10:
29
- random.shuffle(key_list)
30
- key_list = key_list[:10]
31
-
32
- probs = dist_prob_map[len(key_list)]
33
-
34
- num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
35
-
36
- random.shuffle(key_list)
37
- tags = key_list[:num_tags]
38
- tags_str = sep.join(tags)
39
- return tags_str
40
-
41
- class Read_and_PadCrop_Normalized_T(torch.nn.Module):
42
-
43
- def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
44
-
45
- super().__init__()
46
-
47
- self.n_samples = n_samples
48
- self.sample_rate = sample_rate
49
- self.randomize = randomize
50
- self.prob = {"is_start":0.2, "is_end":0.9}
51
- self.shift_secs = 5
52
-
53
- def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
54
- if(duration<(float(self.n_samples)/self.sample_rate+1)):
55
- raise ValueError(duration,float(self.n_samples),self.sample_rate)
56
- chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
57
- t_start = 0.
58
- t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
59
- offset = 0
60
- is_start = True
61
- is_end = True
62
- else:
63
- prob = random.uniform(0,1)
64
- if(prob<self.prob['is_start']):
65
- is_start = True
66
- is_end = False
67
- offset = 0
68
- elif(prob>self.prob['is_end']):
69
- is_start = False
70
- is_end = True
71
- offset = int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)
72
- else:
73
- is_start = False
74
- is_end = False
75
- offset = np.random.randint(self.shift_secs*cur_sample_rate, \
76
- int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)-self.shift_secs*cur_sample_rate)
77
- t_start = offset / float(cur_sample_rate) / duration
78
- t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
79
- chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
80
- if(chunk.shape[0]>1):
81
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
82
- else:
83
- chunk = chunk[[0],:].float()
84
- if(cur_sample_rate!=self.sample_rate):
85
- # print('a:',cur_sample_rate,chunk.shape)
86
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
87
- # print('b:',self.sample_rate,chunk.shape)
88
- if chunk.shape[-1] != self.n_samples:
89
- raise ValueError(chunk.shape, self.n_samples, offset, int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
90
- # if chunk.shape[-1] < self.n_samples:
91
- # chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
92
- # else:
93
- # chunk = chunk[:,0:self.n_samples]
94
- seconds_start = math.floor(offset / cur_sample_rate)
95
- seconds_total = math.floor(duration)
96
-
97
- # # In this dataset, we do not introduce zeros
98
- # if(is_start):
99
- # chunk = torch.cat([torch.zeros(1, self.shift_secs*self.sample_rate), chunk],1)[:,0:self.n_samples]
100
- # elif(is_end):
101
- # chunk = torch.cat([chunk, torch.zeros(1, self.shift_secs*self.sample_rate)],1)[:,self.shift_secs*self.sample_rate:]
102
-
103
- return (
104
- chunk,
105
- t_start,
106
- t_end,
107
- seconds_start,
108
- seconds_total,
109
- is_start,
110
- is_end,
111
- )
112
-
113
-
114
- USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替
115
- if USE_DUMMY_AUDIO:
116
- logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
117
-
118
- class SafeAudioReader:
119
- """
120
- This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
121
- """
122
- def __init__(self,
123
- duration: float, # 返回音频长度
124
- sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample
125
- randomize: bool = True
126
- ):
127
- self.n_samples = int(sample_rate * max(duration, 0))
128
- self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
129
-
130
- #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数!
131
- def __call__(self,
132
- filepath: os.PathLike, # 音频路径
133
- origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取
134
- origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取
135
- ) -> torch.Tensor:
136
- if USE_DUMMY_AUDIO:
137
- wav = torch.zeros(self.n_samples, dtype=torch.float32)
138
- return wav
139
- try:
140
- # if origin_sample_rate is None or origin_duration is None:
141
- # audio_info = torchaudio.info(filepath)
142
- # origin_sample_rate = audio_info.sample_rate
143
- # origin_duration = audio_info.num_frames / origin_sample_rate
144
- audio_info = torchaudio.info(filepath)
145
- origin_sample_rate = audio_info.sample_rate
146
- origin_duration = audio_info.num_frames / origin_sample_rate
147
- wav, *ignored, is_start, is_end = self.reader(filepath, origin_duration, origin_sample_rate)
148
- except Exception as e:
149
- logger.error(f"Error reading {filepath}: {e}")
150
- raise FileNotFoundError(filepath)
151
- return wav, is_start, is_end
152
-
153
-
154
- class PromptTemplate:
155
- def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
156
- self.template_text = template_text
157
- self.tag_map = tag_map
158
- self.lang = lang
159
-
160
- @property
161
- def tags(self):
162
- return tuple(self.tag_map.keys())
163
-
164
- def apply(self, **kwargs):
165
- for tag in list(kwargs.keys()):
166
- if kwargs[tag] == '':
167
- kwargs.pop(tag)
168
- for tag in self.tags:
169
- if tag in kwargs:
170
- kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
171
- else:
172
- kwargs[tag] = ''
173
- prompt = self.template_text.format(**kwargs)
174
-
175
- return self.beautify(prompt)
176
-
177
- def beautify(self, text):
178
- if self.lang == 'en':
179
- return self._beautify_en(text)
180
- elif self.lang == 'zh':
181
- return self._beautify_zh(text)
182
- else:
183
- raise ValueError(f'Unknown language {self.lang}')
184
-
185
- @staticmethod
186
- def _beautify_en(text):
187
- # no continuous commas without content between them
188
- text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
189
- # no continuous whitespace
190
- text = re.sub(r'\s+', ' ', text)
191
- # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
192
- text = re.sub(r'\s+,', r',', text)
193
- text = re.sub(r',\s+', r', ', text)
194
- # no whitespace before the full stop
195
- text = re.sub(r'\s+\.', r'.', text)
196
- # strip whitespace, comma, and replace ',.'
197
- text = text.strip(' ,')
198
- text = text.replace(',.', '.')
199
- return text
200
-
201
- @staticmethod
202
- def _beautify_zh(text):
203
- # no continuous commas without content between them
204
- text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
205
- text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
206
- # assume there should be NO whitespace in Chinese
207
- text = re.sub(r'\s+', r'', text)
208
- # strip whitespace, comma, and replace ',。'
209
- text = text.strip(', 、')
210
- text = text.replace(',。', '。')
211
- return text
212
-
213
- def __repr__(self):
214
- return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
215
-
216
- __str__ = __repr__
217
-
218
- def parse_prompt_template(prompt_template_text, lang='en'):
219
- span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
220
- tag_pattern = re.compile(r'{.+?}', re.DOTALL)
221
-
222
- template_text = prompt_template_text.strip()
223
- span_texts = span_pattern.findall(prompt_template_text)
224
- tag_map = {}
225
- for span_text in span_texts:
226
- tag = tag_pattern.findall(span_text)[0].strip('{}')
227
- tag_map[tag] = span_text
228
- template_text = template_text.replace(span_text, '{'+tag+'}')
229
-
230
- return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
231
-
232
- def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
233
- with open(path, 'r') as f:
234
- lines = f.readlines()
235
- cnt = 0
236
- pts = []
237
- for line in lines:
238
- pt = parse_prompt_template(line, lang=lang)
239
- cnt += 1
240
- if len(pt.tags) < num:
241
- logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
242
- pts.append(pt)
243
-
244
- return pts
245
-
246
-
247
- def get_base_dir_file(key: os.PathLike):
248
- base = os.path.basename(key)
249
- dirname = os.path.basename(os.path.dirname(key))
250
- return os.path.join(dirname, base)
251
-
252
- def read_jsonlike(path: os.PathLike):
253
- #json or jsonl
254
- if str(path).endswith(".json"):
255
- with open(path, 'r', encoding='utf8') as f:
256
- data = json.load(f)
257
- return data
258
- elif str(path).endswith(".jsonl"):
259
- with open(path, 'r', encoding='utf8') as f:
260
- data = [json.loads(line) for line in f.readlines()]
261
- return data
262
- else:
263
- raise ValueError("Unknown file format")
264
-
265
- dist_prob_map = {
266
- 1: (1.0,),
267
- 2: (0.5, 0.5),
268
- 3: (0.3, 0.4, 0.3),
269
- 4: (0.2, 0.3, 0.3, 0.2),
270
- 5: (0.2, 0.2, 0.3, 0.2, 0.1),
271
- 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
272
- 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
273
- 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
274
- 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
275
- 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
276
- }
277
-
278
- dist_prob_map_low = {
279
- 1: (1.0,),
280
- 2: (0.8, 0.2),
281
- 3: (0.8, 0.1, 0.1),
282
- 4: (0.7, 0.1, 0.1, 0.1),
283
- 5: (0.7, 0.1, 0.1, 0.05, 0.05),
284
- 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
285
- }
286
-
287
- _bpm_range_rights = (
288
- (40, '20-40'),
289
- (60, '40-60'),
290
- (66, '60-66'),
291
- (76, '66-76'),
292
- (108, '76-108'),
293
- (120, '108-120'),
294
- (168, '120-168'),
295
- (176, '168-176'),
296
- (200, '176-200')
297
- )
298
- _bpm_desc_map = {
299
- '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
300
- '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
301
- '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
302
- '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
303
- '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
304
- '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
305
- '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
306
- '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
307
- '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
308
- '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
309
- }
310
- _bpm_desc_map_zh = {
311
- '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
312
- '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
313
- '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
314
- '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
315
- '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
316
- '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
317
- '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
318
- '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
319
- '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
320
- '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
321
- }
322
- def get_bpm_range(bpm):
323
- bpm = int(bpm)
324
- for right, tag in _bpm_range_rights:
325
- if bpm <= right:
326
- return tag
327
- return '>200'
328
-
329
- def gen_bpm_descript(bpm, lang='en'):
330
- bpm_range = get_bpm_range(bpm)
331
- if lang == 'en':
332
- return random.choice(_bpm_desc_map[bpm_range])
333
- elif lang == 'zh':
334
- return random.choice(_bpm_desc_map_zh[bpm_range])
335
- else:
336
- raise ValueError(f"Unknown language {lang}")
337
-
338
- def read_translate(translate: Optional[Dict[str, os.PathLike]]):
339
- if translate is None:
340
- return None
341
- if isinstance(translate, str):
342
- return read_jsonlike(translate)
343
- return {k: read_jsonlike(path) for k, path in translate.items()}
344
-
345
-
346
- class MagnaTagATuneDataset(Dataset):
347
- def __init__(self):
348
- pass
349
-
350
-
351
- def tags_to_desc(tag_list, sep=',') -> str:
352
- if not isinstance(tag_list, Sequence):
353
- return str(tag_list)
354
- if isinstance(tag_list, str):
355
- return tag_list
356
- if len(tag_list) <= 0:
357
- return ''
358
- elif len(tag_list) <= 5:
359
- probs = dist_prob_map[len(tag_list)]
360
- tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
361
- random.shuffle(tag_list)
362
- tag_list = tag_list[:tags_num]
363
- return sep.join(tag_list)
364
- else:
365
- probs = dist_prob_map[5]
366
- tags_num = random.choices(range(1, 6), probs)[0]
367
- random.shuffle(tag_list)
368
- tag_list = tag_list[:tags_num]
369
- return sep.join(tag_list)
370
-
371
- def get_sr_and_duration_info(item):
372
- return item.get('sample_rate', None), item.get('duration', None)
373
-
374
- class MtgJamendoDatasetFromJson(Dataset):
375
- def __init__(self,
376
- data_dir:str,
377
- json_path:str,
378
- duration:float=10,
379
- sr:int = 0,
380
- *,
381
- lang = 'en',
382
- return_path = False,
383
- prompt_template_path: os.PathLike = None,
384
- tag_types = [],
385
- translate:Optional[Dict[str, os.PathLike]] = None,
386
- ):
387
- self.audio_reader = SafeAudioReader(duration, sr)
388
-
389
- self.data_dir = data_dir
390
- self._load_metadata_json(json_path)
391
- self.sr = sr
392
- self.duration = duration
393
- self.return_path = return_path
394
- self.lang = lang
395
-
396
- self.use_dynamic_prompt = prompt_template_path is not None
397
- if self.use_dynamic_prompt:
398
- self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
399
- self.tag_types = tag_types
400
-
401
- self.translate = read_translate(translate)
402
- if not self.use_dynamic_prompt and self.lang != 'en':
403
- raise NotImplementedError
404
-
405
- #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示
406
- WEAK_TAG_LIST = ["title", "artist"]
407
-
408
- def _load_metadata_json(self, json_path):
409
- with open(json_path) as fp:
410
- self.data = json.load(fp)
411
-
412
- def convert_key_to_path(self, key):
413
- return os.path.join(self.data_dir, get_base_dir_file(key))
414
-
415
- def __len__(self):
416
- return len(self.data)
417
-
418
- def __getitem__(self, idx):
419
- item = self.data[idx]
420
- path = self.convert_key_to_path(item['key'])
421
- description = self.generate_description(item)
422
-
423
- sr, duration = get_sr_and_duration_info(item)
424
- audio, is_start, is_end = self.audio_reader(path, sr, duration)
425
-
426
- if self.return_path:
427
- return audio, description, path
428
- return audio, description, is_start, is_end
429
-
430
- def tags_to_desc(self, tag_list, tag_type) -> str:
431
- if self.lang == 'en':
432
- return tags_to_desc(tag_list)
433
- elif self.lang == 'zh':
434
- translator = self.translate[tag_type]
435
- translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
436
- return tags_to_desc(translated_tag_list, sep='、')
437
-
438
- def generate_description(self, item):
439
- if self.use_dynamic_prompt:
440
- # dynamically generate prompt from given prompt template
441
- prompt_template = random.choice(self.prompt_templates)
442
- description = self.generate_description_dynamic(item, prompt_template)
443
-
444
- else:
445
- # use ordinary static prompt instead
446
- description = self.generate_description_ordinary(item)
447
- return description
448
-
449
- def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
450
- exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
451
- exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
452
- exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
453
-
454
- if len(exists_strong_tag) > 0:
455
- probs = dist_prob_map[len(exists_strong_tag)]
456
- tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
457
- random.shuffle(exists_strong_tag)
458
- tags = exists_strong_tag[:tags_num]
459
- weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
460
- weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
461
- random.shuffle(exists_weak_tag)
462
- weak_tags = exists_weak_tag[:weak_tags_num]
463
- tags += weak_tags
464
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
465
- prompt = prompt_template.apply(**tags_args)
466
- else:
467
- # no strong tags, use all weak tags instead
468
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
469
- prompt = prompt_template.apply(**tags_args)
470
-
471
- return prompt
472
-
473
- def generate_description_ordinary(self, data, thresh = 0.3):
474
- # Initialize the description with title and artist
475
- description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}'
476
-
477
- # Add genre if available
478
- if data["genre"] and random.random() > thresh:
479
- genres = ', '.join(data["genre"])
480
- description += f', belonging to the {genres} genres'
481
-
482
- # Add moods if available
483
- if data["moods"] and random.random() > thresh:
484
- moods = ', '.join(data["moods"])
485
- description += f'. This track conveys a {moods} mood'
486
-
487
- # Add instruments if available
488
- if data["instrument"] and random.random() > thresh:
489
- instruments = ', '.join(data["instrument"])
490
- description += f', and primarily features the following instruments: {instruments}'
491
-
492
- # Add a period to end the description
493
- description += '.'
494
-
495
- return description
496
-
497
- class AudioStockDataset(Dataset):
498
- def __init__(self,
499
- metadata_path:str,
500
- duration:float=10,
501
- sr:int = 0,
502
- return_path = False,
503
- return_audio = True,
504
- prompt_template_path: os.PathLike = None,
505
- tag_types = [],
506
- lang = 'en',
507
- translate:Optional[Dict[str, os.PathLike]] = None
508
- ):
509
- self.audio_reader = SafeAudioReader(duration, sr)
510
-
511
- self.duration = duration
512
- self._load_metadata(metadata_path)
513
- self.sr = sr
514
- self.return_path = return_path
515
- self.return_audio = return_audio
516
-
517
- self.use_dynamic_prompt = prompt_template_path is not None
518
- if self.use_dynamic_prompt:
519
- self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
520
- self.tag_types = tag_types
521
-
522
- self.lang = lang
523
- self.translate = read_translate(translate)
524
-
525
- def _load_metadata(self, metadata_path):
526
- with open(metadata_path) as fp:
527
- lines = fp.readlines()
528
- self.data = []
529
- for line in lines:
530
- item = json.loads(line)
531
- if(item['duration']>self.duration+10):
532
- self.data.append(item)
533
- self.is_info_recorded = bool('Tags' in self.data[0])
534
-
535
- def __len__(self):
536
- return len(self.data)
537
-
538
- def __getitem__(self, idx):
539
- path:str = self.data[idx]["path"]
540
- json_path = path[:path.rfind('.')] + ".json"
541
- if self.is_info_recorded:
542
- item = self.data[idx]
543
- else:
544
- try:
545
- with open(json_path) as fp:
546
- item:dict = json.load(fp)
547
- except Exception as e:
548
- print(f"Error loading json file {json_path} :\n{e}")
549
- item = {}
550
- description = self.generate_description(item)
551
- if self.return_audio:
552
- sr, duration = get_sr_and_duration_info(item)
553
- audio, is_start, is_end = self.audio_reader(path, sr, duration)
554
- else:
555
- audio = None
556
- if self.return_path:
557
- return audio, description, path, is_start, is_end
558
- else:
559
- return audio, description, is_start, is_end
560
-
561
- def generate_description(self, item):
562
- if self.use_dynamic_prompt:
563
- # dynamically generate prompt from given prompt template
564
- prompt_template = random.choice(self.prompt_templates)
565
- description = self.generate_description_dynamic(item, prompt_template)
566
- else:
567
- # use ordinary static prompt instead
568
- description = self.generate_description_ordinary(item)
569
- return description
570
-
571
- def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
572
- exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
573
-
574
- if len(exists_tag) > 0:
575
- probs = dist_prob_map[len(exists_tag)]
576
- tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
577
- random.shuffle(exists_tag)
578
- tags = exists_tag[:tags_num]
579
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
580
- tags_args = self.handle_BPM_tag(tags_args)
581
- prompt = prompt_template.apply(**tags_args)
582
- else:
583
- # no strong tags, use all weak tags instead
584
- prompt = prompt_template.apply()
585
-
586
- return prompt
587
-
588
- def tags_to_desc(self, tag_list, tag_type) -> str:
589
- if self.lang == 'en':
590
- return tags_to_desc(tag_list)
591
- elif self.lang == 'zh':
592
- if tag_type == 'BPM':
593
- return tags_to_desc(tag_list, sep='、')
594
- translator = self.translate[tag_type]
595
- translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
596
- return tags_to_desc(translated_tag_list, sep='、')
597
-
598
- def handle_BPM_tag(self, tags_args):
599
- if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
600
- bpm = tags_args["BPM"]
601
- del tags_args["BPM"]
602
- tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
603
- for tag_type in tag_types_used:
604
- tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
605
- return tags_args
606
-
607
- def generate_description_ordinary(self, data, thresh = 0.3):
608
- if self.lang != 'en':
609
- raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
610
- description = f'a piece of music by {data["Artist"]}'
611
-
612
- # Add genre if available
613
- if data["Genre"] and random.random() > thresh:
614
- genres = ', '.join(data["Genre"])
615
- description += f', belonging to the {genres} genres'
616
-
617
- # Add moods if available
618
- if data["Tags"] and random.random() > thresh:
619
- tags = ', '.join(data["Tags"])
620
- description += f'. This track contains the tags:{tags}'
621
-
622
- # Add moods if available
623
- if data["Mood"] and random.random() > thresh:
624
- moods = ', '.join(data["Mood"])
625
- description += f'. This track conveys a {moods} mood.'
626
-
627
- # Add instruments if available
628
- if data["Instrument"] and random.random() > thresh:
629
- instruments = ', '.join(data["Instrument"])
630
- description += f'. and primarily features the following instruments: {instruments}'
631
-
632
- # Add a period to end the description
633
- description += '.'
634
-
635
- return description
636
-
637
- def mp3_path_to_id(mp3_path):
638
- return int(
639
- mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')]
640
- )
641
-
642
- class TmeDataset(Dataset):
643
- def __init__(self,
644
- data_index:str,
645
- music_info:str = None,
646
- duration:float = 10,
647
- sr:int = 0,
648
- return_path = False,
649
- return_audio = True,
650
- prompt_format_path: os.PathLike = None,
651
- tag_types = ['*'],
652
- lang = 'zh',
653
- translate: Optional[os.PathLike] = None,
654
- prompt_dir: os.PathLike = None,
655
- ):
656
- self.audio_reader = SafeAudioReader(duration, sr)
657
-
658
- self.sr = sr
659
- self.duration = duration
660
- self.return_path = return_path
661
- self.return_audio = return_audio
662
- self.lang = lang
663
-
664
- self.use_ready_prompt = prompt_dir is not None
665
-
666
- data_index = read_jsonlike(data_index)
667
- data_index = [d for d in data_index if d['duration']>self.duration+10]
668
- self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index}
669
- self.data_ids = list(self.data_index_dict.keys())
670
-
671
- if not self.use_ready_prompt:
672
- #读取音乐的信息文件
673
- music_info = read_jsonlike(music_info)
674
- if 'music' in music_info:
675
- music_info = music_info['music']
676
- self.music_info_dict = {d["歌曲ID"]:d for d in music_info}
677
- self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict}
678
- self.data_ids = list(self.data_index_dict.keys())
679
-
680
- with open(prompt_format_path) as fp:
681
- self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader)
682
-
683
- #加载tag types,并分成一般的tag_types和关键的key_tag_types
684
- if '*' in tag_types:
685
- self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag']
686
- else:
687
- self.tag_types = tag_types
688
-
689
- self.key_tag_types = []
690
- if 'tag' in self.tag_types:
691
- self.tag_types.remove('tag')
692
- self.key_tag_types = list(self.prompt_formats['tag'].keys())
693
-
694
- #加载translate翻译
695
- if translate is not None:
696
- self.translator = read_jsonlike(translate)
697
- else:
698
- data_ids_set = set(self.data_ids)
699
- self.prompts_dict = {}
700
- for fname in os.listdir(prompt_dir):
701
- items = read_jsonlike(os.path.join(prompt_dir, fname))
702
- for item in items:
703
- if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']):
704
- continue
705
- if item['ID'] not in self.prompts_dict:
706
- self.prompts_dict[item['ID']] = []
707
- self.prompts_dict[item['ID']].append(item['Text'])
708
- self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict}
709
- self.data_ids = list(self.data_index_dict.keys())
710
-
711
- def tags_to_desc(self, tag_list) -> str:
712
- if is_bearable(tag_list, int):
713
- return str(tag_list)
714
- if self.lang == 'zh':
715
- return tags_to_desc(tag_list, sep=self.sep)
716
- else:
717
- translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ]
718
- return tags_to_desc(translated_tag_list, sep=self.sep)
719
-
720
- def gen_desc_of_tag(self, formats, tags):
721
- fmt = random.choice(formats)
722
- return fmt.format(self.tags_to_desc(tags))
723
-
724
- @staticmethod
725
- def check_valid(value):
726
- if isinstance(value, int) or isinstance(value, float):
727
- return value > 0
728
- if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0):
729
- return True
730
- return False
731
-
732
- @staticmethod
733
- def remove_repeat(data):
734
- #若专辑名和歌曲名相同,则只使用后者
735
- album_name = data.get('专辑名', None)
736
- if album_name is not None and album_name == data.get('歌曲名', None):
737
- del data['专辑名']
738
- return data
739
-
740
- @property
741
- def comma(self):
742
- if self.lang == 'zh':
743
- return ','
744
- elif self.lang == 'en':
745
- return ', '
746
-
747
- @property
748
- def sep(self):
749
- if self.lang == 'zh':
750
- return '、'
751
- elif self.lang == 'en':
752
- return ', '
753
-
754
- def generate_description(self, data):
755
- data = self.remove_repeat(data)
756
- weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低
757
-
758
- key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个
759
-
760
- prompts = []
761
- if len(weak_tags) > 0:
762
- probs = dist_prob_map_low[len(weak_tags)]
763
- if len(key_tags) > 0:
764
- tags_num = random.choices(range(0, len(weak_tags)), probs)[0]
765
- else:
766
- tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0]
767
- random.shuffle(weak_tags)
768
- tags = weak_tags[:tags_num]
769
- for tag_type in tags:
770
- tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type])
771
- prompts.append(tag_desc)
772
-
773
- if len(key_tags) > 0:
774
- probs = dist_prob_map[len(key_tags)]
775
- tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0]
776
- random.shuffle(key_tags)
777
- tags = key_tags[:tags_num]
778
- for tag_type in tags:
779
- tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type])
780
- prompts.append(tag_desc)
781
-
782
- random.shuffle(prompts)
783
- return self.comma.join(prompts)
784
-
785
- def is_valid_prompt_text(self, text):
786
- for bad in ('抱歉','sorry', 'Sorry'):
787
- if bad in text:
788
- return False
789
- return True
790
-
791
- def get_ready_prompt(self, path):
792
- sid = mp3_path_to_id(path)
793
- return random.choice(self.prompts_dict[sid])
794
-
795
- def __len__(self):
796
- return len(self.data_ids)
797
-
798
- def __getitem__(self, idx):
799
- data_id = self.data_ids[idx]
800
- item = self.data_index_dict[data_id]
801
- path = item['path']
802
- if not self.use_ready_prompt:
803
- info = self.music_info_dict[data_id]
804
- description = self.generate_description(info)
805
- else:
806
- description = self.get_ready_prompt(path)
807
- if self.return_audio:
808
- sr, duration = get_sr_and_duration_info(item)
809
- audio, is_start, is_end = self.audio_reader(path, sr, duration)
810
- else:
811
- audio = None
812
- if self.return_path:
813
- return audio, description, path, is_start, is_end
814
- else:
815
- return audio, description, is_start, is_end
816
-
817
- class Pond5Dataset(Dataset):
818
- MAX_PROMPT_LEN = 200
819
- def __init__(self,
820
- metadata_path:str,
821
- index_path:str,
822
- duration:float=10,
823
- sr:int = 0,
824
- plain_rate = 0,
825
- return_path = False,
826
- return_audio = True,
827
- lang = 'en',
828
- translate:Optional[Dict[str, os.PathLike]] = None,
829
- use_literal_none = True,
830
- use_avoid_watermark_policy = None,
831
- ):
832
-
833
- if use_avoid_watermark_policy is None:
834
- raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type")
835
- self.use_avoid_watermark_policy = use_avoid_watermark_policy
836
- assert self.use_avoid_watermark_policy is False
837
- self.audio_reader = SafeAudioReader(duration, sr)
838
-
839
- self.duration = duration
840
- self._load_metadata(metadata_path, index_path)
841
- self.sr = sr
842
- self.plain_rate = plain_rate
843
- self.return_path = return_path
844
- self.return_audio = return_audio
845
- self.use_literal_none = use_literal_none
846
-
847
- self.lang = lang
848
- self.translate = read_translate(translate)
849
-
850
- def _load_metadata(self, metadata_path, index_path):
851
- data_index = read_jsonlike(index_path)
852
- data_ids = set([item['id'] for item in data_index])
853
-
854
- with open(metadata_path) as fp:
855
- lines = fp.readlines()
856
-
857
- append_ids = set()
858
-
859
- self.data = []
860
- for line in lines:
861
- item = json.loads(line)
862
- if item['id'] in data_ids and item['id'] not in append_ids and item["details"]["duration"] is not None and item["details"]["duration"]>self.duration+10:
863
- self.data.append(item)
864
- append_ids.add(item['id'])
865
-
866
- def __len__(self):
867
- return len(self.data)
868
-
869
- def __getitem__(self, idx):
870
- item = self.data[idx]
871
- path:str = item["path"]
872
- description = self.generate_description(item)
873
- if self.return_audio:
874
- sr, duration = get_sr_and_duration_info(item)
875
- audio, is_start, is_end = self.audio_reader(path, sr, duration)
876
- else:
877
- audio = None
878
- if self.return_path:
879
- return audio, description, path
880
- return audio, description, is_start, is_end
881
-
882
- @property
883
- def keysep(self):
884
- if self.lang == 'zh':
885
- return ',' if random.random() > 0.5 else '、'
886
- elif self.lang == 'en':
887
- return ', '
888
-
889
- def generate_description(self, item):
890
- if random.random() > self.plain_rate:
891
- # dynamically generate prompt from given prompt template
892
- description = self.generate_description_dynamic(item)
893
- else:
894
- # use plain prompt, i.e. tags sequence separated by comma
895
- description = self.generate_description_plain(item)
896
- return description
897
-
898
- def get_translation(self, k):
899
- k = k.strip()
900
- if k in self.translate:
901
- return self.translate[k]
902
- else:
903
- return k
904
-
905
- def generate_description_plain(self, item):
906
- keywords = item['keywords']
907
- if self.lang != 'en':
908
- keywords = [self.get_translation(k) for k in keywords]
909
- return gen_plain_prompt(keywords, sep=self.keysep)
910
-
911
- def generate_description_dynamic(self,item):
912
- desc = item.get('desc', 'none')
913
- if desc is None:
914
- desc = 'none'
915
- desc = desc.strip()
916
- if len(desc) > self.MAX_PROMPT_LEN:
917
- shorter_desc = desc[:self.MAX_PROMPT_LEN]
918
- # find last stop
919
- stop_idx = shorter_desc.rfind('.')
920
- if stop_idx == -1:
921
- stop_idx = shorter_desc.rfind('!')
922
- if stop_idx == -1:
923
- stop_idx = shorter_desc.rfind(',')
924
- if stop_idx == -1:
925
- stop_idx = self.MAX_PROMPT_LEN - 1
926
- desc = desc[:stop_idx+1]
927
- return desc
928
-
929
- class CombinedDataset(Dataset):
930
- @beartype
931
- def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
932
- self.datasets = datasets
933
- self.datasets_index = []
934
-
935
- for i,dataset in enumerate(datasets):
936
- if dataset is None:
937
- continue
938
- for dup in range(ratios[i]):
939
- for j in range(len(dataset)):
940
- self.datasets_index.append((i,j))
941
-
942
- def __len__(self):
943
- return len(self.datasets_index)
944
-
945
- def __getitem__(self, idx):
946
- index = self.datasets_index[idx]
947
- i,j = index
948
- return self.datasets[i][j]
949
-
950
- class CombinedDataset_random(Dataset):
951
- @beartype
952
- def __init__(self,
953
- num_examples:int,
954
- datasets: Sequence[Dataset], ratios: Sequence[int]
955
- ):
956
- self.datasets = datasets
957
- self.datasets_index = []
958
-
959
- for i,dataset in enumerate(datasets):
960
- if dataset is None:
961
- continue
962
- for dup in range(ratios[i]):
963
- for j in range(len(dataset)):
964
- self.datasets_index.append((i,j))
965
- if num_examples > 0:
966
- self.random_choose = True
967
- self.dataset_len = num_examples
968
- else:
969
- self.random_choose = False
970
- self.dataset_len = len(self.datasets_index)
971
-
972
- def __len__(self):
973
- return self.dataset_len
974
-
975
- def __getitem__(self, idx):
976
- first_try = True
977
- try_cnt = 0
978
- while True:
979
- try:
980
- if(self.random_choose or not first_try):
981
- index2 = []
982
- index2.append(np.random.randint(0,len(self.datasets)))
983
- index2.append(np.random.randint(0,len(self.datasets[index2[-1]])))
984
- else:
985
- index2 = self.datasets_index[idx]
986
- first_try = False
987
- out = self.datasets[index2[0]][index2[1]]
988
- if(len(out[0].shape)==1):out[0]=out[0][None,:]
989
- return out
990
- except:
991
- print("Error loadding ", index2)
992
- try_cnt += 1
993
- if(try_cnt>10):
994
- raise FileNotFoundError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py DELETED
@@ -1,313 +0,0 @@
1
- import re
2
- import sys
3
- import json
4
-
5
- from torch.utils.data import Dataset
6
- import torchaudio
7
- from torchaudio.functional import resample
8
- import torch
9
- import numpy as np
10
-
11
- from torch.nn.utils.rnn import pad_sequence
12
-
13
-
14
-
15
- def check_lryics(lyric):
16
- _FILTER_STRING = [
17
- '作词', '作曲', '编曲', '【', '策划',
18
- '录音', '混音', '母带', ':', '制作',
19
- '版权', '校对', '演奏', '制作', '伴奏'
20
- ]
21
- for item in _FILTER_STRING:
22
- if item in lyric:
23
- return True
24
-
25
- return False
26
-
27
-
28
-
29
- def process_lyrics(lines):
30
- lyric_part = []
31
- timestamp_part = []
32
-
33
- timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
34
-
35
- for i, line in enumerate(lines):
36
-
37
- # 删除前几行的特定信息
38
- if i<10 and check_lryics(line):
39
- continue
40
-
41
- # 检查是否包含有效的时间戳和歌词内容
42
- if timestamp_pattern.match(line):
43
- timestamp_end = line.rfind(']')
44
- lyrics = line[timestamp_end + 1:].strip()
45
- timestamps = line[:timestamp_end + 1]
46
-
47
- if ':' in lyrics:
48
- if len(lyrics.split(":")[0]) <=5:
49
- lyrics = "".join(lyrics.split(":")[1:])
50
- # if lyrics: # 确保歌词部分不是空的
51
- # lyric_part.append(lyrics)
52
- # timestamp_part.append(timestamps)
53
- # print(processed_lyrics)
54
- return timestamp_part, lyric_part
55
-
56
- def get_timestamps(timestamp_part):
57
-
58
- # 转换为秒
59
-
60
- timestamps = []
61
-
62
- for line in timestamp_part:
63
- match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
64
- if match:
65
- minutes = int(match.group(1))
66
- seconds = float(match.group(2))
67
- millis = float(match.group(3)) if match.group(3) else 0
68
- total_seconds = minutes * 60 + seconds + millis
69
- timestamps.append(total_seconds)
70
-
71
-
72
- return timestamps
73
-
74
- def process_lyrics_lrc(lyrics):
75
- timestamp_part, lyric_part = process_lyrics(lyrics)
76
- # print(timestamp_part)
77
- # print(lyric_part)
78
- timestamps = get_timestamps(timestamp_part)
79
- # print(timestamps)
80
- if len(timestamps) == 0:
81
- # print(f'{lyric_path}')
82
- return []
83
-
84
- slice_start = timestamps[0]
85
- slice_start_idx = 0
86
-
87
- output_list = []
88
- for i in range(1, len(timestamps)):
89
- # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
90
- if timestamps[i] - slice_start > 30:
91
- output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
92
-
93
- slice_start = timestamps[i]
94
- slice_start_idx = i
95
-
96
- return output_list
97
-
98
-
99
-
100
- def process_lyrics_yrc(lyrics):
101
-
102
- timestamps, lyric_part = extract_lrc(lyrics)
103
-
104
- # timestamp_part, lyric_part = process_lyrics(lyrics)
105
- # import pdb; pdb.set_trace()
106
- # print(timestamp_part)
107
- # print(lyric_part)
108
- # timestamps = get_timestamps(timestamp_part)
109
- # print(timestamps)
110
- if len(timestamps) == 0:
111
- # print(f'{lyric_path}')
112
- return []
113
-
114
- slice_start = timestamps[0]
115
- slice_start_idx = 0
116
-
117
- output_list = []
118
- for i in range(1, len(timestamps)):
119
- # 如果累积时间超过30秒,则进行切分
120
- if timestamps[i] - slice_start > 30:
121
- output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
122
-
123
- slice_start = timestamps[i]
124
- slice_start_idx = i
125
- # import pdb; pdb.set_trace()
126
- return output_list
127
-
128
- def extract_lrc(lyrics):
129
- timestamp_part, lyric_part = [], []
130
-
131
- for i, text in enumerate(lyrics):
132
- # 提取中括号内的内容
133
- bracket_content = re.search(r'\[(.*?)\]', text).group(1)
134
- bracket_content = bracket_content.split(',')
135
- # 提取小括号内的内容
136
- parentheses_content = re.findall(r'\((.*?)\)', text)
137
- # 提取其他内容
138
- other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
139
-
140
- # 数据怎么处理?
141
- # import pdb; pdb.set_trace()
142
- if i<10 and check_lryics(other_content):
143
- continue
144
-
145
- # import pdb; pdb.set_trace()
146
- timestamp_part.append(float(bracket_content[0])/1000)
147
- lyric_part.append(other_content)
148
- # import pdb; pdb.set_trace()
149
- return timestamp_part, lyric_part
150
-
151
-
152
-
153
- class WYYSongDataset(Dataset):
154
- def __init__(self,
155
- metadata_path:str,
156
- sr:int = 0,
157
- use_lang = ['en', 'zh-cn'],
158
- num_examples = -1,
159
- ):
160
-
161
- self.sr = sr
162
- self.use_lang = use_lang
163
- self._load_metadata(metadata_path)
164
-
165
- # buffer
166
- self.lyric_buffer = {}
167
-
168
- if(num_examples<=0):
169
- self.dataset_len = len(self.data)
170
- self.random_slc = False
171
- else:
172
- self.dataset_len = num_examples
173
- self.random_slc = True
174
-
175
- # 读取jsonl文件
176
- def _load_metadata(self, metadata_path):
177
- with open(metadata_path) as fp:
178
- lines = fp.readlines()
179
- self.data = []
180
- for line in lines:
181
- item = json.loads(line)
182
- # if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None:
183
- if 'lyrics' in item and 'lang_info' in item:
184
- if len(item['lyrics']) > 0:
185
- for lang in self.use_lang:
186
- if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9:
187
- # if '伴奏' not in item['path'] and "cloud" in item['path']:
188
- if '伴奏' not in item['path']:
189
- self.data.append(item)
190
-
191
-
192
- def __len__(self):
193
- return self.dataset_len
194
-
195
-
196
- def __getitem__(self, idx):
197
- try_cnt = 0
198
- while True:
199
- if(self.random_slc):
200
- idx = np.random.randint(0, len(self.data))
201
- yrc_lyrics = []
202
- lrc_lyrics = []
203
- try:
204
- info = self.data[idx]
205
-
206
- # audio path
207
- path:str = info["path"]
208
-
209
- # 读取歌词段落
210
- if 'lyrics' not in info:
211
- if idx not in self.lyric_buffer:
212
- # 字级别align的歌词
213
- if info['yrc-lyric'] is not None:
214
- with open(info['yrc-lyric']) as f_in:
215
- yrc_lyric = json.load(f_in)
216
- yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1])
217
-
218
- # 句子级align的歌词
219
- if info['lrc-lyric'] is not None:
220
- with open(info['lrc-lyric']) as f_in:
221
- lrc_lyric = json.load(f_in)
222
- lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1])
223
-
224
- # 优先使用字级别align的歌词
225
- if len(yrc_lyrics) > 0:
226
- lyrics = yrc_lyrics
227
- else:
228
- lyrics = lrc_lyrics
229
- self.lyric_buffer[idx] = lyrics
230
-
231
- # TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲
232
- else:
233
- lyrics = self.lyric_buffer[idx]
234
- else:
235
- lyrics = info['lyrics']
236
-
237
- # 随机选取一个lyric段落
238
- ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
239
- # ly_id = 0
240
-
241
- lyric = lyrics[ly_id]
242
-
243
-
244
-
245
- st, et, lyric = self.parse_lyric(lyric)
246
-
247
- assert et - st < 40
248
-
249
- # 文本过滤
250
-
251
- lyric = re.sub(r'【.*?】', '', lyric)
252
- if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8:
253
- assert 200 > len(lyric.replace(" ", "")) > 30
254
- if ':' in lyrics:
255
- if len(lyrics.split(":")[0]) <=5:
256
- lyrics = "".join(lyrics.split(":")[1:])
257
-
258
- if ':' in lyrics:
259
- if len(lyrics.split(":")[0]) <=5:
260
- lyrics = "".join(lyrics.split(":")[1:])
261
-
262
- if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8:
263
- assert 200 > len(lyric.split()) > 20
264
-
265
- if ':' in lyrics:
266
- if len(lyrics.split(":")[0].split()) <=3:
267
- lyrics = "".join(lyrics.split(":")[1:])
268
-
269
- if ':' in lyrics:
270
- if len(lyrics.split(":")[0].split()) <=3:
271
- lyrics = "".join(lyrics.split(":")[1:])
272
-
273
-
274
-
275
- # 读取音频文件
276
- cur_sample_rate = torchaudio.info(path).sample_rate
277
- offset = int(cur_sample_rate*st)
278
- num_frames = int(cur_sample_rate * (et -st))
279
- chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
280
-
281
- # 随机选取一个channel
282
- if(chunk.shape[0]>1):
283
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
284
- else:
285
- chunk = chunk[[0],:].float()
286
-
287
- if(cur_sample_rate!=self.sr):
288
- # print('a:',cur_sample_rate,chunk.shape)
289
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
290
-
291
- return chunk, lyric, [st, et], path
292
- except:
293
- print("Error loadding ", info["path"])
294
- try_cnt += 1
295
- idx = np.random.randint(0, len(self.data))
296
- if(try_cnt>10):
297
- raise FileNotFoundError()
298
-
299
- def parse_lyric(self, lyric):
300
- pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
301
- match = re.search(pattern, lyric)
302
-
303
- start_time = float(match.group(1))
304
- end_time = float(match.group(2))
305
- content = match.group(3)
306
- return start_time, end_time, content
307
-
308
- def collect_song(data_list):
309
- audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
310
- lyrics = [data[1] for data in data_list]
311
- st_et = [data[2] for data in data_list]
312
- paths = [data[3] for data in data_list]
313
- return audios, lyrics, st_et
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py DELETED
@@ -1,313 +0,0 @@
1
- import re
2
- import sys
3
- import json
4
-
5
- from torch.utils.data import Dataset
6
- import torchaudio
7
- from torchaudio.functional import resample
8
- import torch
9
- import numpy as np
10
-
11
- from torch.nn.utils.rnn import pad_sequence
12
-
13
-
14
-
15
- def check_lryics(lyric):
16
- _FILTER_STRING = [
17
- '作词', '作曲', '编曲', '【', '策划',
18
- '录音', '混音', '母带', ':', '制作',
19
- '版权', '校对', '演奏', '制作', '伴奏'
20
- ]
21
- for item in _FILTER_STRING:
22
- if item in lyric:
23
- return True
24
-
25
- return False
26
-
27
-
28
-
29
- def process_lyrics(lines):
30
- lyric_part = []
31
- timestamp_part = []
32
-
33
- timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
34
-
35
- for i, line in enumerate(lines):
36
-
37
- # 删除前几行的特定信息
38
- if i<10 and check_lryics(line):
39
- continue
40
-
41
- # 检查是否包含有效的时间戳和歌词内容
42
- if timestamp_pattern.match(line):
43
- timestamp_end = line.rfind(']')
44
- lyrics = line[timestamp_end + 1:].strip()
45
- timestamps = line[:timestamp_end + 1]
46
-
47
- if ':' in lyrics:
48
- if len(lyrics.split(":")[0]) <=5:
49
- lyrics = "".join(lyrics.split(":")[1:])
50
- # if lyrics: # 确保歌词部分不是空的
51
- # lyric_part.append(lyrics)
52
- # timestamp_part.append(timestamps)
53
- # print(processed_lyrics)
54
- return timestamp_part, lyric_part
55
-
56
- def get_timestamps(timestamp_part):
57
-
58
- # 转换为秒
59
-
60
- timestamps = []
61
-
62
- for line in timestamp_part:
63
- match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
64
- if match:
65
- minutes = int(match.group(1))
66
- seconds = float(match.group(2))
67
- millis = float(match.group(3)) if match.group(3) else 0
68
- total_seconds = minutes * 60 + seconds + millis
69
- timestamps.append(total_seconds)
70
-
71
-
72
- return timestamps
73
-
74
- def process_lyrics_lrc(lyrics):
75
- timestamp_part, lyric_part = process_lyrics(lyrics)
76
- # print(timestamp_part)
77
- # print(lyric_part)
78
- timestamps = get_timestamps(timestamp_part)
79
- # print(timestamps)
80
- if len(timestamps) == 0:
81
- # print(f'{lyric_path}')
82
- return []
83
-
84
- slice_start = timestamps[0]
85
- slice_start_idx = 0
86
-
87
- output_list = []
88
- for i in range(1, len(timestamps)):
89
- # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
90
- if timestamps[i] - slice_start > 30:
91
- output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
92
-
93
- slice_start = timestamps[i]
94
- slice_start_idx = i
95
-
96
- return output_list
97
-
98
-
99
-
100
- def process_lyrics_yrc(lyrics):
101
-
102
- timestamps, lyric_part = extract_lrc(lyrics)
103
-
104
- # timestamp_part, lyric_part = process_lyrics(lyrics)
105
- # import pdb; pdb.set_trace()
106
- # print(timestamp_part)
107
- # print(lyric_part)
108
- # timestamps = get_timestamps(timestamp_part)
109
- # print(timestamps)
110
- if len(timestamps) == 0:
111
- # print(f'{lyric_path}')
112
- return []
113
-
114
- slice_start = timestamps[0]
115
- slice_start_idx = 0
116
-
117
- output_list = []
118
- for i in range(1, len(timestamps)):
119
- # 如果累积时间超过30秒,则进行切分
120
- if timestamps[i] - slice_start > 30:
121
- output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
122
-
123
- slice_start = timestamps[i]
124
- slice_start_idx = i
125
- # import pdb; pdb.set_trace()
126
- return output_list
127
-
128
- def extract_lrc(lyrics):
129
- timestamp_part, lyric_part = [], []
130
-
131
- for i, text in enumerate(lyrics):
132
- # 提取中括号内的内容
133
- bracket_content = re.search(r'\[(.*?)\]', text).group(1)
134
- bracket_content = bracket_content.split(',')
135
- # 提取小括号内的内容
136
- parentheses_content = re.findall(r'\((.*?)\)', text)
137
- # 提取其他内容
138
- other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
139
-
140
- # 数据怎么处理?
141
- # import pdb; pdb.set_trace()
142
- if i<10 and check_lryics(other_content):
143
- continue
144
-
145
- # import pdb; pdb.set_trace()
146
- timestamp_part.append(float(bracket_content[0])/1000)
147
- lyric_part.append(other_content)
148
- # import pdb; pdb.set_trace()
149
- return timestamp_part, lyric_part
150
-
151
-
152
-
153
- class WYYSongDataset(Dataset):
154
- def __init__(self,
155
- metadata_path:str,
156
- sr:int = 0,
157
- use_lang = ['en', 'zh-cn'],
158
- num_examples = -1,
159
- ):
160
-
161
- self.sr = sr
162
- self.use_lang = use_lang
163
- self._load_metadata(metadata_path)
164
-
165
- # buffer
166
- self.lyric_buffer = {}
167
-
168
- if(num_examples<=0):
169
- self.dataset_len = len(self.data)
170
- self.random_slc = False
171
- else:
172
- self.dataset_len = num_examples
173
- self.random_slc = True
174
-
175
- # 读取jsonl文件
176
- def _load_metadata(self, metadata_path):
177
- with open(metadata_path) as fp:
178
- lines = fp.readlines()
179
- self.data = []
180
- for line in lines:
181
- item = json.loads(line)
182
- # if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None:
183
- if 'lyrics' in item and 'lang_info' in item:
184
- if len(item['lyrics']) > 0:
185
- for lang in self.use_lang:
186
- if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9:
187
- # if '伴奏' not in item['path'] and "cloud" in item['path']:
188
- if '伴奏' not in item['path']:
189
- self.data.append(item)
190
-
191
-
192
- def __len__(self):
193
- return self.dataset_len
194
-
195
-
196
- def __getitem__(self, idx):
197
- try_cnt = 0
198
- while True:
199
- if(self.random_slc):
200
- idx = np.random.randint(0, len(self.data))
201
- yrc_lyrics = []
202
- lrc_lyrics = []
203
- try:
204
- info = self.data[idx]
205
-
206
- # audio path
207
- path:str = info["path"]
208
-
209
- # 读取歌词段落
210
- if 'lyrics' not in info:
211
- if idx not in self.lyric_buffer:
212
- # 字级别align的歌词
213
- if info['yrc-lyric'] is not None:
214
- with open(info['yrc-lyric']) as f_in:
215
- yrc_lyric = json.load(f_in)
216
- yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1])
217
-
218
- # 句子级align的歌词
219
- if info['lrc-lyric'] is not None:
220
- with open(info['lrc-lyric']) as f_in:
221
- lrc_lyric = json.load(f_in)
222
- lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1])
223
-
224
- # 优先使用字级别align的歌词
225
- if len(yrc_lyrics) > 0:
226
- lyrics = yrc_lyrics
227
- else:
228
- lyrics = lrc_lyrics
229
- self.lyric_buffer[idx] = lyrics
230
-
231
- # TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲
232
- else:
233
- lyrics = self.lyric_buffer[idx]
234
- else:
235
- lyrics = info['lyrics']
236
-
237
- # 随机选取一个lyric段落
238
- ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
239
- # ly_id = 0
240
-
241
- lyric = lyrics[ly_id]
242
-
243
-
244
-
245
- st, et, lyric = self.parse_lyric(lyric)
246
-
247
- assert et - st < 20
248
-
249
- # 文本过滤
250
-
251
- lyric = re.sub(r'【.*?】', '', lyric)
252
- if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8:
253
- assert 100 > len(lyric.replace(" ", "")) > 5
254
- if ':' in lyrics:
255
- if len(lyrics.split(":")[0]) <=5:
256
- lyrics = "".join(lyrics.split(":")[1:])
257
-
258
- if ':' in lyrics:
259
- if len(lyrics.split(":")[0]) <=5:
260
- lyrics = "".join(lyrics.split(":")[1:])
261
-
262
- if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8:
263
- assert 100 > len(lyric.split()) > 5
264
-
265
- if ':' in lyrics:
266
- if len(lyrics.split(":")[0].split()) <=3:
267
- lyrics = "".join(lyrics.split(":")[1:])
268
-
269
- if ':' in lyrics:
270
- if len(lyrics.split(":")[0].split()) <=3:
271
- lyrics = "".join(lyrics.split(":")[1:])
272
-
273
-
274
-
275
- # 读取音频文件
276
- cur_sample_rate = torchaudio.info(path).sample_rate
277
- offset = int(cur_sample_rate*st)
278
- num_frames = int(cur_sample_rate * (et -st))
279
- chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
280
-
281
- # 随机选取一个channel
282
- if(chunk.shape[0]>1):
283
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
284
- else:
285
- chunk = chunk[[0],:].float()
286
-
287
- if(cur_sample_rate!=self.sr):
288
- # print('a:',cur_sample_rate,chunk.shape)
289
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
290
-
291
- return chunk, lyric, [st, et], path
292
- except:
293
- print("Error loadding ", info["path"])
294
- try_cnt += 1
295
- idx = np.random.randint(0, len(self.data))
296
- if(try_cnt>10):
297
- raise FileNotFoundError()
298
-
299
- def parse_lyric(self, lyric):
300
- pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
301
- match = re.search(pattern, lyric)
302
-
303
- start_time = float(match.group(1))
304
- end_time = float(match.group(2))
305
- content = match.group(3)
306
- return start_time, end_time, content
307
-
308
- def collect_song(data_list):
309
- audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
310
- lyrics = [data[1] for data in data_list]
311
- st_et = [data[2] for data in data_list]
312
- paths = [data[3] for data in data_list]
313
- return audios, lyrics, st_et
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py DELETED
@@ -1,313 +0,0 @@
1
- import re
2
- import sys
3
- import json
4
-
5
- from torch.utils.data import Dataset
6
- import torchaudio
7
- from torchaudio.functional import resample
8
- import torch
9
- import numpy as np
10
-
11
- from torch.nn.utils.rnn import pad_sequence
12
-
13
-
14
-
15
- def check_lryics(lyric):
16
- _FILTER_STRING = [
17
- '作词', '作曲', '编曲', '【', '策划',
18
- '录音', '混音', '母带', ':', '制作',
19
- '版权', '校对', '演奏', '制作', '伴奏'
20
- ]
21
- for item in _FILTER_STRING:
22
- if item in lyric:
23
- return True
24
-
25
- return False
26
-
27
-
28
-
29
- def process_lyrics(lines):
30
- lyric_part = []
31
- timestamp_part = []
32
-
33
- timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]')
34
-
35
- for i, line in enumerate(lines):
36
-
37
- # 删除前几行的特定信息
38
- if i<10 and check_lryics(line):
39
- continue
40
-
41
- # 检查是否包含有效的时间戳和歌词内容
42
- if timestamp_pattern.match(line):
43
- timestamp_end = line.rfind(']')
44
- lyrics = line[timestamp_end + 1:].strip()
45
- timestamps = line[:timestamp_end + 1]
46
-
47
- if ':' in lyrics:
48
- if len(lyrics.split(":")[0]) <=5:
49
- lyrics = "".join(lyrics.split(":")[1:])
50
- # if lyrics: # 确保歌词部分不是空的
51
- # lyric_part.append(lyrics)
52
- # timestamp_part.append(timestamps)
53
- # print(processed_lyrics)
54
- return timestamp_part, lyric_part
55
-
56
- def get_timestamps(timestamp_part):
57
-
58
- # 转换为秒
59
-
60
- timestamps = []
61
-
62
- for line in timestamp_part:
63
- match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line)
64
- if match:
65
- minutes = int(match.group(1))
66
- seconds = float(match.group(2))
67
- millis = float(match.group(3)) if match.group(3) else 0
68
- total_seconds = minutes * 60 + seconds + millis
69
- timestamps.append(total_seconds)
70
-
71
-
72
- return timestamps
73
-
74
- def process_lyrics_lrc(lyrics):
75
- timestamp_part, lyric_part = process_lyrics(lyrics)
76
- # print(timestamp_part)
77
- # print(lyric_part)
78
- timestamps = get_timestamps(timestamp_part)
79
- # print(timestamps)
80
- if len(timestamps) == 0:
81
- # print(f'{lyric_path}')
82
- return []
83
-
84
- slice_start = timestamps[0]
85
- slice_start_idx = 0
86
-
87
- output_list = []
88
- for i in range(1, len(timestamps)):
89
- # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉
90
- if timestamps[i] - slice_start > 30:
91
- output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
92
-
93
- slice_start = timestamps[i]
94
- slice_start_idx = i
95
-
96
- return output_list
97
-
98
-
99
-
100
- def process_lyrics_yrc(lyrics):
101
-
102
- timestamps, lyric_part = extract_lrc(lyrics)
103
-
104
- # timestamp_part, lyric_part = process_lyrics(lyrics)
105
- # import pdb; pdb.set_trace()
106
- # print(timestamp_part)
107
- # print(lyric_part)
108
- # timestamps = get_timestamps(timestamp_part)
109
- # print(timestamps)
110
- if len(timestamps) == 0:
111
- # print(f'{lyric_path}')
112
- return []
113
-
114
- slice_start = timestamps[0]
115
- slice_start_idx = 0
116
-
117
- output_list = []
118
- for i in range(1, len(timestamps)):
119
- # 如果累积时间超过30秒,则进行切分
120
- if timestamps[i] - slice_start > 30:
121
- output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i]))
122
-
123
- slice_start = timestamps[i]
124
- slice_start_idx = i
125
- # import pdb; pdb.set_trace()
126
- return output_list
127
-
128
- def extract_lrc(lyrics):
129
- timestamp_part, lyric_part = [], []
130
-
131
- for i, text in enumerate(lyrics):
132
- # 提取中括号内的内容
133
- bracket_content = re.search(r'\[(.*?)\]', text).group(1)
134
- bracket_content = bracket_content.split(',')
135
- # 提取小括号内的内容
136
- parentheses_content = re.findall(r'\((.*?)\)', text)
137
- # 提取其他内容
138
- other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip()
139
-
140
- # 数据怎么处理?
141
- if i<10 and check_lryics(other_content):
142
- continue
143
- timestamp_part.append(float(bracket_content[0])/1000)
144
- lyric_part.append(other_content)
145
- return timestamp_part, lyric_part
146
-
147
-
148
-
149
- class WYYSongDataset(Dataset):
150
- def __init__(self,
151
- metadata_path:str,
152
- sr:int = 0,
153
- use_lang = ['en', 'zh-cn'],
154
- num_examples = -1,
155
- max_dur = 20,
156
- pad_to_max= True,
157
- ):
158
-
159
- self.sr = sr
160
- self.use_lang = use_lang
161
- self._load_metadata(metadata_path)
162
- self.max_dur = max_dur
163
- self.pad_to_max = pad_to_max
164
-
165
- # buffer
166
- self.lyric_buffer = {}
167
-
168
- if(num_examples<=0):
169
- self.dataset_len = len(self.data)
170
- self.random_slc = False
171
- else:
172
- self.dataset_len = num_examples
173
- self.random_slc = True
174
-
175
- # 读取jsonl文件
176
- def _load_metadata(self, metadata_path):
177
- with open(metadata_path) as fp:
178
- lines = fp.readlines()
179
- self.data = []
180
- for line in lines:
181
- item = json.loads(line)
182
- if '伴奏' not in item['path']:
183
- # if "lang_type" in item and item['lang_type'] == 'en':
184
- if "lang_type" in item:
185
- self.data.append(item)
186
-
187
-
188
- def __len__(self):
189
- return self.dataset_len
190
-
191
-
192
- def __getitem__(self, idx):
193
- try_cnt = 0
194
- while True:
195
- if(self.random_slc):
196
- idx = np.random.randint(0, len(self.data))
197
- yrc_lyrics = []
198
- lrc_lyrics = []
199
- try:
200
- info = self.data[idx]
201
-
202
- # audio path
203
- path = info["path"]
204
- lang_type = info["lang_type"]
205
- if info["lang_type"] == 'en':
206
- lyrics = info['lyrics']
207
- else:
208
- lyrics = info['lyrics_phone']
209
-
210
- # 随机选取一个lyric段落
211
- ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item()
212
- lyric = lyrics[ly_id].strip()
213
-
214
- st, et, lyric = self.parse_lyric(lyric)
215
- lyric = lyric.replace("\xa0", " ")
216
-
217
- lyric = " ".join(lyric.split())
218
-
219
- assert et - st < self.max_dur
220
-
221
-
222
- if info["lang_type"] == 'en':
223
- # print(len(lyric.split())/(et-st))
224
- assert 6 > len(lyric.split())/(et-st) > 1
225
- else:
226
- # print(len(lyric.split())/(et-st))
227
- lyric = lyric.replace("-", "")
228
- assert 6 > len(lyric.split())/(et-st) > 1
229
-
230
-
231
- # 读取音频文件
232
- cur_sample_rate = torchaudio.info(path).sample_rate
233
- offset = int(cur_sample_rate*st)
234
- num_frames = int(cur_sample_rate * (et -st))
235
- chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames)
236
- # chunk = torch.zeros(1, 48000*15)
237
-
238
- # 随机选取一个channel
239
- if(chunk.shape[0]>1):
240
- chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
241
- else:
242
- chunk = chunk[[0],:].float()
243
-
244
- if(cur_sample_rate!=self.sr):
245
- # print('a:',cur_sample_rate,chunk.shape)
246
- chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr)
247
-
248
- if self.pad_to_max:
249
- chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0)
250
-
251
- return chunk, lyric, et-st, path, lang_type
252
- except:
253
- # print("Error loadding ", info["path"])
254
- try_cnt += 1
255
- idx = np.random.randint(0, len(self.data))
256
- if(try_cnt>20):
257
- raise FileNotFoundError()
258
-
259
- def parse_lyric(self, lyric):
260
- pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)'
261
- match = re.search(pattern, lyric)
262
-
263
- start_time = float(match.group(1))
264
- end_time = float(match.group(2))
265
- content = match.group(3)
266
- return start_time, end_time, content
267
-
268
- def pad_2d_tensor(self, x, max_len, pad_id):
269
- # 获取输入 tensor 的形状
270
- batch_size, seq_len = x.size()
271
- max_len = max(max_len, seq_len)
272
- # 计算需要填充的长度
273
- pad_len = max_len - seq_len
274
-
275
- # 如果需要填充
276
- if pad_len > 0:
277
- # 创建填充 tensor
278
- pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device)
279
-
280
- # 沿第二个维度(列)连接输入 tensor 和填充 tensor
281
- padded_tensor = torch.cat([x, pad_tensor], dim=1)
282
- else:
283
- # 如果不需要填充,直接返回输入 tensor
284
- padded_tensor = x
285
-
286
- return padded_tensor
287
-
288
- def collect_data(data_list):
289
- audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2)
290
- lyrics = [data[1] for data in data_list]
291
- st_et = [data[2] for data in data_list]
292
- paths = [data[3] for data in data_list]
293
- lang_types = [data[4] for data in data_list]
294
- return audios, lyrics, st_et, lang_types
295
- # return audios, lyrics, st_et
296
-
297
-
298
- def build_dataset():
299
- train_dataset = WYYSongDataset(
300
- metadata_path = "train.jsonl",
301
- sr = 48000,
302
- use_lang = ['zh-cn', 'en'],
303
- num_examples = 10*10000
304
- )
305
-
306
- valid_dataset = WYYSongDataset(
307
- metadata_path = "valid.jsonl",
308
- sr = 48000,
309
- use_lang = ['zh-cn', 'en'],
310
- num_examples = 500
311
- )
312
-
313
- return train_dataset, valid_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py DELETED
@@ -1,461 +0,0 @@
1
- from torch.utils.data import Dataset
2
- from beartype.typing import Sequence, Callable, Optional, Dict, List
3
- from beartype.door import is_bearable
4
- import random
5
- import os
6
- from torchaudio.functional import resample
7
- import torch
8
- import typing as tp
9
- from pathlib import Path
10
- import torchaudio as ta
11
- import torch.nn.functional as F
12
- import soundfile
13
- import numpy as np
14
- import json
15
- import yaml
16
- import random
17
- import librosa
18
- from loguru import logger
19
- import re
20
-
21
-
22
- def _av_read(filepath, seek_time=0, duration=None):
23
- if duration is not None:
24
- sr = librosa.get_samplerate(filepath)
25
- offset = seek_time
26
- num_samples = int(duration * sr)
27
- wav, _ = librosa.load(filepath, sr=sr, offset=offset, duration=duration)
28
- else:
29
- wav, sr = librosa.load(filepath, sr=None, offset=seek_time)
30
-
31
- return wav, sr
32
-
33
- def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
34
- duration: float = -1., pad: bool = True) -> tp.Tuple[torch.Tensor, int]:
35
- """Read audio by picking the most appropriate backend tool based on the audio format.
36
-
37
- Args:
38
- filepath (str or Path): Path to audio file to read.
39
- seek_time (float): Time at which to start reading in the file.
40
- duration (float): Duration to read from the file. If set to -1, the whole file is read.
41
- pad (bool): Pad output audio if not reaching expected duration.
42
- Returns:
43
- tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
44
- """
45
- fp = Path(filepath)
46
- if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
47
- # There is some bug with ffmpeg and reading flac
48
- info = soundfile.info(filepath)
49
- frames = -1 if duration <= 0 else int(duration * info.samplerate)
50
- frame_offset = int(seek_time * info.samplerate)
51
- wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
52
- assert info.samplerate == sr, f"Mismatch of sample rates {info.samplerate} {sr}"
53
- wav = torch.from_numpy(wav).t().contiguous()
54
- if len(wav.shape) == 1:
55
- wav = torch.unsqueeze(wav, 0)
56
- elif (
57
- fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
58
- and duration <= 0 and seek_time == 0
59
- ):
60
- # Torchaudio is faster if we load an entire file at once.
61
- wav, sr = librosa.load(fp, sr=None, mono=True)
62
- else:
63
- wav, sr = _av_read(filepath, seek_time, duration)
64
- if pad and duration > 0:
65
- expected_frames = int(duration * sr)
66
- wav = F.pad(torch.tensor(wav), (0, expected_frames - wav.shape[-1]))
67
- if not isinstance(wav, torch.Tensor):
68
- wav = torch.tensor(wav)
69
- return wav, sr
70
-
71
- def random_seek_read(filepath, duration):
72
- if duration > 0:
73
- total_duration = librosa.get_duration(path=filepath)
74
- acceptable_start = max(0, total_duration - duration)
75
- wav, sr = audio_read(filepath, random.uniform(0, acceptable_start), duration, pad=True)
76
- else:
77
- wav, sr = audio_read(filepath, 0, -1, pad=False)
78
- return wav, sr
79
-
80
- def safe_random_seek_read(filepath, duration, sample_rate):
81
- try:
82
- wav, sr = random_seek_read(filepath, duration)
83
- if sr != sample_rate:
84
- wav = resample(wav, sr, sample_rate)
85
- sr = sample_rate
86
- except Exception as e:
87
- logger.error(f"Error reading {filepath}: {e}")
88
- sr = sample_rate
89
- wav = torch.zeros(sr * max(duration, 0), dtype=torch.float32)
90
- return wav, sr
91
-
92
- def read_jsonlike(path: os.PathLike):
93
- #json or jsonl
94
- if str(path).endswith(".json"):
95
- with open(path, 'r', encoding='utf8') as f:
96
- data = json.load(f)
97
- return data
98
- elif str(path).endswith(".jsonl"):
99
- with open(path, 'r', encoding='utf8') as f:
100
- data = [json.loads(line) for line in f.readlines()]
101
- return data
102
- else:
103
- raise ValueError("Unknown file format")
104
-
105
- dist_prob_map = {
106
- 1: (1.0,),
107
- 2: (0.5, 0.5),
108
- 3: (0.3, 0.4, 0.3),
109
- 4: (0.2, 0.3, 0.3, 0.2),
110
- 5: (0.2, 0.2, 0.3, 0.2, 0.1),
111
- 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
112
- 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
113
- 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
114
- 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
115
- 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
116
- }
117
-
118
- dist_prob_map_low = {
119
- 1: (1.0,),
120
- 2: (0.8, 0.2),
121
- 3: (0.8, 0.1, 0.1),
122
- 4: (0.7, 0.1, 0.1, 0.1),
123
- 5: (0.7, 0.1, 0.1, 0.05, 0.05),
124
- 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
125
- }
126
-
127
-
128
- _bpm_range_rights = (
129
- (40, '20-40'),
130
- (60, '40-60'),
131
- (66, '60-66'),
132
- (76, '66-76'),
133
- (108, '76-108'),
134
- (120, '108-120'),
135
- (168, '120-168'),
136
- (176, '168-176'),
137
- (200, '176-200')
138
- )
139
- _bpm_desc_map = {
140
- '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"),
141
- '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"),
142
- '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'),
143
- '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'),
144
- '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"),
145
- '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'),
146
- '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'),
147
- '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'),
148
- '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'),
149
- '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo')
150
- }
151
- _bpm_desc_map_zh = {
152
- '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"),
153
- '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"),
154
- '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'),
155
- '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'),
156
- '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"),
157
- '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'),
158
- '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'),
159
- '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'),
160
- '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'),
161
- '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板')
162
- }
163
- def get_bpm_range(bpm):
164
- bpm = int(bpm)
165
- for right, tag in _bpm_range_rights:
166
- if bpm <= right:
167
- return tag
168
- return '>200'
169
-
170
- def gen_bpm_descript(bpm, lang='en'):
171
- bpm_range = get_bpm_range(bpm)
172
- if lang == 'en':
173
- return random.choice(_bpm_desc_map[bpm_range])
174
- elif lang == 'zh':
175
- return random.choice(_bpm_desc_map_zh[bpm_range])
176
- else:
177
- raise ValueError(f"Unknown language {lang}")
178
-
179
- def read_translate(translate: Optional[Dict[str, os.PathLike]]):
180
- if translate is None:
181
- return None
182
- return {k: read_jsonlike(path) for k, path in translate.items()}
183
-
184
-
185
- def tags_to_desc(tag_list, sep=',') -> str:
186
- if not isinstance(tag_list, Sequence):
187
- return str(tag_list)
188
- if isinstance(tag_list, str):
189
- return tag_list
190
- if len(tag_list) <= 0:
191
- return ''
192
- elif len(tag_list) <= 5:
193
- probs = dist_prob_map[len(tag_list)]
194
- tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
195
- random.shuffle(tag_list)
196
- tag_list = tag_list[:tags_num]
197
- return sep.join(tag_list)
198
- else:
199
- probs = dist_prob_map[5]
200
- tags_num = random.choices(range(1, 6), probs)[0]
201
- random.shuffle(tag_list)
202
- tag_list = tag_list[:tags_num]
203
- return sep.join(tag_list)
204
-
205
-
206
- class PromptTemplate:
207
- def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
208
- self.template_text = template_text
209
- self.tag_map = tag_map
210
- self.lang = lang
211
-
212
- @property
213
- def tags(self):
214
- return tuple(self.tag_map.keys())
215
-
216
- def apply(self, **kwargs):
217
- for tag in list(kwargs.keys()):
218
- if kwargs[tag] == '':
219
- kwargs.pop(tag)
220
- for tag in self.tags:
221
- if tag in kwargs:
222
- kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
223
- else:
224
- kwargs[tag] = ''
225
- prompt = self.template_text.format(**kwargs)
226
-
227
- return self.beautify(prompt)
228
-
229
- def beautify(self, text):
230
- if self.lang == 'en':
231
- return self._beautify_en(text)
232
- elif self.lang == 'zh':
233
- return self._beautify_zh(text)
234
- else:
235
- raise ValueError(f'Unknown language {self.lang}')
236
-
237
- @staticmethod
238
- def _beautify_en(text):
239
- # no continuous commas without content between them
240
- text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
241
- # no continuous whitespace
242
- text = re.sub(r'\s+', ' ', text)
243
- # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
244
- text = re.sub(r'\s+,', r',', text)
245
- text = re.sub(r',\s+', r', ', text)
246
- # no whitespace before the full stop
247
- text = re.sub(r'\s+\.', r'.', text)
248
- # strip whitespace, comma, and replace ',.'
249
- text = text.strip(' ,')
250
- text = text.replace(',.', '.')
251
- return text
252
-
253
- @staticmethod
254
- def _beautify_zh(text):
255
- # no continuous commas without content between them
256
- text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
257
- text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
258
- # assume there should be NO whitespace in Chinese
259
- text = re.sub(r'\s+', r'', text)
260
- # strip whitespace, comma, and replace ',。'
261
- text = text.strip(', 、')
262
- text = text.replace(',。', '。')
263
- return text
264
-
265
- def __repr__(self):
266
- return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
267
-
268
- __str__ = __repr__
269
-
270
- def parse_prompt_template(prompt_template_text, lang='en'):
271
- span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
272
- tag_pattern = re.compile(r'{.+?}', re.DOTALL)
273
-
274
- template_text = prompt_template_text.strip()
275
- span_texts = span_pattern.findall(prompt_template_text)
276
- tag_map = {}
277
- for span_text in span_texts:
278
- tag = tag_pattern.findall(span_text)[0].strip('{}')
279
- tag_map[tag] = span_text
280
- template_text = template_text.replace(span_text, '{'+tag+'}')
281
-
282
- return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
283
-
284
- def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
285
- with open(path, 'r') as f:
286
- lines = f.readlines()
287
- cnt = 0
288
- pts = []
289
- for line in lines:
290
- pt = parse_prompt_template(line, lang=lang)
291
- cnt += 1
292
- if len(pt.tags) < num:
293
- logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
294
- pts.append(pt)
295
-
296
- return pts
297
-
298
-
299
- class AudioStockDataset(Dataset):
300
- def __init__(self,
301
- num_examples:int,
302
- metadata_path:str,
303
- duration:float=60,
304
- sr:int = 0,
305
- return_path = False,
306
- return_audio = True,
307
- prompt_template_path: os.PathLike = None,
308
- tag_types = [],
309
- lang = 'en',
310
- translate:Optional[Dict[str, os.PathLike]] = None
311
- ):
312
- self.duration = duration
313
- self.MAX_DURATION = 360
314
- self._load_metadata(metadata_path)
315
- if num_examples > 0:
316
- self.random_choose = True
317
- self.dataset_len = num_examples
318
- else:
319
- self.random_choose = False
320
- self.dataset_len = len(self.data)
321
- self.sr = sr
322
- self.return_path = return_path
323
- self.return_audio = return_audio
324
-
325
- self.use_dynamic_prompt = prompt_template_path is not None
326
- if self.use_dynamic_prompt:
327
- self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang)
328
- self.tag_types = tag_types
329
-
330
- self.lang = lang
331
- self.translate = read_translate(translate)
332
-
333
- def _load_metadata(self, metadata_path):
334
- total_len = 0; valid_len = 0
335
- with open(metadata_path) as fp:
336
- lines = fp.readlines()
337
- self.data = []
338
- for line in lines:
339
- item = json.loads(line)
340
- total_len += 1
341
- if(item['duration']>self.duration and item['duration']<self.MAX_DURATION):
342
- valid_len += 1
343
- self.data.append(item)
344
- print("Filter data from {} to {}".format(total_len, valid_len))
345
- self.is_info_recorded = bool('Tags' in self.data[0])
346
-
347
- def __len__(self):
348
- return self.dataset_len
349
-
350
- def __getitem__(self, idx):
351
- first_try = True
352
- try_cnt = 0
353
- while True:
354
- try:
355
- if(self.random_choose or not first_try):
356
- index2 = np.random.randint(0,len(self.data))
357
- else:
358
- index2 = idx
359
- first_try = False
360
- return self.getitem_main(index2)
361
- except:
362
- print("Error loadding ", self.data[idx]["path"])
363
- try_cnt += 1
364
- if(try_cnt>10):
365
- raise ValueError()
366
-
367
- def getitem_main(self, idx):
368
- path:str = self.data[idx]["path"]
369
- json_path = path[:path.rfind('.')] + ".json"
370
- if self.is_info_recorded:
371
- item = self.data[idx]
372
- else:
373
- with open(json_path) as fp:
374
- item:dict = json.load(fp)
375
- description = self.generate_description(item)
376
- if self.return_audio:
377
- audio, sr = safe_random_seek_read(path, duration=self.duration, sample_rate=self.sr)
378
- else:
379
- audio = None
380
- if self.return_path:
381
- return audio, description, path
382
- return audio, description
383
-
384
-
385
-
386
- def generate_description(self, item):
387
- if self.use_dynamic_prompt:
388
- # dynamically generate prompt from given prompt template
389
- prompt_template = random.choice(self.prompt_templates)
390
- description = self.generate_description_dynamic(item, prompt_template)
391
- else:
392
- # use ordinary static prompt instead
393
- description = self.generate_description_ordinary(item)
394
- return description
395
-
396
- def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
397
- exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
398
-
399
- if len(exists_tag) > 0:
400
- probs = dist_prob_map[len(exists_tag)]
401
- tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0]
402
- random.shuffle(exists_tag)
403
- tags = exists_tag[:tags_num]
404
- tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
405
- tags_args = self.handle_BPM_tag(tags_args)
406
- prompt = prompt_template.apply(**tags_args)
407
- else:
408
- # no strong tags, use all weak tags instead
409
- prompt = prompt_template.apply()
410
-
411
- return prompt
412
-
413
- def tags_to_desc(self, tag_list, tag_type) -> str:
414
- if self.lang == 'en':
415
- return tags_to_desc(tag_list)
416
- elif self.lang == 'zh':
417
- if tag_type == 'BPM':
418
- return tags_to_desc(tag_list, sep='、')
419
- translator = self.translate[tag_type]
420
- translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
421
- return tags_to_desc(translated_tag_list, sep='、')
422
-
423
- def handle_BPM_tag(self, tags_args):
424
- if "BPM" in tags_args and 'BPMDescript' in self.tag_types:
425
- bpm = tags_args["BPM"]
426
- del tags_args["BPM"]
427
- tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript')))
428
- for tag_type in tag_types_used:
429
- tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang)
430
- return tags_args
431
-
432
- def generate_description_ordinary(self, data, thresh = 0.3):
433
- if self.lang != 'en':
434
- raise ValueError(f'Language {self.lang} is not supported for ordinary description generation')
435
- description = f'a piece of music by {data["Artist"]}'
436
-
437
- # Add genre if available
438
- if data["Genre"] and random.random() > thresh:
439
- genres = ', '.join(data["Genre"])
440
- description += f', belonging to the {genres} genres'
441
-
442
- # Add moods if available
443
- if data["Tags"] and random.random() > thresh:
444
- tags = ', '.join(data["Tags"])
445
- description += f'. This track contains the tags:{tags}'
446
-
447
- # Add moods if available
448
- if data["Mood"] and random.random() > thresh:
449
- moods = ', '.join(data["Mood"])
450
- description += f'. This track conveys a {moods} mood.'
451
-
452
- # Add instruments if available
453
- if data["Instrument"] and random.random() > thresh:
454
- instruments = ', '.join(data["Instrument"])
455
- description += f'. and primarily features the following instruments: {instruments}'
456
-
457
- # Add a period to end the description
458
- description += '.'
459
-
460
- return description
461
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/model_1rvq.py CHANGED
@@ -270,8 +270,6 @@ class PromptCondAudioDiffusion(nn.Module):
270
  hubert_layer=None,
271
  ssl_layer=None,
272
  uncondition=True,
273
- out_paint=False,
274
- ssl_path='ckpt/encode-s12k.pt'
275
  ):
276
  super().__init__()
277
 
 
270
  hubert_layer=None,
271
  ssl_layer=None,
272
  uncondition=True,
 
 
273
  ):
274
  super().__init__()
275
 
codeclm/tokenizer/Flow1dVAE/model_2rvq.py DELETED
@@ -1,774 +0,0 @@
1
- import yaml
2
- import random
3
- import inspect
4
- import numpy as np
5
- from tqdm import tqdm
6
- import typing as tp
7
- from abc import ABC
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torchaudio
13
-
14
- from einops import repeat
15
- from tools.torch_tools import wav_to_fbank
16
-
17
- import diffusers
18
- from diffusers.utils.torch_utils import randn_tensor
19
- from diffusers import DDPMScheduler
20
- from models.transformer_2d_flow import Transformer2DModel
21
- from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
22
- # from tools.get_mulan import get_mulan
23
- from third_party.wespeaker.extract_embd import XVECModel
24
- # from libs.rvq2 import RVQEmbedding
25
- from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
26
-
27
- from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
28
- from models_gpt.models.gpt2_config import GPT2Config
29
-
30
- from torch.cuda.amp import autocast
31
-
32
-
33
- from our_MERT_BESTRQ.test import load_model
34
-
35
- class HubertModelWithFinalProj(HubertModel):
36
- def __init__(self, config):
37
- super().__init__(config)
38
-
39
- # The final projection layer is only used for backward compatibility.
40
- # Following https://github.com/auspicious3000/contentvec/issues/6
41
- # Remove this layer is necessary to achieve the desired outcome.
42
- print("hidden_size:",config.hidden_size)
43
- print("classifier_proj_size:",config.classifier_proj_size)
44
- self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
45
-
46
-
47
- class SampleProcessor(torch.nn.Module):
48
- def project_sample(self, x: torch.Tensor):
49
- """Project the original sample to the 'space' where the diffusion will happen."""
50
- """Project back from diffusion space to the actual sample space."""
51
- return z
52
-
53
- class Feature1DProcessor(SampleProcessor):
54
- def __init__(self, dim: int = 100, power_std = 1., \
55
- num_samples: int = 100_000, cal_num_frames: int = 600):
56
- super().__init__()
57
-
58
- self.num_samples = num_samples
59
- self.dim = dim
60
- self.power_std = power_std
61
- self.cal_num_frames = cal_num_frames
62
- self.register_buffer('counts', torch.zeros(1))
63
- self.register_buffer('sum_x', torch.zeros(dim))
64
- self.register_buffer('sum_x2', torch.zeros(dim))
65
- self.register_buffer('sum_target_x2', torch.zeros(dim))
66
- self.counts: torch.Tensor
67
- self.sum_x: torch.Tensor
68
- self.sum_x2: torch.Tensor
69
-
70
- @property
71
- def mean(self):
72
- mean = self.sum_x / self.counts
73
- if(self.counts < 10):
74
- mean = torch.zeros_like(mean)
75
- return mean
76
-
77
- @property
78
- def std(self):
79
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
80
- if(self.counts < 10):
81
- std = torch.ones_like(std)
82
- return std
83
-
84
- @property
85
- def target_std(self):
86
- return 1
87
-
88
- def project_sample(self, x: torch.Tensor):
89
- assert x.dim() == 3
90
- if self.counts.item() < self.num_samples:
91
- self.counts += len(x)
92
- self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
93
- self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
94
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
95
- x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
96
- return x
97
-
98
- def return_sample(self, x: torch.Tensor):
99
- assert x.dim() == 3
100
- rescale = (self.std / self.target_std) ** self.power_std
101
- # print(rescale, self.mean)
102
- x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
103
- return x
104
-
105
- def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
106
- if(prior_text_encoder_hidden_states.shape[1]<len_size):
107
- prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
108
- torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
109
- prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
110
- dtype=prior_text_encoder_hidden_states.dtype)],1)
111
- prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
112
- else:
113
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
114
- prior_text_mask = prior_text_mask[:,0:len_size]
115
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
116
- return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
117
-
118
- class BASECFM(torch.nn.Module, ABC):
119
- def __init__(
120
- self,
121
- estimator,
122
- mlp,
123
- ssl_layer
124
- ):
125
- super().__init__()
126
- self.sigma_min = 1e-4
127
-
128
- self.estimator = estimator
129
- self.mlp = mlp
130
- self.ssl_layer = ssl_layer
131
-
132
- @torch.inference_mode()
133
- def forward(self, mu, n_timesteps, temperature=1.0):
134
- """Forward diffusion
135
-
136
- Args:
137
- mu (torch.Tensor): output of encoder
138
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
139
- n_timesteps (int): number of diffusion steps
140
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
141
-
142
- Returns:
143
- sample: generated mel-spectrogram
144
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
145
- """
146
- z = torch.randn_like(mu) * temperature
147
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
148
- return self.solve_euler(z, t_span=t_span)
149
-
150
- def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
151
- """
152
- Fixed euler solver for ODEs.
153
- Args:
154
- x (torch.Tensor): random noise
155
- t_span (torch.Tensor): n_timesteps interpolated
156
- shape: (n_timesteps + 1,)
157
- mu (torch.Tensor): output of encoder
158
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
159
- """
160
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
161
- noise = x.clone()
162
-
163
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
164
- # Or in future might add like a return_all_steps flag
165
- sol = []
166
-
167
- for step in tqdm(range(1, len(t_span))):
168
- # print("incontext_x.shape:",incontext_x.shape)
169
- # print("noise.shape:",noise.shape)
170
- # print("t.shape:",t.shape)
171
- x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
172
- if(guidance_scale > 1.0):
173
-
174
- model_input = torch.cat([ \
175
- torch.cat([latent_mask_input, latent_mask_input], 0), \
176
- torch.cat([incontext_x, incontext_x], 0), \
177
- torch.cat([torch.zeros_like(mu), mu], 0), \
178
- torch.cat([x, x], 0), \
179
- ], 2)
180
- timestep=t.unsqueeze(-1).repeat(2)
181
-
182
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
183
- dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
184
- dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
185
- else:
186
- model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
187
- timestep=t.unsqueeze(-1)
188
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
189
-
190
- dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
191
- # print("dphi_dt.shape:",dphi_dt.shape)
192
- # print("x.shape:",x.shape)
193
-
194
- x = x + dt * dphi_dt
195
- t = t + dt
196
- sol.append(x)
197
- if step < len(t_span) - 1:
198
- dt = t_span[step + 1] - t
199
-
200
- return sol[-1]
201
-
202
- def projection_loss(self,hidden_proj, bestrq_emb):
203
- bsz = hidden_proj.shape[0]
204
-
205
- hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
206
- bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
207
-
208
- proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
209
- proj_loss = 1+proj_loss.mean()
210
-
211
- return proj_loss
212
-
213
- def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
214
- """Computes diffusion loss
215
-
216
- Args:
217
- x1 (torch.Tensor): Target
218
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
219
- mu (torch.Tensor): output of encoder
220
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
221
-
222
- Returns:
223
- loss: conditional flow matching loss
224
- y: conditional flow
225
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
226
- """
227
- b = mu[0].shape[0]
228
- len_x = x1.shape[2]
229
- # random timestep
230
- if(validation_mode):
231
- t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
232
- else:
233
- t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
234
- # sample noise p(x_0)
235
- z = torch.randn_like(x1)
236
-
237
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
238
- u = x1 - (1 - self.sigma_min) * z
239
- # print("y.shape:",y.shape)
240
- #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
241
- model_input = torch.cat([*mu,y], 2)
242
- t=t.squeeze(-1).squeeze(-1)
243
- # print("model_input.shape:",model_input.shape)
244
- # print("attention_mask.shape:",attention_mask.shape)
245
- out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
246
- hidden_layer = out.hidden_states[self.ssl_layer]
247
- hidden_proj = self.mlp(hidden_layer)
248
- # print("hidden_proj.shape:",hidden_proj.shape)
249
- # print("mert_emb.shape:",mert_emb.shape)
250
- # exit()
251
-
252
-
253
- out = out.last_hidden_state
254
-
255
- out=out[:,:,-len_x:]
256
- # out=self.proj_out(out)
257
-
258
- weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
259
- # print("out.shape",out.shape)
260
- # print("u.shape",u.shape)
261
- loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
262
- # print("hidden_proj.shape:",hidden_proj.shape)
263
- # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
264
- loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
265
- loss = loss_re + loss_cos * 0.5
266
- # print("loss_cos:",loss_cos,loss_cos.device)
267
- print("loss:",loss,loss.device)
268
- # exit()
269
- return loss, loss_re, loss_cos
270
-
271
- class PromptCondAudioDiffusion(nn.Module):
272
- def __init__(
273
- self,
274
- num_channels,
275
- unet_model_name=None,
276
- unet_model_config_path=None,
277
- snr_gamma=None,
278
- hubert_layer=None,
279
- ssl_layer=None,
280
- uncondition=True,
281
- out_paint=False,
282
- ):
283
- super().__init__()
284
-
285
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
286
-
287
- self.unet_model_name = unet_model_name
288
- self.unet_model_config_path = unet_model_config_path
289
- self.snr_gamma = snr_gamma
290
- self.uncondition = uncondition
291
- self.num_channels = num_channels
292
- self.hubert_layer = hubert_layer
293
- self.ssl_layer = ssl_layer
294
-
295
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
296
- self.normfeat = Feature1DProcessor(dim=64)
297
-
298
- self.sample_rate = 48000
299
- self.num_samples_perseg = self.sample_rate * 20 // 1000
300
- self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
301
- self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
302
- # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
303
- # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
304
- self.bestrq = load_model(
305
- model_dir='path/to/our-MERT/mert_fairseq',
306
- checkpoint_dir='checkpoint-120000.pt',
307
- )
308
- self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
309
- self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
310
- for v in self.bestrq.parameters():v.requires_grad = False
311
- self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
312
- # for v in self.rvq_bestrq_emb.parameters():
313
- # print(v)
314
- freeze_parameters='quantizers.0'
315
- for name, param in self.rvq_bestrq_emb.named_parameters():
316
- if freeze_parameters in name:
317
- param.requires_grad = False
318
- print("Freezing RVQ parameters:", name)
319
- self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
320
- for v in self.hubert.parameters():v.requires_grad = False
321
- self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
322
- # self.xvecmodel = XVECModel()
323
- config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
324
- unet = GPT2Model(config)
325
- mlp = nn.Sequential(
326
- nn.Linear(1200, 1024),
327
- nn.SiLU(),
328
- nn.Linear(1024, 1024),
329
- nn.SiLU(),
330
- nn.Linear(1024, 768)
331
- )
332
- self.set_from = "random"
333
- self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
334
- self.mask_emb = torch.nn.Embedding(3, 48)
335
- print("Transformer initialized from pretrain.")
336
- torch.cuda.empty_cache()
337
- # self.unet.set_attn_processor(AttnProcessor2_0())
338
- # self.unet.set_use_memory_efficient_attention_xformers(True)
339
-
340
- # self.start_embedding = nn.Parameter(torch.randn(1,1024))
341
- # self.end_embedding = nn.Parameter(torch.randn(1,1024))
342
-
343
- def compute_snr(self, timesteps):
344
- """
345
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
346
- """
347
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
348
- sqrt_alphas_cumprod = alphas_cumprod**0.5
349
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
350
-
351
- # Expand the tensors.
352
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
353
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
354
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
355
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
356
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
357
-
358
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
359
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
360
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
361
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
362
-
363
- # Compute SNR.
364
- snr = (alpha / sigma) ** 2
365
- return snr
366
-
367
- def preprocess_audio(self, input_audios, threshold=0.9):
368
- assert len(input_audios.shape) == 2, input_audios.shape
369
- norm_value = torch.ones_like(input_audios[:,0])
370
- max_volume = input_audios.abs().max(dim=-1)[0]
371
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
372
- return input_audios/norm_value.unsqueeze(-1)
373
-
374
- def extract_wav2vec_embeds(self, input_audios,output_len):
375
- wav2vec_stride = 2
376
-
377
- wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
378
- # print(wav2vec_embeds)
379
- # print("audio.shape:",input_audios.shape)
380
- wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
381
- # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
382
- wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
383
- return wav2vec_embeds_last
384
-
385
- def extract_mert_embeds(self, input_audios):
386
- prompt_stride = 3
387
- inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
388
- input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
389
- prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
390
- mert_emb= prompt_embeds[-1]
391
- mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
392
-
393
- return mert_emb
394
-
395
- def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
396
- self.bestrq.eval()
397
- # print("audio shape:",input_audio_0.shape)
398
- input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
399
- # print("input_wav_mean.shape:",input_wav_mean.shape)
400
- # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
401
- input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
402
- layer_results = input_wav_mean['layer_results']
403
- # print("layer_results.shape:",layer_results[layer].shape)
404
- bestrq_emb = layer_results[layer]
405
- bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
406
- #[b,t,1024] t=t/960
407
- #35.84s->batch,896,1024
408
- return bestrq_emb
409
-
410
-
411
- def extract_spk_embeds(self, input_audios):
412
- spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
413
- spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
414
- return spk_embeds
415
-
416
- def extract_lyric_feats(self, lyric):
417
- with torch.no_grad():
418
- try:
419
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
420
- except:
421
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
422
- text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
423
- text_mask = text_mask.to(self.device)
424
- text_encoder_hidden_states, text_mask, text_prompt_embeds = \
425
- pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
426
- text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
427
- return text_encoder_hidden_states, text_mask
428
-
429
- def extract_energy_bar(self, input_audios):
430
- if(input_audios.shape[-1] % self.num_samples_perseg > 0):
431
- energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
432
- else:
433
- energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
434
- energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
435
- energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
436
- energy_embedding = self.energy_embedding(energy_bar)
437
- energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
438
- return energy_embedding
439
-
440
- def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
441
- additional_feats = ['spk', 'lyric'], \
442
- train_rvq=True, train_ssl=False,layer=5):
443
- if not hasattr(self,"device"):
444
- self.device = input_audios.device
445
- if not hasattr(self,"dtype"):
446
- self.dtype = input_audios.dtype
447
- device = self.device
448
- input_audio_0 = input_audios[:,0,:]
449
- input_audio_1 = input_audios[:,1,:]
450
- input_audio_0 = self.preprocess_audio(input_audio_0)
451
- input_audio_1 = self.preprocess_audio(input_audio_1)
452
- input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
453
- # energy_embedding = self.extract_energy_bar(input_audios)
454
- # print("energy_embedding.shape:",energy_embedding.shape)
455
- # with autocast(enabled=False):
456
- if(train_ssl):
457
- self.wav2vec.train()
458
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
459
- self.clap_embd_extractor.train()
460
- prompt_embeds = self.extract_mert_embeds(input_audios)
461
- if('spk' in additional_feats):
462
- self.xvecmodel.train()
463
- spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
464
- else:
465
- with torch.no_grad():
466
- with autocast(enabled=False):
467
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
468
- # mert_emb = self.extract_mert_embeds(input_audios_mert)
469
-
470
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
471
-
472
- bestrq_emb = bestrq_emb.detach()
473
- if('lyric' in additional_feats):
474
- text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
475
- else:
476
- text_encoder_hidden_states, text_mask = None, None
477
-
478
-
479
- if(train_rvq):
480
- random_num=random.random()
481
- if(random_num<0.6):
482
- rvq_layer = 1
483
- elif(random_num<0.8):
484
- rvq_layer = 2
485
- else:
486
- rvq_layer = 4
487
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
488
- else:
489
- bestrq_emb = bestrq_emb.float()
490
- self.rvq_bestrq_emb.eval()
491
- # with autocast(enabled=False):
492
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
493
- commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
494
- codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
495
- quantized_bestrq_emb = quantized_bestrq_emb.detach()
496
-
497
- commitment_loss = commitment_loss_bestrq_emb
498
- codebook_loss = codebook_loss_bestrq_emb
499
-
500
-
501
- alpha=1
502
- quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
503
-
504
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
505
- # print("latent_masks.shape:",latent_masks.shape)
506
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
507
-
508
-
509
-
510
- scenario = np.random.choice(['start_seg', 'other_seg'])
511
- if(scenario == 'other_seg'):
512
- for binx in range(input_audios.shape[0]):
513
- # latent_masks[binx,0:64] = 1
514
- latent_masks[binx,0:random.randint(64,128)] = 1
515
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
516
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
517
- # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
518
- # print("latent_masks.shape:",latent_masks.shape)
519
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
520
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
521
-
522
-
523
-
524
-
525
- if self.uncondition:
526
- mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
527
- if len(mask_indices) > 0:
528
- quantized_bestrq_emb[mask_indices] = 0
529
- # print("latents.shape:",latents.shape)
530
- latents = latents.permute(0,2,1).contiguous()
531
- latents = self.normfeat.project_sample(latents)
532
- latents = latents.permute(0,2,1).contiguous()
533
- incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
534
- attention_mask=(latent_masks > 0.5)
535
- B, L = attention_mask.size()
536
- attention_mask = attention_mask.view(B, 1, L)
537
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
538
- attention_mask = attention_mask.unsqueeze(1)
539
- # print("incontext_latents.shape:",incontext_latents.shape)
540
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
541
- latent_mask_input = self.mask_emb(latent_masks)
542
- #64+48+64+1024
543
- loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
544
- return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
545
-
546
- def init_device_dtype(self, device, dtype):
547
- self.device = device
548
- self.dtype = dtype
549
-
550
- @torch.no_grad()
551
- def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
552
- input_audio_0 = input_audios[[0],:]
553
- input_audio_1 = input_audios[[1],:]
554
- input_audio_0 = self.preprocess_audio(input_audio_0)
555
- input_audio_1 = self.preprocess_audio(input_audio_1)
556
-
557
- self.bestrq.eval()
558
-
559
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
560
- # bestrq_middle = bestrq_middle.detach()
561
- # bestrq_last = bestrq_last.detach()
562
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
563
- bestrq_emb = bestrq_emb.detach()
564
-
565
- # self.rvq_bestrq_middle.eval()
566
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
567
- # self.rvq_bestrq_last.eval()
568
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
569
-
570
- self.rvq_bestrq_emb.eval()
571
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
572
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
573
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
574
- # exit()
575
-
576
-
577
- if('spk' in additional_feats):
578
- self.xvecmodel.eval()
579
- spk_embeds = self.extract_spk_embeds(input_audios)
580
- else:
581
- spk_embeds = None
582
-
583
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
584
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
585
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
586
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
587
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
588
-
589
- @torch.no_grad()
590
- def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
591
- input_audio_0 = input_audios[:,0,:]
592
- input_audio_1 = input_audios[:,1,:]
593
- input_audio_0 = self.preprocess_audio(input_audio_0)
594
- input_audio_1 = self.preprocess_audio(input_audio_1)
595
-
596
- self.bestrq.eval()
597
-
598
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
599
- # bestrq_middle = bestrq_middle.detach()
600
- # bestrq_last = bestrq_last.detach()
601
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
602
- bestrq_emb = bestrq_emb.detach()
603
-
604
- # self.rvq_bestrq_middle.eval()
605
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
606
- # self.rvq_bestrq_last.eval()
607
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
608
-
609
- self.rvq_bestrq_emb.eval()
610
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
611
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
612
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
613
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
614
- # exit()
615
-
616
-
617
- if('spk' in additional_feats):
618
- self.xvecmodel.eval()
619
- spk_embeds = self.extract_spk_embeds(input_audios)
620
- else:
621
- spk_embeds = None
622
-
623
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
624
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
625
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
626
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
627
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
628
-
629
- @torch.no_grad()
630
- def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
631
- input_audio_0 = input_audios[:,0,:]
632
- input_audio_1 = input_audios[:,1,:]
633
- input_audio_0 = self.preprocess_audio(input_audio_0)
634
- input_audio_1 = self.preprocess_audio(input_audio_1)
635
-
636
- self.bestrq.eval()
637
-
638
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
639
- # bestrq_middle = bestrq_middle.detach()
640
- # bestrq_last = bestrq_last.detach()
641
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
642
- bestrq_emb = bestrq_emb.detach()
643
-
644
- # self.rvq_bestrq_middle.eval()
645
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
646
- # self.rvq_bestrq_last.eval()
647
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
648
-
649
- self.rvq_bestrq_emb.eval()
650
- bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
651
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
652
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
653
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
654
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
655
- # exit()
656
-
657
-
658
- if('spk' in additional_feats):
659
- self.xvecmodel.eval()
660
- spk_embeds = self.extract_spk_embeds(input_audios)
661
- else:
662
- spk_embeds = None
663
-
664
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
665
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
666
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
667
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
668
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
669
-
670
- @torch.no_grad()
671
- def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
672
- guidance_scale=2, num_steps=20,
673
- disable_progress=True, scenario='start_seg'):
674
- classifier_free_guidance = guidance_scale > 1.0
675
- device = self.device
676
- dtype = self.dtype
677
- # codes_bestrq_middle, codes_bestrq_last = codes
678
- codes_bestrq_emb = codes[0]
679
-
680
-
681
- batch_size = codes_bestrq_emb.shape[0]
682
-
683
-
684
- quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
685
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
686
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
687
- print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
688
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
689
-
690
-
691
-
692
-
693
- if('spk' in additional_feats):
694
- spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
695
-
696
- num_frames = quantized_bestrq_emb.shape[1]
697
-
698
- num_channels_latents = self.num_channels
699
- shape = (batch_size, num_frames, 64)
700
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
701
-
702
-
703
-
704
- latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
705
- latent_masks[:,0:latent_length] = 2
706
- if(scenario=='other_seg'):
707
- latent_masks[:,0:incontext_length] = 1
708
-
709
-
710
-
711
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
712
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
713
- true_latents = true_latents.permute(0,2,1).contiguous()
714
- true_latents = self.normfeat.project_sample(true_latents)
715
- true_latents = true_latents.permute(0,2,1).contiguous()
716
- incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
717
- incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
718
-
719
-
720
- attention_mask=(latent_masks > 0.5)
721
- B, L = attention_mask.size()
722
- attention_mask = attention_mask.view(B, 1, L)
723
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
724
- attention_mask = attention_mask.unsqueeze(1)
725
- latent_mask_input = self.mask_emb(latent_masks)
726
-
727
- if('spk' in additional_feats):
728
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
729
- additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
730
- else:
731
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
732
- additional_model_input = torch.cat([quantized_bestrq_emb],1)
733
-
734
- temperature = 1.0
735
- t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
736
- latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
737
-
738
- latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
739
- latents = latents.permute(0,2,1).contiguous()
740
- latents = self.normfeat.return_sample(latents)
741
- # latents = latents.permute(0,2,1).contiguous()
742
- return latents
743
-
744
- @torch.no_grad()
745
- def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
746
- disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
747
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
748
-
749
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
750
- guidance_scale=guidance_scale, num_steps=num_steps, \
751
- disable_progress=disable_progress,scenario=scenario)
752
- return latents
753
-
754
- @torch.no_grad()
755
- def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
756
- disable_progress=True,layer=5,scenario='start_seg'):
757
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
758
- import time
759
- start = time.time()
760
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
761
- guidance_scale=guidance_scale, num_steps=num_steps, \
762
- disable_progress=disable_progress,scenario=scenario)
763
- return latents,time.time()-start
764
-
765
- def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
766
- divisor = 4
767
- shape = (batch_size, num_channels_latents, num_frames, 32)
768
- if(num_frames%divisor>0):
769
- num_frames = round(num_frames/float(divisor))*divisor
770
- shape = (batch_size, num_channels_latents, num_frames, 32)
771
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
772
- return latents
773
-
774
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/model_4rvq.py DELETED
@@ -1,774 +0,0 @@
1
- import yaml
2
- import random
3
- import inspect
4
- import numpy as np
5
- from tqdm import tqdm
6
- import typing as tp
7
- from abc import ABC
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torchaudio
13
-
14
- from einops import repeat
15
- from tools.torch_tools import wav_to_fbank
16
-
17
- import diffusers
18
- from diffusers.utils.torch_utils import randn_tensor
19
- from diffusers import DDPMScheduler
20
- from models.transformer_2d_flow import Transformer2DModel
21
- from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel
22
- # from tools.get_mulan import get_mulan
23
- from third_party.wespeaker.extract_embd import XVECModel
24
- # from libs.rvq2 import RVQEmbedding
25
- from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize
26
-
27
- from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
28
- from models_gpt.models.gpt2_config import GPT2Config
29
-
30
- from torch.cuda.amp import autocast
31
-
32
-
33
- from our_MERT_BESTRQ.test import load_model
34
-
35
- class HubertModelWithFinalProj(HubertModel):
36
- def __init__(self, config):
37
- super().__init__(config)
38
-
39
- # The final projection layer is only used for backward compatibility.
40
- # Following https://github.com/auspicious3000/contentvec/issues/6
41
- # Remove this layer is necessary to achieve the desired outcome.
42
- print("hidden_size:",config.hidden_size)
43
- print("classifier_proj_size:",config.classifier_proj_size)
44
- self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
45
-
46
-
47
- class SampleProcessor(torch.nn.Module):
48
- def project_sample(self, x: torch.Tensor):
49
- """Project the original sample to the 'space' where the diffusion will happen."""
50
- """Project back from diffusion space to the actual sample space."""
51
- return z
52
-
53
- class Feature1DProcessor(SampleProcessor):
54
- def __init__(self, dim: int = 100, power_std = 1., \
55
- num_samples: int = 100_000, cal_num_frames: int = 600):
56
- super().__init__()
57
-
58
- self.num_samples = num_samples
59
- self.dim = dim
60
- self.power_std = power_std
61
- self.cal_num_frames = cal_num_frames
62
- self.register_buffer('counts', torch.zeros(1))
63
- self.register_buffer('sum_x', torch.zeros(dim))
64
- self.register_buffer('sum_x2', torch.zeros(dim))
65
- self.register_buffer('sum_target_x2', torch.zeros(dim))
66
- self.counts: torch.Tensor
67
- self.sum_x: torch.Tensor
68
- self.sum_x2: torch.Tensor
69
-
70
- @property
71
- def mean(self):
72
- mean = self.sum_x / self.counts
73
- if(self.counts < 10):
74
- mean = torch.zeros_like(mean)
75
- return mean
76
-
77
- @property
78
- def std(self):
79
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
80
- if(self.counts < 10):
81
- std = torch.ones_like(std)
82
- return std
83
-
84
- @property
85
- def target_std(self):
86
- return 1
87
-
88
- def project_sample(self, x: torch.Tensor):
89
- assert x.dim() == 3
90
- if self.counts.item() < self.num_samples:
91
- self.counts += len(x)
92
- self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0)
93
- self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0)
94
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
95
- x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1)
96
- return x
97
-
98
- def return_sample(self, x: torch.Tensor):
99
- assert x.dim() == 3
100
- rescale = (self.std / self.target_std) ** self.power_std
101
- # print(rescale, self.mean)
102
- x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1)
103
- return x
104
-
105
- def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77):
106
- if(prior_text_encoder_hidden_states.shape[1]<len_size):
107
- prior_text_encoder_hidden_states = torch.cat([prior_text_encoder_hidden_states, \
108
- torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], \
109
- prior_text_encoder_hidden_states.shape[2], device=prior_text_mask.device, \
110
- dtype=prior_text_encoder_hidden_states.dtype)],1)
111
- prior_text_mask = torch.cat([prior_text_mask, torch.zeros(prior_text_mask.shape[0], len_size-prior_text_mask.shape[1], device=prior_text_mask.device, dtype=prior_text_mask.dtype)],1)
112
- else:
113
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states[:,0:len_size]
114
- prior_text_mask = prior_text_mask[:,0:len_size]
115
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.permute(0,2,1).contiguous()
116
- return prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds
117
-
118
- class BASECFM(torch.nn.Module, ABC):
119
- def __init__(
120
- self,
121
- estimator,
122
- mlp,
123
- ssl_layer
124
- ):
125
- super().__init__()
126
- self.sigma_min = 1e-4
127
-
128
- self.estimator = estimator
129
- self.mlp = mlp
130
- self.ssl_layer = ssl_layer
131
-
132
- @torch.inference_mode()
133
- def forward(self, mu, n_timesteps, temperature=1.0):
134
- """Forward diffusion
135
-
136
- Args:
137
- mu (torch.Tensor): output of encoder
138
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
139
- n_timesteps (int): number of diffusion steps
140
- temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
141
-
142
- Returns:
143
- sample: generated mel-spectrogram
144
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
145
- """
146
- z = torch.randn_like(mu) * temperature
147
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
148
- return self.solve_euler(z, t_span=t_span)
149
-
150
- def solve_euler(self, x, latent_mask_input,incontext_x, incontext_length, t_span, mu,attention_mask, guidance_scale):
151
- """
152
- Fixed euler solver for ODEs.
153
- Args:
154
- x (torch.Tensor): random noise
155
- t_span (torch.Tensor): n_timesteps interpolated
156
- shape: (n_timesteps + 1,)
157
- mu (torch.Tensor): output of encoder
158
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
159
- """
160
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
161
- noise = x.clone()
162
-
163
- # I am storing this because I can later plot it by putting a debugger here and saving it to a file
164
- # Or in future might add like a return_all_steps flag
165
- sol = []
166
-
167
- for step in tqdm(range(1, len(t_span))):
168
- print("incontext_x.shape:",incontext_x.shape)
169
- print("noise.shape:",noise.shape)
170
- print("t.shape:",t.shape)
171
- x[:,0:incontext_length,:] = (1 - (1 - self.sigma_min) * t) * noise[:,0:incontext_length,:] + t * incontext_x[:,0:incontext_length,:]
172
- if(guidance_scale > 1.0):
173
-
174
- model_input = torch.cat([ \
175
- torch.cat([latent_mask_input, latent_mask_input], 0), \
176
- torch.cat([incontext_x, incontext_x], 0), \
177
- torch.cat([torch.zeros_like(mu), mu], 0), \
178
- torch.cat([x, x], 0), \
179
- ], 2)
180
- timestep=t.unsqueeze(-1).repeat(2)
181
-
182
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
183
- dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0)
184
- dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond)
185
- else:
186
- model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2)
187
- timestep=t.unsqueeze(-1)
188
- dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state
189
-
190
- dphi_dt = dphi_dt[: ,:, -x.shape[2]:]
191
- print("dphi_dt.shape:",dphi_dt.shape)
192
- print("x.shape:",x.shape)
193
-
194
- x = x + dt * dphi_dt
195
- t = t + dt
196
- sol.append(x)
197
- if step < len(t_span) - 1:
198
- dt = t_span[step + 1] - t
199
-
200
- return sol[-1]
201
-
202
- def projection_loss(self,hidden_proj, bestrq_emb):
203
- bsz = hidden_proj.shape[0]
204
-
205
- hidden_proj_normalized = F.normalize(hidden_proj, dim=-1)
206
- bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1)
207
-
208
- proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1)
209
- proj_loss = 1+proj_loss.mean()
210
-
211
- return proj_loss
212
-
213
- def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False):
214
- """Computes diffusion loss
215
-
216
- Args:
217
- x1 (torch.Tensor): Target
218
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
219
- mu (torch.Tensor): output of encoder
220
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
221
-
222
- Returns:
223
- loss: conditional flow matching loss
224
- y: conditional flow
225
- shape: (batch_size, n_channels, mel_timesteps, n_feats)
226
- """
227
- b = mu[0].shape[0]
228
- len_x = x1.shape[2]
229
- # random timestep
230
- if(validation_mode):
231
- t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5
232
- else:
233
- t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype)
234
- # sample noise p(x_0)
235
- z = torch.randn_like(x1)
236
-
237
- y = (1 - (1 - self.sigma_min) * t) * z + t * x1
238
- u = x1 - (1 - self.sigma_min) * z
239
- # print("y.shape:",y.shape)
240
- #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state
241
- model_input = torch.cat([*mu,y], 2)
242
- t=t.squeeze(-1).squeeze(-1)
243
- # print("model_input.shape:",model_input.shape)
244
- # print("attention_mask.shape:",attention_mask.shape)
245
- out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True)
246
- hidden_layer = out.hidden_states[self.ssl_layer]
247
- hidden_proj = self.mlp(hidden_layer)
248
- # print("hidden_proj.shape:",hidden_proj.shape)
249
- # print("mert_emb.shape:",mert_emb.shape)
250
- # exit()
251
-
252
-
253
- out = out.last_hidden_state
254
-
255
- out=out[:,:,-len_x:]
256
- # out=self.proj_out(out)
257
-
258
- weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01
259
- # print("out.shape",out.shape)
260
- # print("u.shape",u.shape)
261
- loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum()
262
- # print("hidden_proj.shape:",hidden_proj.shape)
263
- # print("wav2vec_embeds.shape:",wav2vec_embeds.shape)
264
- loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds)
265
- loss = loss_re + loss_cos * 0.5
266
- # print("loss_cos:",loss_cos,loss_cos.device)
267
- print("loss:",loss,loss.device)
268
- # exit()
269
- return loss, loss_re, loss_cos
270
-
271
- class PromptCondAudioDiffusion(nn.Module):
272
- def __init__(
273
- self,
274
- num_channels,
275
- unet_model_name=None,
276
- unet_model_config_path=None,
277
- snr_gamma=None,
278
- hubert_layer=None,
279
- ssl_layer=None,
280
- uncondition=True,
281
- out_paint=False,
282
- ):
283
- super().__init__()
284
-
285
- assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required"
286
-
287
- self.unet_model_name = unet_model_name
288
- self.unet_model_config_path = unet_model_config_path
289
- self.snr_gamma = snr_gamma
290
- self.uncondition = uncondition
291
- self.num_channels = num_channels
292
- self.hubert_layer = hubert_layer
293
- self.ssl_layer = ssl_layer
294
-
295
- # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview
296
- self.normfeat = Feature1DProcessor(dim=64)
297
-
298
- self.sample_rate = 48000
299
- self.num_samples_perseg = self.sample_rate * 20 // 1000
300
- self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000)
301
- self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
302
- # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
303
- # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
304
- self.bestrq = load_model(
305
- model_dir='path/to/our-MERT/mert_fairseq',
306
- checkpoint_dir='checkpoint-120000.pt',
307
- )
308
- self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
309
- self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
310
- for v in self.bestrq.parameters():v.requires_grad = False
311
- self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
312
- # for v in self.rvq_bestrq_emb.parameters():
313
- # print(v)
314
- freeze_parameters='quantizers.0'
315
- for name, param in self.rvq_bestrq_emb.named_parameters():
316
- if freeze_parameters in name:
317
- param.requires_grad = False
318
- print("Freezing RVQ parameters:", name)
319
- self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
320
- for v in self.hubert.parameters():v.requires_grad = False
321
- self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
322
- # self.xvecmodel = XVECModel()
323
- config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
324
- unet = GPT2Model(config)
325
- mlp = nn.Sequential(
326
- nn.Linear(1200, 1024),
327
- nn.SiLU(),
328
- nn.Linear(1024, 1024),
329
- nn.SiLU(),
330
- nn.Linear(1024, 768)
331
- )
332
- self.set_from = "random"
333
- self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
334
- self.mask_emb = torch.nn.Embedding(3, 48)
335
- print("Transformer initialized from pretrain.")
336
- torch.cuda.empty_cache()
337
- # self.unet.set_attn_processor(AttnProcessor2_0())
338
- # self.unet.set_use_memory_efficient_attention_xformers(True)
339
-
340
- # self.start_embedding = nn.Parameter(torch.randn(1,1024))
341
- # self.end_embedding = nn.Parameter(torch.randn(1,1024))
342
-
343
- def compute_snr(self, timesteps):
344
- """
345
- Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
346
- """
347
- alphas_cumprod = self.noise_scheduler.alphas_cumprod
348
- sqrt_alphas_cumprod = alphas_cumprod**0.5
349
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
350
-
351
- # Expand the tensors.
352
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
353
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
354
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
355
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
356
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
357
-
358
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
359
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
360
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
361
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
362
-
363
- # Compute SNR.
364
- snr = (alpha / sigma) ** 2
365
- return snr
366
-
367
- def preprocess_audio(self, input_audios, threshold=0.9):
368
- assert len(input_audios.shape) == 2, input_audios.shape
369
- norm_value = torch.ones_like(input_audios[:,0])
370
- max_volume = input_audios.abs().max(dim=-1)[0]
371
- norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
372
- return input_audios/norm_value.unsqueeze(-1)
373
-
374
- def extract_wav2vec_embeds(self, input_audios,output_len):
375
- wav2vec_stride = 2
376
-
377
- wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024
378
- # print(wav2vec_embeds)
379
- # print("audio.shape:",input_audios.shape)
380
- wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer]
381
- # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape)
382
- wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1)
383
- return wav2vec_embeds_last
384
-
385
- def extract_mert_embeds(self, input_audios):
386
- prompt_stride = 3
387
- inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt")
388
- input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype)
389
- prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024
390
- mert_emb= prompt_embeds[-1]
391
- mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1)
392
-
393
- return mert_emb
394
-
395
- def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer):
396
- self.bestrq.eval()
397
- # print("audio shape:",input_audio_0.shape)
398
- input_wav_mean = (input_audio_0 + input_audio_1) / 2.0
399
- # print("input_wav_mean.shape:",input_wav_mean.shape)
400
- # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device)
401
- input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True)
402
- layer_results = input_wav_mean['layer_results']
403
- # print("layer_results.shape:",layer_results[layer].shape)
404
- bestrq_emb = layer_results[layer]
405
- bestrq_emb = bestrq_emb.permute(0,2,1).contiguous()
406
- #[b,t,1024] t=t/960
407
- #35.84s->batch,896,1024
408
- return bestrq_emb
409
-
410
-
411
- def extract_spk_embeds(self, input_audios):
412
- spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios))
413
- spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32)
414
- return spk_embeds
415
-
416
- def extract_lyric_feats(self, lyric):
417
- with torch.no_grad():
418
- try:
419
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False)
420
- except:
421
- text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False)
422
- text_encoder_hidden_states = text_encoder_hidden_states.to(self.device)
423
- text_mask = text_mask.to(self.device)
424
- text_encoder_hidden_states, text_mask, text_prompt_embeds = \
425
- pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds)
426
- text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous()
427
- return text_encoder_hidden_states, text_mask
428
-
429
- def extract_energy_bar(self, input_audios):
430
- if(input_audios.shape[-1] % self.num_samples_perseg > 0):
431
- energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg)
432
- else:
433
- energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg)
434
- energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T
435
- energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int()
436
- energy_embedding = self.energy_embedding(energy_bar)
437
- energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t
438
- return energy_embedding
439
-
440
- def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \
441
- additional_feats = ['spk', 'lyric'], \
442
- train_rvq=True, train_ssl=False,layer=5):
443
- if not hasattr(self,"device"):
444
- self.device = input_audios.device
445
- if not hasattr(self,"dtype"):
446
- self.dtype = input_audios.dtype
447
- device = self.device
448
- input_audio_0 = input_audios[:,0,:]
449
- input_audio_1 = input_audios[:,1,:]
450
- input_audio_0 = self.preprocess_audio(input_audio_0)
451
- input_audio_1 = self.preprocess_audio(input_audio_1)
452
- input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0
453
- # energy_embedding = self.extract_energy_bar(input_audios)
454
- # print("energy_embedding.shape:",energy_embedding.shape)
455
- # with autocast(enabled=False):
456
- if(train_ssl):
457
- self.wav2vec.train()
458
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios)
459
- self.clap_embd_extractor.train()
460
- prompt_embeds = self.extract_mert_embeds(input_audios)
461
- if('spk' in additional_feats):
462
- self.xvecmodel.train()
463
- spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1)
464
- else:
465
- with torch.no_grad():
466
- with autocast(enabled=False):
467
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
468
- # mert_emb = self.extract_mert_embeds(input_audios_mert)
469
-
470
- wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2])
471
-
472
- bestrq_emb = bestrq_emb.detach()
473
- if('lyric' in additional_feats):
474
- text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric)
475
- else:
476
- text_encoder_hidden_states, text_mask = None, None
477
-
478
-
479
- if(train_rvq):
480
- random_num=random.random()
481
- if(random_num<0.6):
482
- rvq_layer = 1
483
- elif(random_num<0.8):
484
- rvq_layer = 2
485
- else:
486
- rvq_layer = 4
487
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t
488
- else:
489
- bestrq_emb = bestrq_emb.float()
490
- self.rvq_bestrq_emb.eval()
491
- # with autocast(enabled=False):
492
- quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t
493
- commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach()
494
- codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach()
495
- quantized_bestrq_emb = quantized_bestrq_emb.detach()
496
-
497
- commitment_loss = commitment_loss_bestrq_emb
498
- codebook_loss = codebook_loss_bestrq_emb
499
-
500
-
501
- alpha=1
502
- quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha)
503
-
504
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
505
- # print("latent_masks.shape:",latent_masks.shape)
506
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
507
-
508
-
509
-
510
- scenario = np.random.choice(['start_seg', 'other_seg'])
511
- if(scenario == 'other_seg'):
512
- for binx in range(input_audios.shape[0]):
513
- # latent_masks[binx,0:64] = 1
514
- latent_masks[binx,0:random.randint(64,128)] = 1
515
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
516
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
517
- # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape)
518
- # print("latent_masks.shape:",latent_masks.shape)
519
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
520
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
521
-
522
-
523
-
524
-
525
- if self.uncondition:
526
- mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1]
527
- if len(mask_indices) > 0:
528
- quantized_bestrq_emb[mask_indices] = 0
529
- # print("latents.shape:",latents.shape)
530
- latents = latents.permute(0,2,1).contiguous()
531
- latents = self.normfeat.project_sample(latents)
532
- latents = latents.permute(0,2,1).contiguous()
533
- incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
534
- attention_mask=(latent_masks > 0.5)
535
- B, L = attention_mask.size()
536
- attention_mask = attention_mask.view(B, 1, L)
537
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
538
- attention_mask = attention_mask.unsqueeze(1)
539
- # print("incontext_latents.shape:",incontext_latents.shape)
540
- # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
541
- latent_mask_input = self.mask_emb(latent_masks)
542
- #64+48+64+1024
543
- loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode)
544
- return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean()
545
-
546
- def init_device_dtype(self, device, dtype):
547
- self.device = device
548
- self.dtype = dtype
549
-
550
- @torch.no_grad()
551
- def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1):
552
- input_audio_0 = input_audios[[0],:]
553
- input_audio_1 = input_audios[[1],:]
554
- input_audio_0 = self.preprocess_audio(input_audio_0)
555
- input_audio_1 = self.preprocess_audio(input_audio_1)
556
-
557
- self.bestrq.eval()
558
-
559
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
560
- # bestrq_middle = bestrq_middle.detach()
561
- # bestrq_last = bestrq_last.detach()
562
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
563
- bestrq_emb = bestrq_emb.detach()
564
-
565
- # self.rvq_bestrq_middle.eval()
566
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
567
- # self.rvq_bestrq_last.eval()
568
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
569
-
570
- self.rvq_bestrq_emb.eval()
571
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
572
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
573
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
574
- # exit()
575
-
576
-
577
- if('spk' in additional_feats):
578
- self.xvecmodel.eval()
579
- spk_embeds = self.extract_spk_embeds(input_audios)
580
- else:
581
- spk_embeds = None
582
-
583
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
584
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
585
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
586
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
587
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
588
-
589
- @torch.no_grad()
590
- def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1):
591
- input_audio_0 = input_audios[:,0,:]
592
- input_audio_1 = input_audios[:,1,:]
593
- input_audio_0 = self.preprocess_audio(input_audio_0)
594
- input_audio_1 = self.preprocess_audio(input_audio_1)
595
-
596
- self.bestrq.eval()
597
-
598
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
599
- # bestrq_middle = bestrq_middle.detach()
600
- # bestrq_last = bestrq_last.detach()
601
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
602
- bestrq_emb = bestrq_emb.detach()
603
-
604
- # self.rvq_bestrq_middle.eval()
605
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
606
- # self.rvq_bestrq_last.eval()
607
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
608
-
609
- self.rvq_bestrq_emb.eval()
610
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
611
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
612
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
613
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
614
- # exit()
615
-
616
-
617
- if('spk' in additional_feats):
618
- self.xvecmodel.eval()
619
- spk_embeds = self.extract_spk_embeds(input_audios)
620
- else:
621
- spk_embeds = None
622
-
623
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
624
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
625
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
626
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
627
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
628
-
629
- @torch.no_grad()
630
- def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250):
631
- input_audio_0 = input_audios[:,0,:]
632
- input_audio_1 = input_audios[:,1,:]
633
- input_audio_0 = self.preprocess_audio(input_audio_0)
634
- input_audio_1 = self.preprocess_audio(input_audio_1)
635
-
636
- self.bestrq.eval()
637
-
638
- # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
639
- # bestrq_middle = bestrq_middle.detach()
640
- # bestrq_last = bestrq_last.detach()
641
- bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer)
642
- bestrq_emb = bestrq_emb.detach()
643
-
644
- # self.rvq_bestrq_middle.eval()
645
- # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t
646
- # self.rvq_bestrq_last.eval()
647
- # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t
648
-
649
- self.rvq_bestrq_emb.eval()
650
- bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds)
651
- quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb)
652
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
653
- codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:]
654
- # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape)
655
- # exit()
656
-
657
-
658
- if('spk' in additional_feats):
659
- self.xvecmodel.eval()
660
- spk_embeds = self.extract_spk_embeds(input_audios)
661
- else:
662
- spk_embeds = None
663
-
664
- # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds
665
- # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds
666
- # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds
667
- return [codes_bestrq_emb], [bestrq_emb], spk_embeds
668
- # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds
669
-
670
- @torch.no_grad()
671
- def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127,
672
- guidance_scale=2, num_steps=20,
673
- disable_progress=True, scenario='start_seg'):
674
- classifier_free_guidance = guidance_scale > 1.0
675
- device = self.device
676
- dtype = self.dtype
677
- # codes_bestrq_middle, codes_bestrq_last = codes
678
- codes_bestrq_emb = codes[0]
679
-
680
-
681
- batch_size = codes_bestrq_emb.shape[0]
682
-
683
-
684
- quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb)
685
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
686
- quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous()
687
- print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape)
688
- # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True)
689
-
690
-
691
-
692
-
693
- if('spk' in additional_feats):
694
- spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach()
695
-
696
- num_frames = quantized_bestrq_emb.shape[1]
697
-
698
- num_channels_latents = self.num_channels
699
- shape = (batch_size, num_frames, 64)
700
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
701
-
702
-
703
-
704
- latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device)
705
- latent_masks[:,0:latent_length] = 2
706
- if(scenario=='other_seg'):
707
- latent_masks[:,0:incontext_length] = 1
708
-
709
-
710
-
711
- quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \
712
- + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024)
713
- true_latents = true_latents.permute(0,2,1).contiguous()
714
- true_latents = self.normfeat.project_sample(true_latents)
715
- true_latents = true_latents.permute(0,2,1).contiguous()
716
- incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float()
717
- incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0]
718
-
719
-
720
- attention_mask=(latent_masks > 0.5)
721
- B, L = attention_mask.size()
722
- attention_mask = attention_mask.view(B, 1, L)
723
- attention_mask = attention_mask * attention_mask.transpose(-1, -2)
724
- attention_mask = attention_mask.unsqueeze(1)
725
- latent_mask_input = self.mask_emb(latent_masks)
726
-
727
- if('spk' in additional_feats):
728
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1)
729
- additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1)
730
- else:
731
- # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1)
732
- additional_model_input = torch.cat([quantized_bestrq_emb],1)
733
-
734
- temperature = 1.0
735
- t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device)
736
- latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale)
737
-
738
- latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:]
739
- latents = latents.permute(0,2,1).contiguous()
740
- latents = self.normfeat.return_sample(latents)
741
- # latents = latents.permute(0,2,1).contiguous()
742
- return latents
743
-
744
- @torch.no_grad()
745
- def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
746
- disable_progress=True,layer=5,scenario='start_seg',rvq_num=1):
747
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num)
748
-
749
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
750
- guidance_scale=guidance_scale, num_steps=num_steps, \
751
- disable_progress=disable_progress,scenario=scenario)
752
- return latents
753
-
754
- @torch.no_grad()
755
- def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20,
756
- disable_progress=True,layer=5,scenario='start_seg'):
757
- codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer)
758
- import time
759
- start = time.time()
760
- latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \
761
- guidance_scale=guidance_scale, num_steps=num_steps, \
762
- disable_progress=disable_progress,scenario=scenario)
763
- return latents,time.time()-start
764
-
765
- def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device):
766
- divisor = 4
767
- shape = (batch_size, num_channels_latents, num_frames, 32)
768
- if(num_frames%divisor>0):
769
- num_frames = round(num_frames/float(divisor))*divisor
770
- shape = (batch_size, num_channels_latents, num_frames, 32)
771
- latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
772
- return latents
773
-
774
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/model_septoken.py CHANGED
@@ -252,8 +252,6 @@ class PromptCondAudioDiffusion(nn.Module):
252
  unet_model_config_path=None,
253
  snr_gamma=None,
254
  uncondition=True,
255
- out_paint=False,
256
- ssl_path='ckpt/encode-s12k.pt'
257
  ):
258
  super().__init__()
259
 
 
252
  unet_model_config_path=None,
253
  snr_gamma=None,
254
  uncondition=True,
 
 
255
  ):
256
  super().__init__()
257
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml DELETED
@@ -1,122 +0,0 @@
1
- # @package _group_
2
-
3
- common:
4
- fp16: true
5
- log_format: json
6
- log_interval: 200
7
- tensorboard_logdir: tb
8
- min_loss_scale: 1e-6
9
- fp16_no_flatten_grads: true
10
- user_dir: ${env:PWD}
11
- seed: 1
12
-
13
- checkpoint:
14
- save_interval: 1
15
- save_interval_updates: 10000
16
- keep_interval_updates: 1
17
- no_epoch_checkpoints: true
18
-
19
- task:
20
- _name: mae_image_pretraining
21
- data: unbalanced_train
22
- rebuild_batches: true
23
- key: source
24
- precompute_mask_config: {}
25
- downsr_16hz: true
26
- audio_mae: true
27
- h5_format: false
28
- target_length: 1024
29
- flexible_mask: false
30
-
31
- dataset:
32
- num_workers: 10
33
- batch_size: 12
34
- skip_invalid_size_inputs_valid_test: true
35
- required_batch_size_multiple: 1
36
- disable_validation: true
37
-
38
- distributed_training:
39
- distributed_world_size: 4
40
- ddp_backend: c10d
41
-
42
- criterion:
43
- _name: model
44
- log_keys:
45
- - ema_decay
46
- - target_var
47
- - pred_var
48
- - model_norm
49
- - ema_norm
50
- - masked_pct
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [ 0.0005 ]
55
- debug_param_names: true
56
- clip_norm: 4
57
-
58
- optimizer:
59
- _name: composite
60
- dynamic_groups: true
61
- groups:
62
- default:
63
- lr_float: 0.0005
64
- optimizer:
65
- _name: adam
66
- adam_betas: [0.9,0.95]
67
- weight_decay: 0.05
68
- lr_scheduler:
69
- _name: cosine
70
- warmup_updates: 53333
71
-
72
- lr_scheduler: pass_through
73
-
74
- model:
75
- _name: data2vec_multi
76
-
77
- ema_decay: 0.9998
78
- ema_end_decay: 0.99999
79
- ema_anneal_end_step: 100000
80
- instance_norm_target_layer: true
81
- layer_norm_target_layer: false
82
- layer_norm_targets: true
83
- end_of_block_targets: false
84
-
85
- depth: 12
86
- average_top_k_layers: 12
87
- clone_batch: 16
88
-
89
- norm_eps: 1e-6
90
-
91
- min_target_var: 0
92
- min_pred_var: 0
93
-
94
- encoder_dropout: 0
95
- post_mlp_drop: 0
96
- attention_dropout: 0
97
- activation_dropout: 0
98
-
99
- supported_modality: IMAGE
100
- cls_loss: 1
101
-
102
- ema_encoder_only: false
103
-
104
- modalities:
105
- image:
106
- in_chans: 1
107
- inverse_mask: true
108
- mask_prob: 0.8
109
- mask_prob_adjust: 0.07
110
- mask_length: 5
111
- mask_noise_std: 0.01
112
- prenet_depth: 0
113
- ema_local_encoder: true
114
- num_extra_tokens: 1
115
- init_extra_token_zero: false
116
- use_alibi_encoder: false
117
- decoder:
118
- decoder_dim: 768
119
- decoder_groups: 16
120
- decoder_kernel: 3
121
- decoder_layers: 6
122
- input_dropout: 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml DELETED
@@ -1,125 +0,0 @@
1
- # @package _group_
2
-
3
- common:
4
- fp16: true
5
- log_format: json
6
- log_interval: 200
7
- tensorboard_logdir: tb
8
- min_loss_scale: 1e-6
9
- fp16_no_flatten_grads: true
10
- user_dir: ${env:PWD}
11
- seed: 1
12
-
13
- checkpoint:
14
- save_interval: 1
15
- save_interval_updates: 10000
16
- keep_interval_updates: 1000
17
- no_epoch_checkpoints: true
18
-
19
- task:
20
- _name: mae_image_pretraining
21
- data: music4all_sh/
22
- rebuild_batches: true
23
- key: source
24
- precompute_mask_config: {}
25
- downsr_16hz: false
26
- audio_mae: true
27
- h5_format: false
28
- target_length: 752
29
- flexible_mask: false
30
- sample_rate: 24000
31
- fixed_duration: 30
32
-
33
- dataset:
34
- num_workers: 10
35
- batch_size: 12
36
- skip_invalid_size_inputs_valid_test: true
37
- required_batch_size_multiple: 1
38
- disable_validation: true
39
-
40
- distributed_training:
41
- distributed_world_size: 4
42
- ddp_backend: c10d
43
-
44
- criterion:
45
- _name: model
46
- log_keys:
47
- - ema_decay
48
- - target_var
49
- - pred_var
50
- - model_norm
51
- - ema_norm
52
- - masked_pct
53
-
54
- optimization:
55
- max_update: 400000
56
- lr: [ 0.0001 ]
57
- # debug_param_names: true
58
- clip_norm: 4
59
-
60
- optimizer:
61
- _name: composite
62
- # dynamic_groups: true
63
- groups:
64
- default:
65
- lr_float: 0.0005
66
- optimizer:
67
- _name: adam
68
- adam_betas: [0.9,0.95]
69
- weight_decay: 0.05
70
- lr_scheduler:
71
- _name: cosine
72
- warmup_updates: 10000 # 53333
73
-
74
- lr_scheduler: pass_through
75
-
76
- model:
77
- _name: data2vec_multi
78
-
79
- ema_decay: 0.9998
80
- ema_end_decay: 0.99999
81
- ema_anneal_end_step: 100000
82
- instance_norm_target_layer: true
83
- layer_norm_target_layer: false
84
- layer_norm_targets: true
85
- end_of_block_targets: false
86
-
87
- depth: 12
88
- average_top_k_layers: 12
89
- clone_batch: 16
90
-
91
- norm_eps: 1e-6
92
-
93
- min_target_var: 0
94
- min_pred_var: 0
95
-
96
- encoder_dropout: 0
97
- post_mlp_drop: 0
98
- attention_dropout: 0
99
- activation_dropout: 0
100
-
101
- supported_modality: IMAGE
102
- cls_loss: 1
103
-
104
- ema_encoder_only: false
105
-
106
- modalities:
107
- image:
108
- in_chans: 1
109
- inverse_mask: true
110
- mask_prob: 0.8
111
- mask_prob_adjust: 0.07
112
- mask_length: 5
113
- mask_noise_std: 0.01
114
- prenet_depth: 0
115
- ema_local_encoder: true
116
- num_extra_tokens: 1
117
- init_extra_token_zero: false
118
- use_alibi_encoder: false
119
- decoder:
120
- decoder_dim: 768
121
- decoder_groups: 16
122
- decoder_kernel: 3
123
- decoder_layers: 6
124
- input_dropout: 0
125
- target_length: 752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml DELETED
@@ -1,137 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 100
6
- seed: 1337
7
-
8
- # tensorboard_logdir: tblog_proj_name
9
- # wandb_project: wandb_proj_name
10
-
11
- checkpoint:
12
- save_interval_updates: 5000
13
- keep_interval_updates: -1
14
- no_epoch_checkpoints: true
15
-
16
-
17
- distributed_training:
18
- ddp_backend: no_c10d
19
- distributed_backend: 'nccl'
20
- distributed_world_size: 64
21
- nprocs_per_node: 8
22
- find_unused_parameters: true
23
- # reset-dataloader: true
24
-
25
- task:
26
- _name: mert_pretraining
27
- data: ???
28
- label_dir: ???
29
- labels: ???
30
- label_rate: ${model.label_rate}
31
- sharding_data: -1 #数据分块
32
- load_random_data_shard: false
33
- sample_rate: 24000
34
- # crop to 5s
35
- # max_sample_size: 120000
36
- # crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
37
- max_sample_size: 122880
38
- min_sample_size: 72000
39
-
40
- pad_audio: false
41
- random_crop: true
42
- # normalize: true # must be consistent with extractor_mode: layer_norm
43
- normalize: false # must be consistent with extractor_mode: default (groupnorm)
44
-
45
-
46
- dataset:
47
- num_workers: 6
48
- max_tokens: 900000
49
- skip_invalid_size_inputs_valid_test: true
50
- validate_interval: 1
51
- validate_interval_updates: 10000
52
-
53
- criterion:
54
- _name: hubert
55
- pred_masked_weight: 1.0
56
- pred_nomask_weight: 0.0
57
- loss_weights: [10, 1]
58
-
59
- optimization:
60
- max_update: 1000000
61
- lr: [0.0015]
62
- clip_norm: 1.0
63
- update_freq: [8]
64
-
65
- optimizer:
66
- _name: adam
67
- adam_betas: (0.9,0.98)
68
- adam_eps: 1e-06
69
- weight_decay: 0.01
70
-
71
- lr_scheduler:
72
- _name: polynomial_decay
73
- warmup_updates: 32000
74
-
75
- model:
76
- _name: mert
77
- label_rate: ???
78
- skip_masked: false
79
- skip_nomask: true
80
- mask_prob: 0.8
81
- mask_length: 5
82
-
83
- logit_temp: 0.1
84
-
85
-
86
- # ----- mixture ------
87
- mixture_prob: 0.5
88
- inbatch_noise_augment_len_range: "[12000, 36000]"
89
- inbatch_noise_augment_number_range: "[1, 3]"
90
- inbatch_noise_augment_volume: 1.0
91
- # ------------------------
92
-
93
- # ---- cqt reconstruction, need to add loss weight ---
94
- audio_cqt_loss_m: true
95
- audio_cqt_bins: 336
96
-
97
- final_dim: 128
98
- encoder_layers: 24
99
- encoder_embed_dim: 1024
100
- encoder_ffn_embed_dim: 4096
101
- encoder_attention_heads: 16
102
- # default refers to group norm
103
- extractor_mode: default
104
- # extractor_mode: layer_norm
105
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
106
- encoder_layerdrop: 0.0
107
- dropout_input: 0.0
108
- dropout_features: 0.0
109
- dropout: 0.0
110
- attention_dropout: 0.0
111
-
112
- layer_norm_first: true
113
- feature_grad_mult: 1.0
114
-
115
- untie_final_proj: true
116
- activation_dropout: 0.0
117
-
118
- deepnorm: false
119
- attention_relax: 32.0
120
-
121
-
122
-
123
- hydra:
124
- job:
125
- config:
126
- override_dirname:
127
- kv_sep: '-'
128
- item_sep: '__'
129
- exclude_keys:
130
- - run
131
- - task.data
132
- - task.label_dir
133
- run:
134
- dir: ???
135
- sweep:
136
- dir: ???
137
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml DELETED
@@ -1,139 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 100
6
- seed: 1337
7
- # model_parallel_size: 8
8
- # amp: true
9
-
10
- # tensorboard_logdir: tblog_proj_name
11
- # wandb_project: wandb_proj_name
12
-
13
- checkpoint:
14
- save_interval_updates: 5000
15
- keep_interval_updates: -1
16
- no_epoch_checkpoints: true
17
-
18
-
19
- distributed_training:
20
- ddp_backend: c10d
21
- distributed_backend: 'nccl'
22
- distributed_world_size: 64
23
- nprocs_per_node: 8
24
- find_unused_parameters: true
25
- # reset-dataloader: true
26
-
27
- task:
28
- _name: mert_pretraining
29
- data: ???
30
- label_dir: ???
31
- labels: ???
32
- label_rate: ${model.label_rate}
33
- sharding_data: -1 #数据分块
34
- load_random_data_shard: false
35
- sample_rate: 24000
36
- # crop to 5s
37
- # max_sample_size: 120000
38
- # crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
39
- max_sample_size: 122880
40
- min_sample_size: 72000
41
-
42
- pad_audio: false
43
- random_crop: true
44
- # normalize: true # must be consistent with extractor_mode: layer_norm
45
- normalize: false # must be consistent with extractor_mode: default (groupnorm)
46
-
47
-
48
- dataset:
49
- num_workers: 6
50
- max_tokens: 900000
51
- skip_invalid_size_inputs_valid_test: true
52
- validate_interval: 1
53
- validate_interval_updates: 10000
54
-
55
- criterion:
56
- _name: hubert
57
- pred_masked_weight: 1.0
58
- pred_nomask_weight: 0.0
59
- loss_weights: [10, 1]
60
-
61
- optimization:
62
- max_update: 1000000
63
- lr: [0.0015]
64
- clip_norm: 1.0
65
- update_freq: [8]
66
-
67
- optimizer:
68
- _name: adam
69
- adam_betas: (0.9,0.98)
70
- adam_eps: 1e-06
71
- weight_decay: 0.01
72
-
73
- lr_scheduler:
74
- _name: polynomial_decay
75
- warmup_updates: 32000
76
-
77
- model:
78
- _name: mert
79
- label_rate: ???
80
- skip_masked: false
81
- skip_nomask: true
82
- mask_prob: 0.8
83
- mask_length: 5
84
-
85
- logit_temp: 0.1
86
-
87
-
88
- # ----- mixture ------
89
- mixture_prob: 0.5
90
- inbatch_noise_augment_len_range: "[12000, 36000]"
91
- inbatch_noise_augment_number_range: "[1, 3]"
92
- inbatch_noise_augment_volume: 1.0
93
- # ------------------------
94
-
95
- # ---- cqt reconstruction, need to add loss weight ---
96
- audio_cqt_loss_m: true
97
- audio_cqt_bins: 336
98
-
99
- final_dim: 128
100
- encoder_layers: 24
101
- encoder_embed_dim: 1024
102
- encoder_ffn_embed_dim: 4096
103
- encoder_attention_heads: 16
104
- # default refers to group norm
105
- extractor_mode: default
106
- # extractor_mode: layer_norm
107
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
108
- encoder_layerdrop: 0.0
109
- dropout_input: 0.0
110
- dropout_features: 0.0
111
- dropout: 0.0
112
- attention_dropout: 0.0
113
-
114
- layer_norm_first: true
115
- feature_grad_mult: 1.0
116
-
117
- untie_final_proj: true
118
- activation_dropout: 0.0
119
-
120
- deepnorm: false
121
- attention_relax: 32.0
122
-
123
-
124
-
125
- hydra:
126
- job:
127
- config:
128
- override_dirname:
129
- kv_sep: '-'
130
- item_sep: '__'
131
- exclude_keys:
132
- - run
133
- - task.data
134
- - task.label_dir
135
- run:
136
- dir: run
137
- sweep:
138
- dir: sweep
139
- subdir: subdir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml DELETED
@@ -1,138 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 100
6
- seed: 1337
7
- # amp: true
8
-
9
- # tensorboard_logdir: tblog_proj_name
10
- # wandb_project: wandb_proj_name
11
-
12
- checkpoint:
13
- save_interval_updates: 5000
14
- keep_interval_updates: -1
15
- no_epoch_checkpoints: true
16
-
17
-
18
- distributed_training:
19
- ddp_backend: c10d
20
- distributed_backend: 'nccl'
21
- distributed_world_size: 64
22
- nprocs_per_node: 8
23
- find_unused_parameters: true
24
- # reset-dataloader: true
25
-
26
- task:
27
- _name: mert_pretraining
28
- data: ???
29
- label_dir: ???
30
- labels: ???
31
- label_rate: ${model.label_rate}
32
- sharding_data: -1 #数据分块
33
- load_random_data_shard: false
34
- sample_rate: 24000
35
- # crop to 5s
36
- # max_sample_size: 120000
37
- # crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
38
- max_sample_size: 122880
39
- min_sample_size: 72000
40
-
41
- pad_audio: false
42
- random_crop: true
43
- # normalize: true # must be consistent with extractor_mode: layer_norm
44
- normalize: false # must be consistent with extractor_mode: default (groupnorm)
45
-
46
-
47
- dataset:
48
- num_workers: 6
49
- max_tokens: 900000
50
- skip_invalid_size_inputs_valid_test: true
51
- validate_interval: 1
52
- validate_interval_updates: 10000
53
-
54
- criterion:
55
- _name: hubert
56
- pred_masked_weight: 1.0
57
- pred_nomask_weight: 0.0
58
- loss_weights: [10, 1]
59
-
60
- optimization:
61
- max_update: 1000000
62
- lr: [0.0015]
63
- clip_norm: 1.0
64
- update_freq: [8]
65
-
66
- optimizer:
67
- _name: adam
68
- adam_betas: (0.9,0.98)
69
- adam_eps: 1e-06
70
- weight_decay: 0.01
71
-
72
- lr_scheduler:
73
- _name: polynomial_decay
74
- warmup_updates: 32000
75
-
76
- model:
77
- _name: mert
78
- label_rate: ???
79
- skip_masked: false
80
- skip_nomask: true
81
- mask_prob: 0.8
82
- mask_length: 5
83
-
84
- logit_temp: 0.1
85
-
86
-
87
- # ----- mixture ------
88
- mixture_prob: 0.5
89
- inbatch_noise_augment_len_range: "[12000, 36000]"
90
- inbatch_noise_augment_number_range: "[1, 3]"
91
- inbatch_noise_augment_volume: 1.0
92
- # ------------------------
93
-
94
- # ---- cqt reconstruction, need to add loss weight ---
95
- audio_cqt_loss_m: true
96
- audio_cqt_bins: 336
97
-
98
- final_dim: 128
99
- encoder_layers: 24
100
- encoder_embed_dim: 1024
101
- encoder_ffn_embed_dim: 4096
102
- encoder_attention_heads: 16
103
- # default refers to group norm
104
- extractor_mode: default
105
- # extractor_mode: layer_norm
106
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
107
- encoder_layerdrop: 0.0
108
- dropout_input: 0.0
109
- dropout_features: 0.0
110
- dropout: 0.0
111
- attention_dropout: 0.0
112
-
113
- layer_norm_first: true
114
- feature_grad_mult: 1.0
115
-
116
- untie_final_proj: true
117
- activation_dropout: 0.0
118
-
119
- deepnorm: false
120
- attention_relax: 32.0
121
-
122
-
123
-
124
- hydra:
125
- job:
126
- config:
127
- override_dirname:
128
- kv_sep: '-'
129
- item_sep: '__'
130
- exclude_keys:
131
- - run
132
- - task.data
133
- - task.label_dir
134
- run:
135
- dir: run
136
- sweep:
137
- dir: sweep
138
- subdir: subdir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml DELETED
@@ -1,139 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 100
6
- seed: 1337
7
- model_parallel_size: 8
8
- # amp: true
9
-
10
- # tensorboard_logdir: tblog_proj_name
11
- # wandb_project: wandb_proj_name
12
-
13
- checkpoint:
14
- save_interval_updates: 5000
15
- keep_interval_updates: -1
16
- no_epoch_checkpoints: true
17
-
18
-
19
- distributed_training:
20
- ddp_backend: c10d
21
- distributed_backend: 'nccl'
22
- distributed_world_size: 64
23
- nprocs_per_node: 8
24
- find_unused_parameters: true
25
- # reset-dataloader: true
26
-
27
- task:
28
- _name: mert_pretraining
29
- data: ???
30
- label_dir: ???
31
- labels: ???
32
- label_rate: ${model.label_rate}
33
- sharding_data: -1 #数据分块
34
- load_random_data_shard: false
35
- sample_rate: 24000
36
- # crop to 5s
37
- # max_sample_size: 120000
38
- # crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
39
- max_sample_size: 122880
40
- min_sample_size: 72000
41
-
42
- pad_audio: false
43
- random_crop: true
44
- # normalize: true # must be consistent with extractor_mode: layer_norm
45
- normalize: false # must be consistent with extractor_mode: default (groupnorm)
46
-
47
-
48
- dataset:
49
- num_workers: 6
50
- max_tokens: null
51
- skip_invalid_size_inputs_valid_test: true
52
- validate_interval: 1
53
- validate_interval_updates: 10000
54
-
55
- criterion:
56
- _name: hubert
57
- pred_masked_weight: 1.0
58
- pred_nomask_weight: 0.0
59
- loss_weights: [10, 1]
60
-
61
- optimization:
62
- max_update: 1000000
63
- lr: [0.0015]
64
- clip_norm: 1.0
65
- update_freq: [8]
66
-
67
- optimizer:
68
- _name: adam
69
- adam_betas: (0.9,0.98)
70
- adam_eps: 1e-06
71
- weight_decay: 0.01
72
-
73
- lr_scheduler:
74
- _name: polynomial_decay
75
- warmup_updates: 32000
76
-
77
- model:
78
- _name: mert
79
- label_rate: ???
80
- skip_masked: false
81
- skip_nomask: true
82
- mask_prob: 0.8
83
- mask_length: 5
84
-
85
- logit_temp: 0.1
86
-
87
-
88
- # ----- mixture ------
89
- mixture_prob: 0.5
90
- inbatch_noise_augment_len_range: "[12000, 36000]"
91
- inbatch_noise_augment_number_range: "[1, 3]"
92
- inbatch_noise_augment_volume: 1.0
93
- # ------------------------
94
-
95
- # ---- cqt reconstruction, need to add loss weight ---
96
- audio_cqt_loss_m: true
97
- audio_cqt_bins: 336
98
-
99
- final_dim: 128
100
- encoder_layers: 24
101
- encoder_embed_dim: 1024
102
- encoder_ffn_embed_dim: 4096
103
- encoder_attention_heads: 16
104
- # default refers to group norm
105
- extractor_mode: default
106
- # extractor_mode: layer_norm
107
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
108
- encoder_layerdrop: 0.0
109
- dropout_input: 0.0
110
- dropout_features: 0.0
111
- dropout: 0.0
112
- attention_dropout: 0.0
113
-
114
- layer_norm_first: true
115
- feature_grad_mult: 1.0
116
-
117
- untie_final_proj: true
118
- activation_dropout: 0.0
119
-
120
- deepnorm: false
121
- attention_relax: 32.0
122
-
123
-
124
-
125
- hydra:
126
- job:
127
- config:
128
- override_dirname:
129
- kv_sep: '-'
130
- item_sep: '__'
131
- exclude_keys:
132
- - run
133
- - task.data
134
- - task.label_dir
135
- run:
136
- dir: run
137
- sweep:
138
- dir: sweep
139
- subdir: subdir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml DELETED
@@ -1,135 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: true
4
- log_format: json
5
- log_interval: 100
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 5000
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sharding_data: 6
30
- load_random_data_shard: false
31
- sample_rate: 24000
32
- # crop to 5s
33
- # max_sample_size: 120000
34
- # crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
35
- max_sample_size: 122880
36
- min_sample_size: 72000
37
-
38
- pad_audio: false
39
- random_crop: true
40
- # normalize: true # must be consistent with extractor_mode: layer_norm
41
- normalize: false # must be consistent with extractor_mode: default (groupnorm)
42
-
43
-
44
- dataset:
45
- num_workers: 6
46
- max_tokens: 900000
47
- skip_invalid_size_inputs_valid_test: true
48
- validate_interval: 1
49
- validate_interval_updates: 10000
50
-
51
- criterion:
52
- _name: hubert
53
- pred_masked_weight: 1.0
54
- pred_nomask_weight: 0.0
55
- loss_weights: [10, 1]
56
-
57
- optimization:
58
- max_update: 400000
59
- lr: [0.0015]
60
- clip_norm: 1.0
61
- update_freq: [8]
62
-
63
- optimizer:
64
- _name: adam
65
- adam_betas: (0.9,0.98)
66
- adam_eps: 1e-06
67
- weight_decay: 0.01
68
-
69
- lr_scheduler:
70
- _name: polynomial_decay
71
- warmup_updates: 32000
72
-
73
- model:
74
- _name: mert
75
- label_rate: ???
76
- skip_masked: false
77
- skip_nomask: true
78
- mask_prob: 0.8
79
- mask_length: 5
80
-
81
- logit_temp: 0.1
82
-
83
-
84
- # ----- mixture ------
85
- mixture_prob: 0.5
86
- inbatch_noise_augment_len_range: "[12000, 36000]"
87
- inbatch_noise_augment_number_range: "[1, 3]"
88
- inbatch_noise_augment_volume: 1.0
89
- # ------------------------
90
-
91
- # ---- cqt reconstruction, need to add loss weight ---
92
- audio_cqt_loss_m: true
93
- audio_cqt_bins: 336
94
-
95
- final_dim: 128
96
- encoder_layers: 24
97
- encoder_embed_dim: 1024
98
- encoder_ffn_embed_dim: 4096
99
- encoder_attention_heads: 16
100
- # default refers to group norm
101
- extractor_mode: default
102
- # extractor_mode: layer_norm
103
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
104
- encoder_layerdrop: 0.0
105
- dropout_input: 0.0
106
- dropout_features: 0.0
107
- dropout: 0.0
108
- attention_dropout: 0.0
109
-
110
- layer_norm_first: true
111
- feature_grad_mult: 1.0
112
-
113
- untie_final_proj: true
114
- activation_dropout: 0.0
115
-
116
- deepnorm: false
117
- attention_relax: 32.0
118
-
119
-
120
-
121
- hydra:
122
- job:
123
- config:
124
- override_dirname:
125
- kv_sep: '-'
126
- item_sep: '__'
127
- exclude_keys:
128
- - run
129
- - task.data
130
- - task.label_dir
131
- run:
132
- dir: ???
133
- sweep:
134
- dir: ???
135
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml DELETED
@@ -1,137 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: true
4
- log_format: json
5
- log_interval: 100
6
- seed: 1337
7
-
8
- # tensorboard_logdir: tblog_proj_name
9
- # wandb_project: wandb_proj_name
10
-
11
- checkpoint:
12
- save_interval_updates: 5000
13
- keep_interval_updates: -1
14
- no_epoch_checkpoints: true
15
-
16
-
17
- distributed_training:
18
- ddp_backend: no_c10d
19
- distributed_backend: 'nccl'
20
- distributed_world_size: 64
21
- nprocs_per_node: 8
22
- find_unused_parameters: true
23
- # reset-dataloader: true
24
-
25
- task:
26
- _name: mert_pretraining
27
- data: ???
28
- label_dir: ???
29
- labels: ???
30
- label_rate: ${model.label_rate}
31
- sharding_data: -1 #数据分块
32
- load_random_data_shard: false
33
- sample_rate: 24000
34
- # crop to 5s
35
- # max_sample_size: 120000
36
- # crop to 5.12s, refers to 384 token per audio, which can be devided by 8.
37
- max_sample_size: 122880
38
- min_sample_size: 72000
39
-
40
- pad_audio: false
41
- random_crop: true
42
- # normalize: true # must be consistent with extractor_mode: layer_norm
43
- normalize: false # must be consistent with extractor_mode: default (groupnorm)
44
-
45
-
46
- dataset:
47
- num_workers: 6
48
- max_tokens: 900000
49
- skip_invalid_size_inputs_valid_test: true
50
- validate_interval: 1
51
- validate_interval_updates: 10000
52
-
53
- criterion:
54
- _name: hubert
55
- pred_masked_weight: 1.0
56
- pred_nomask_weight: 0.0
57
- loss_weights: [10, 1]
58
-
59
- optimization:
60
- max_update: 400000
61
- lr: [0.0015]
62
- clip_norm: 1.0
63
- update_freq: [8]
64
-
65
- optimizer:
66
- _name: adam
67
- adam_betas: (0.9,0.98)
68
- adam_eps: 1e-06
69
- weight_decay: 0.01
70
-
71
- lr_scheduler:
72
- _name: polynomial_decay
73
- warmup_updates: 32000
74
-
75
- model:
76
- _name: mert
77
- label_rate: ???
78
- skip_masked: false
79
- skip_nomask: true
80
- mask_prob: 0.8
81
- mask_length: 5
82
- # freeze_parameters:true
83
- logit_temp: 0.1
84
-
85
-
86
- # ----- mixture ------
87
- mixture_prob: 0.5
88
- inbatch_noise_augment_len_range: "[12000, 36000]"
89
- inbatch_noise_augment_number_range: "[1, 3]"
90
- inbatch_noise_augment_volume: 1.0
91
- # ------------------------
92
-
93
- # ---- cqt reconstruction, need to add loss weight ---
94
- audio_cqt_loss_m: true
95
- audio_cqt_bins: 336
96
-
97
- final_dim: 128
98
- encoder_layers: 24
99
- encoder_embed_dim: 1024
100
- encoder_ffn_embed_dim: 4096
101
- encoder_attention_heads: 16
102
- # default refers to group norm
103
- extractor_mode: default
104
- # extractor_mode: layer_norm
105
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
106
- encoder_layerdrop: 0.0
107
- dropout_input: 0.0
108
- dropout_features: 0.0
109
- dropout: 0.0
110
- attention_dropout: 0.0
111
-
112
- layer_norm_first: true
113
- feature_grad_mult: 1.0
114
-
115
- untie_final_proj: true
116
- activation_dropout: 0.0
117
-
118
- deepnorm: false
119
- attention_relax: 32.0
120
-
121
-
122
-
123
- hydra:
124
- job:
125
- config:
126
- override_dirname:
127
- kv_sep: '-'
128
- item_sep: '__'
129
- exclude_keys:
130
- - run
131
- - task.data
132
- - task.label_dir
133
- run:
134
- dir: ???
135
- sweep:
136
- dir: ???
137
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml DELETED
@@ -1,116 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 25000
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
-
57
- optimizer:
58
- _name: adam
59
- adam_betas: (0.9,0.98)
60
- adam_eps: 1e-06
61
- weight_decay: 0.01
62
-
63
- lr_scheduler:
64
- _name: polynomial_decay
65
- warmup_updates: 32000
66
-
67
- model:
68
- _name: mert
69
- label_rate: ???
70
- skip_masked: false
71
- skip_nomask: true
72
- mask_prob: 0.8
73
- mask_length: 5
74
-
75
- logit_temp: 0.1
76
-
77
- # ----- mixture ------
78
- mixture_prob: 0.5
79
- inbatch_noise_augment_len_range: "[12000, 24000]"
80
- inbatch_noise_augment_number_range: "[1, 3]"
81
- inbatch_noise_augment_volume: 1.0
82
- # ------------------------
83
- extractor_mode: default
84
- audio_extract_type: w2v_conv
85
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
86
-
87
- # ---- cqt reconstruction, need to add loss weight ---
88
- audio_cqt_loss_m: true
89
- audio_cqt_bins: 336
90
- # -----------
91
- final_dim: 64
92
- encoder_layerdrop: 0.05
93
- dropout_input: 0.1
94
- dropout_features: 0.1
95
- dropout: 0.1
96
- attention_dropout: 0.1
97
- feature_grad_mult: 0.1
98
- untie_final_proj: true
99
- activation_dropout: 0.0
100
-
101
-
102
- hydra:
103
- job:
104
- config:
105
- override_dirname:
106
- kv_sep: '-'
107
- item_sep: '__'
108
- exclude_keys:
109
- - run
110
- - task.data
111
- - task.label_dir
112
- run:
113
- dir: ???
114
- sweep:
115
- dir: ???
116
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml DELETED
@@ -1,125 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 25000
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 8 # 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
-
57
- optimizer:
58
- _name: adam
59
- adam_betas: (0.9,0.98)
60
- adam_eps: 1e-06
61
- weight_decay: 0.01
62
-
63
- lr_scheduler:
64
- _name: polynomial_decay
65
- warmup_updates: 32000
66
-
67
- model:
68
- _name: mert
69
- label_rate: ???
70
- skip_masked: false
71
- skip_nomask: true
72
- mask_prob: 0.8
73
- mask_length: 5
74
-
75
- logit_temp: 0.1
76
-
77
- # ----- mixture ------
78
- mixture_prob: 0.5
79
- inbatch_noise_augment_len_range: "[12000, 24000]"
80
- inbatch_noise_augment_number_range: "[1, 3]"
81
- inbatch_noise_augment_volume: 1.0
82
- # ------------------------
83
- extractor_mode: default
84
- audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
85
- melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
86
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
87
-
88
- # best-rq loss
89
- audio_rq_loss_m: true
90
- audio_rq_loss_embed_dim: 16
91
- audio_rq_loss_num_codebooks: 1
92
- audio_rq_loss_num_embeds: 8192
93
- audio_rq_loss_seed: 42
94
- audio_rq_loss_use_norm: true
95
-
96
- # ---- cqt reconstruction, need to add loss weight ---
97
- audio_cqt_loss_m: true
98
- audio_cqt_bins: 336
99
- # -----------
100
- final_dim: 64
101
- encoder_layerdrop: 0.05
102
- dropout_input: 0.1
103
- dropout_features: 0.1
104
- dropout: 0.1
105
- attention_dropout: 0.1
106
- feature_grad_mult: 0.1
107
- untie_final_proj: true
108
- activation_dropout: 0.0
109
-
110
-
111
- hydra:
112
- job:
113
- config:
114
- override_dirname:
115
- kv_sep: '-'
116
- item_sep: '__'
117
- exclude_keys:
118
- - run
119
- - task.data
120
- - task.label_dir
121
- run:
122
- dir: ???
123
- sweep:
124
- dir: ???
125
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml DELETED
@@ -1,128 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
- update_freq: [4]
57
-
58
- optimizer:
59
- _name: adam
60
- adam_betas: (0.9,0.98)
61
- adam_eps: 1e-06
62
- weight_decay: 0.01
63
-
64
- lr_scheduler:
65
- _name: polynomial_decay
66
- warmup_updates: 32000
67
-
68
- model:
69
- _name: mert
70
- label_rate: ???
71
- skip_masked: false
72
- skip_nomask: true
73
- mask_prob: 0.8
74
- mask_length: 5
75
-
76
- logit_temp: 0.1
77
-
78
- # ----- mixture ------
79
- mixture_prob: 0.5
80
- inbatch_noise_augment_len_range: "[12000, 24000]"
81
- inbatch_noise_augment_number_range: "[1, 3]"
82
- inbatch_noise_augment_volume: 1.0
83
- # ------------------------
84
- extractor_mode: default
85
- audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
86
- melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
87
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
88
-
89
- # best-rq loss
90
- audio_rq_loss_m: true
91
- audio_rq_loss_embed_dim: 16
92
- audio_rq_loss_num_codebooks: 1
93
- audio_rq_loss_num_embeds: 8192
94
- audio_rq_loss_seed: 42
95
- audio_rq_loss_use_norm: true
96
- audio_rq_loss_use_chroma: true
97
- audio_rq_loss_seed_chroma: 123
98
-
99
- # ---- cqt reconstruction, need to add loss weight ---
100
- audio_cqt_loss_m: true
101
- audio_cqt_bins: 336
102
- # -----------
103
- final_dim: 32
104
- encoder_layerdrop: 0.05
105
- dropout_input: 0.1
106
- dropout_features: 0.1
107
- dropout: 0.1
108
- attention_dropout: 0.1
109
- feature_grad_mult: 0.1
110
- untie_final_proj: true
111
- activation_dropout: 0.0
112
-
113
-
114
- hydra:
115
- job:
116
- config:
117
- override_dirname:
118
- kv_sep: '-'
119
- item_sep: '__'
120
- exclude_keys:
121
- - run
122
- - task.data
123
- - task.label_dir
124
- run:
125
- dir: ???
126
- sweep:
127
- dir: ???
128
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml DELETED
@@ -1,126 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
- update_freq: [4]
57
-
58
- optimizer:
59
- _name: adam
60
- adam_betas: (0.9,0.98)
61
- adam_eps: 1e-06
62
- weight_decay: 0.01
63
-
64
- lr_scheduler:
65
- _name: polynomial_decay
66
- warmup_updates: 32000
67
-
68
- model:
69
- _name: mert
70
- label_rate: ???
71
- skip_masked: false
72
- skip_nomask: true
73
- mask_prob: 0.8
74
- mask_length: 5
75
-
76
- logit_temp: 0.1
77
-
78
- # ----- mixture ------
79
- mixture_prob: 0.5
80
- inbatch_noise_augment_len_range: "[12000, 24000]"
81
- inbatch_noise_augment_number_range: "[1, 3]"
82
- inbatch_noise_augment_volume: 1.0
83
- # ------------------------
84
- extractor_mode: default
85
- audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
86
- melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
87
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
88
-
89
- # best-rq loss
90
- audio_rq_loss_m: true
91
- audio_rq_loss_embed_dim: 16
92
- audio_rq_loss_num_codebooks: 1
93
- audio_rq_loss_num_embeds: 8192
94
- audio_rq_loss_seed: 42
95
- audio_rq_loss_use_norm: true
96
-
97
- # ---- cqt reconstruction, need to add loss weight ---
98
- audio_cqt_loss_m: true
99
- audio_cqt_bins: 336
100
- # -----------
101
- final_dim: 64
102
- encoder_layerdrop: 0.05
103
- dropout_input: 0.1
104
- dropout_features: 0.1
105
- dropout: 0.1
106
- attention_dropout: 0.1
107
- feature_grad_mult: 0.1
108
- untie_final_proj: true
109
- activation_dropout: 0.0
110
-
111
-
112
- hydra:
113
- job:
114
- config:
115
- override_dirname:
116
- kv_sep: '-'
117
- item_sep: '__'
118
- exclude_keys:
119
- - run
120
- - task.data
121
- - task.label_dir
122
- run:
123
- dir: ???
124
- sweep:
125
- dir: ???
126
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml DELETED
@@ -1,128 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
- update_freq: [4]
57
-
58
- optimizer:
59
- _name: adam
60
- adam_betas: (0.9,0.98)
61
- adam_eps: 1e-06
62
- weight_decay: 0.01
63
-
64
- lr_scheduler:
65
- _name: polynomial_decay
66
- warmup_updates: 32000
67
-
68
- model:
69
- _name: mert
70
- label_rate: ???
71
- skip_masked: false
72
- skip_nomask: true
73
- mask_prob: 0.8
74
- mask_length: 5
75
-
76
- logit_temp: 0.1
77
-
78
- # ----- mixture ------
79
- mixture_prob: 0.5
80
- inbatch_noise_augment_len_range: "[12000, 24000]"
81
- inbatch_noise_augment_number_range: "[1, 3]"
82
- inbatch_noise_augment_volume: 1.0
83
- # ------------------------
84
- extractor_mode: default
85
- audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
86
- melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
87
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
88
-
89
- # best-rq loss
90
- audio_rq_loss_m: true
91
- audio_rq_loss_embed_dim: 16
92
- audio_rq_loss_num_codebooks: 1
93
- audio_rq_loss_num_embeds: 8192
94
- audio_rq_loss_seed: 42
95
- audio_rq_loss_use_norm: true
96
- audio_rq_loss_use_chroma: false
97
- audio_rq_loss_seed_chroma: 123
98
-
99
- # ---- cqt reconstruction, need to add loss weight ---
100
- audio_cqt_loss_m: true
101
- audio_cqt_bins: 336
102
- # -----------
103
- final_dim: 64
104
- encoder_layerdrop: 0.05
105
- dropout_input: 0.1
106
- dropout_features: 0.1
107
- dropout: 0.1
108
- attention_dropout: 0.1
109
- feature_grad_mult: 0.1
110
- untie_final_proj: true
111
- activation_dropout: 0.0
112
-
113
-
114
- hydra:
115
- job:
116
- config:
117
- override_dirname:
118
- kv_sep: '-'
119
- item_sep: '__'
120
- exclude_keys:
121
- - run
122
- - task.data
123
- - task.label_dir
124
- run:
125
- dir: ???
126
- sweep:
127
- dir: ???
128
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml DELETED
@@ -1,128 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
- update_freq: [4]
57
-
58
- optimizer:
59
- _name: adam
60
- adam_betas: (0.9,0.98)
61
- adam_eps: 1e-06
62
- weight_decay: 0.01
63
-
64
- lr_scheduler:
65
- _name: polynomial_decay
66
- warmup_updates: 32000
67
-
68
- model:
69
- _name: mert
70
- label_rate: ???
71
- skip_masked: false
72
- skip_nomask: true
73
- mask_prob: 0.8
74
- mask_length: 5
75
-
76
- logit_temp: 0.1
77
-
78
- # ----- mixture ------
79
- mixture_prob: 0 # 0.5
80
- inbatch_noise_augment_len_range: "[12000, 24000]"
81
- inbatch_noise_augment_number_range: "[1, 3]"
82
- inbatch_noise_augment_volume: 1.0
83
- # ------------------------
84
- extractor_mode: default
85
- audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
86
- melspec_n_bins: 80 # 120 # for melspec we use 120, means 12 bins per octave
87
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
88
-
89
- # best-rq loss
90
- audio_rq_loss_m: true
91
- audio_rq_loss_embed_dim: 16
92
- audio_rq_loss_num_codebooks: 1
93
- audio_rq_loss_num_embeds: 8192
94
- audio_rq_loss_seed: 42
95
- audio_rq_loss_use_norm: true
96
- audio_rq_loss_use_chroma: false
97
- audio_rq_loss_seed_chroma: 123
98
-
99
- # ---- cqt reconstruction, need to add loss weight ---
100
- audio_cqt_loss_m: false
101
- audio_cqt_bins: 336
102
- # -----------
103
- final_dim: 64
104
- encoder_layerdrop: 0.05
105
- dropout_input: 0.1
106
- dropout_features: 0.1
107
- dropout: 0.1
108
- attention_dropout: 0.1
109
- feature_grad_mult: 0.1
110
- untie_final_proj: true
111
- activation_dropout: 0.0
112
-
113
-
114
- hydra:
115
- job:
116
- config:
117
- override_dirname:
118
- kv_sep: '-'
119
- item_sep: '__'
120
- exclude_keys:
121
- - run
122
- - task.data
123
- - task.label_dir
124
- run:
125
- dir: ???
126
- sweep:
127
- dir: ???
128
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml DELETED
@@ -1,121 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
- update_freq: [4]
57
-
58
- optimizer:
59
- _name: adam
60
- adam_betas: (0.9,0.98)
61
- adam_eps: 1e-06
62
- weight_decay: 0.01
63
-
64
- lr_scheduler:
65
- _name: polynomial_decay
66
- warmup_updates: 32000
67
-
68
- model:
69
- _name: mert
70
- label_rate: ???
71
- skip_masked: false
72
- skip_nomask: true
73
- mask_prob: 0.8
74
- mask_length: 5
75
-
76
- logit_temp: 0.1
77
-
78
- # ----- mixture ------
79
- mixture_prob: 0.5
80
- inbatch_noise_augment_len_range: "[12000, 24000]"
81
- inbatch_noise_augment_number_range: "[1, 3]"
82
- inbatch_noise_augment_volume: 1.0
83
- # ------------------------
84
- extractor_mode: default
85
- audio_extract_type: w2v_conv
86
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
87
-
88
- # ---- codec target
89
- audio_codec_type: rvq
90
- audio_codec_ckpt_path: RVQ_3000.pth
91
-
92
- # ---- cqt reconstruction, need to add loss weight ---
93
- audio_cqt_loss_m: true
94
- audio_cqt_bins: 336
95
- # -----------
96
- final_dim: 64
97
- encoder_layerdrop: 0.05
98
- dropout_input: 0.1
99
- dropout_features: 0.1
100
- dropout: 0.1
101
- attention_dropout: 0.1
102
- feature_grad_mult: 0.1
103
- untie_final_proj: true
104
- activation_dropout: 0.0
105
-
106
-
107
- hydra:
108
- job:
109
- config:
110
- override_dirname:
111
- kv_sep: '-'
112
- item_sep: '__'
113
- exclude_keys:
114
- - run
115
- - task.data
116
- - task.label_dir
117
- run:
118
- dir: ???
119
- sweep:
120
- dir: ???
121
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac.yaml DELETED
File without changes
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml DELETED
@@ -1,121 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
- update_freq: [4]
57
-
58
- optimizer:
59
- _name: adam
60
- adam_betas: (0.9,0.98)
61
- adam_eps: 1e-06
62
- weight_decay: 0.01
63
-
64
- lr_scheduler:
65
- _name: polynomial_decay
66
- warmup_updates: 32000
67
-
68
- model:
69
- _name: mert
70
- label_rate: ???
71
- skip_masked: false
72
- skip_nomask: true
73
- mask_prob: 0.8
74
- mask_length: 5
75
-
76
- logit_temp: 0.1
77
-
78
- # ----- mixture ------
79
- mixture_prob: 0.5
80
- inbatch_noise_augment_len_range: "[12000, 24000]"
81
- inbatch_noise_augment_number_range: "[1, 3]"
82
- inbatch_noise_augment_volume: 1.0
83
- # ------------------------
84
- extractor_mode: default
85
- audio_extract_type: w2v_conv
86
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
87
-
88
- # ---- codec target
89
- audio_codec_type: dac
90
- audio_codec_dac_model_path: weights_24khz_8kbps_0.0.4.pth #nj
91
-
92
- # ---- cqt reconstruction, need to add loss weight ---
93
- audio_cqt_loss_m: true
94
- audio_cqt_bins: 336
95
- # -----------
96
- final_dim: 64
97
- encoder_layerdrop: 0.05
98
- dropout_input: 0.1
99
- dropout_features: 0.1
100
- dropout: 0.1
101
- attention_dropout: 0.1
102
- feature_grad_mult: 0.1
103
- untie_final_proj: true
104
- activation_dropout: 0.0
105
-
106
-
107
- hydra:
108
- job:
109
- config:
110
- override_dirname:
111
- kv_sep: '-'
112
- item_sep: '__'
113
- exclude_keys:
114
- - run
115
- - task.data
116
- - task.label_dir
117
- run:
118
- dir: ???
119
- sweep:
120
- dir: ???
121
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml DELETED
@@ -1,125 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
- update_freq: [4]
57
-
58
- optimizer:
59
- _name: adam
60
- adam_betas: (0.9,0.98)
61
- adam_eps: 1e-06
62
- weight_decay: 0.01
63
-
64
- lr_scheduler:
65
- _name: polynomial_decay
66
- warmup_updates: 32000
67
-
68
- model:
69
- _name: mert
70
- label_rate: ???
71
- skip_masked: false
72
- skip_nomask: true
73
- mask_prob: 0.8
74
- mask_length: 5
75
-
76
- logit_temp: 0.1
77
-
78
- # ----- mixture ------
79
- mixture_prob: 0.5
80
- inbatch_noise_augment_len_range: "[12000, 24000]"
81
- inbatch_noise_augment_number_range: "[1, 3]"
82
- inbatch_noise_augment_volume: 1.0
83
- # ------------------------
84
- extractor_mode: default
85
- audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
86
- melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
87
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
88
-
89
- # best-rq loss
90
- audio_rq_loss_m: true
91
- audio_rq_loss_embed_dim: 16
92
- audio_rq_loss_num_codebooks: 64 # 32
93
- audio_rq_loss_num_embeds: 1024
94
- audio_rq_loss_seed: 42
95
-
96
- # ---- cqt reconstruction, need to add loss weight ---
97
- audio_cqt_loss_m: true
98
- audio_cqt_bins: 336
99
- # -----------
100
- final_dim: 16 # 64
101
- encoder_layerdrop: 0.05
102
- dropout_input: 0.1
103
- dropout_features: 0.1
104
- dropout: 0.1
105
- attention_dropout: 0.1
106
- feature_grad_mult: 0.1
107
- untie_final_proj: true
108
- activation_dropout: 0.0
109
-
110
-
111
- hydra:
112
- job:
113
- config:
114
- override_dirname:
115
- kv_sep: '-'
116
- item_sep: '__'
117
- exclude_keys:
118
- - run
119
- - task.data
120
- - task.label_dir
121
- run:
122
- dir: ???
123
- sweep:
124
- dir: ???
125
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml DELETED
@@ -1,124 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # crop to 5s
31
- max_sample_size: 120000
32
- min_sample_size: 72000
33
-
34
- pad_audio: false
35
- random_crop: true
36
- normalize: false # must be consistent with extractor
37
-
38
-
39
- dataset:
40
- num_workers: 6
41
- max_tokens: 2000000
42
- skip_invalid_size_inputs_valid_test: true
43
- validate_interval: 1
44
- validate_interval_updates: 10000
45
-
46
- criterion:
47
- _name: hubert
48
- pred_masked_weight: 1.0
49
- pred_nomask_weight: 0.0
50
- loss_weights: [10, 1]
51
-
52
- optimization:
53
- max_update: 400000
54
- lr: [0.0005]
55
- clip_norm: 10.0
56
- update_freq: [4]
57
-
58
- optimizer:
59
- _name: adam
60
- adam_betas: (0.9,0.98)
61
- adam_eps: 1e-06
62
- weight_decay: 0.01
63
-
64
- lr_scheduler:
65
- _name: polynomial_decay
66
- warmup_updates: 32000
67
-
68
- model:
69
- _name: mert
70
- label_rate: ???
71
- skip_masked: false
72
- skip_nomask: true
73
- mask_prob: 0.8
74
- mask_length: 5
75
-
76
- logit_temp: 0.1
77
-
78
- # ----- mixture ------
79
- mixture_prob: 0.5
80
- inbatch_noise_augment_len_range: "[12000, 24000]"
81
- inbatch_noise_augment_number_range: "[1, 3]"
82
- inbatch_noise_augment_volume: 1.0
83
- # ------------------------
84
- extractor_mode: default
85
- audio_extract_type: melspec # use melspec (instead of `w2v_conv`)
86
- melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave
87
- conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
88
-
89
- # best-rq loss
90
- audio_rq_loss_m: false
91
- audio_rq_loss_embed_dim: 16
92
- audio_rq_loss_num_codebooks: 1
93
- audio_rq_loss_num_embeds: 8192
94
-
95
- # ---- cqt reconstruction, need to add loss weight ---
96
- audio_cqt_loss_m: true
97
- audio_cqt_bins: 336
98
- # -----------
99
- final_dim: 64
100
- encoder_layerdrop: 0.05
101
- dropout_input: 0.1
102
- dropout_features: 0.1
103
- dropout: 0.1
104
- attention_dropout: 0.1
105
- feature_grad_mult: 0.1
106
- untie_final_proj: true
107
- activation_dropout: 0.0
108
-
109
-
110
- hydra:
111
- job:
112
- config:
113
- override_dirname:
114
- kv_sep: '-'
115
- item_sep: '__'
116
- exclude_keys:
117
- - run
118
- - task.data
119
- - task.label_dir
120
- run:
121
- dir: ???
122
- sweep:
123
- dir: ???
124
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml DELETED
@@ -1,108 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # # crop to 5s
31
- # max_sample_size: 120000
32
- # min_sample_size: 72000
33
-
34
- # crop to 30s
35
- max_sample_size: 720000
36
- min_sample_size: 432000
37
- clip_secs: 30
38
-
39
- pad_audio: false
40
- random_crop: true
41
- normalize: false # must be consistent with extractor
42
-
43
-
44
- dataset:
45
- num_workers: 6
46
- max_tokens: 2000000
47
- skip_invalid_size_inputs_valid_test: true
48
- validate_interval: 1
49
- validate_interval_updates: 10000
50
-
51
- criterion:
52
- _name: model
53
- # log_keys:
54
- # - accuracies
55
-
56
- optimization:
57
- max_update: 400000
58
- lr: [0.0005]
59
- clip_norm: 10.0
60
- update_freq: [1]
61
-
62
- optimizer:
63
- _name: adam
64
- adam_betas: (0.9,0.98)
65
- adam_eps: 1e-06
66
- weight_decay: 0.01
67
-
68
- lr_scheduler:
69
- _name: polynomial_decay
70
- warmup_updates: 32000
71
-
72
- model:
73
- _name: musicfm
74
- label_rate: 25
75
- num_codebooks: 1
76
- codebook_dim: 16
77
- codebook_size: 8192 # 4096
78
- features: ["melspec_2048"]
79
- hop_length: 240
80
- n_mels: 128
81
- conv_dim: 512
82
- encoder_dim: 1024
83
- encoder_depth: 12
84
- mask_hop: 0.4
85
- mask_prob: 0.6
86
- is_flash: false
87
-
88
- stat_path: msd_stats.json
89
- model_path: null
90
- w2v2_config_path: our-MERT/data/models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa
91
- use_rvq_target: true
92
- rvq_ckpt_path: RVQ_4000.pth
93
-
94
- hydra:
95
- job:
96
- config:
97
- override_dirname:
98
- kv_sep: '-'
99
- item_sep: '__'
100
- exclude_keys:
101
- - run
102
- - task.data
103
- - task.label_dir
104
- run:
105
- dir: ???
106
- sweep:
107
- dir: ???
108
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml DELETED
@@ -1,105 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 12500
12
- keep_interval_updates: -1
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # # crop to 5s
31
- # max_sample_size: 120000
32
- # min_sample_size: 72000
33
-
34
- # crop to 30s
35
- max_sample_size: 720000
36
- min_sample_size: 432000
37
- clip_secs: 30
38
-
39
- pad_audio: false
40
- random_crop: true
41
- normalize: false # must be consistent with extractor
42
-
43
-
44
- dataset:
45
- num_workers: 6
46
- max_tokens: 2000000
47
- skip_invalid_size_inputs_valid_test: true
48
- validate_interval: 1
49
- validate_interval_updates: 10000
50
-
51
- criterion:
52
- _name: model
53
- # log_keys:
54
- # - accuracies
55
-
56
- optimization:
57
- max_update: 400000
58
- lr: [0.0005]
59
- clip_norm: 10.0
60
- update_freq: [1]
61
-
62
- optimizer:
63
- _name: adam
64
- adam_betas: (0.9,0.98)
65
- adam_eps: 1e-06
66
- weight_decay: 0.01
67
-
68
- lr_scheduler:
69
- _name: polynomial_decay
70
- warmup_updates: 32000
71
-
72
- model:
73
- _name: musicfm
74
- label_rate: 25
75
- num_codebooks: 1
76
- codebook_dim: 16
77
- codebook_size: 4096
78
- features: ["melspec_2048"]
79
- hop_length: 240
80
- n_mels: 128
81
- conv_dim: 512
82
- encoder_dim: 1024
83
- encoder_depth: 12
84
- mask_hop: 0.4
85
- mask_prob: 0.6
86
- is_flash: false
87
- stat_path: msd_stats.json
88
- model_path: pretrained_msd.pt
89
- w2v2_config_path: models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa
90
-
91
- hydra:
92
- job:
93
- config:
94
- override_dirname:
95
- kv_sep: '-'
96
- item_sep: '__'
97
- exclude_keys:
98
- - run
99
- - task.data
100
- - task.label_dir
101
- run:
102
- dir: ???
103
- sweep:
104
- dir: ???
105
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml DELETED
@@ -1,106 +0,0 @@
1
- # @package _group_
2
- common:
3
- fp16: false
4
- log_format: json
5
- log_interval: 200
6
- seed: 1337
7
- # tensorboard_logdir: tblog_proj_name
8
- # wandb_project: wandb_proj_name
9
-
10
- checkpoint:
11
- save_interval_updates: 2500
12
- keep_interval_updates: 10000
13
- no_epoch_checkpoints: true
14
-
15
-
16
- distributed_training:
17
- ddp_backend: no_c10d
18
- distributed_backend: 'nccl'
19
- distributed_world_size: 64
20
- nprocs_per_node: 8
21
- find_unused_parameters: true
22
-
23
- task:
24
- _name: mert_pretraining
25
- data: ???
26
- label_dir: ???
27
- labels: ???
28
- label_rate: ${model.label_rate}
29
- sample_rate: 24000
30
- # # crop to 5s
31
- # max_sample_size: 120000
32
- # min_sample_size: 72000
33
-
34
- # crop to 30s
35
- max_sample_size: 720000
36
- min_sample_size: 12000
37
- # clip_secs: 30
38
-
39
- pad_audio: false
40
- random_crop: true
41
- normalize: false # must be consistent with extractor
42
-
43
-
44
- dataset:
45
- num_workers: 6
46
- max_tokens: 2000000
47
- skip_invalid_size_inputs_valid_test: true
48
- validate_interval: 1
49
- validate_interval_updates: 10000
50
- disable_validation: true
51
-
52
- criterion:
53
- _name: model
54
- # log_keys:
55
- # - accuracies
56
-
57
- optimization:
58
- max_update: 400000
59
- lr: [0.0005]
60
- clip_norm: 10.0
61
- update_freq: [1]
62
-
63
- optimizer:
64
- _name: adam
65
- adam_betas: (0.9,0.98)
66
- adam_eps: 1e-06
67
- weight_decay: 0.01
68
-
69
- lr_scheduler:
70
- _name: polynomial_decay
71
- warmup_updates: 32000
72
-
73
- model:
74
- _name: musicfm
75
- label_rate: 25
76
- num_codebooks: 1
77
- codebook_dim: 16
78
- codebook_size: 4096
79
- features: ["melspec_2048"]
80
- hop_length: 240
81
- n_mels: 128
82
- conv_dim: 512
83
- encoder_dim: 1024
84
- encoder_depth: 12
85
- mask_hop: 0.4
86
- mask_prob: 0.6
87
- is_flash: false
88
- stat_path: msd_stats.json
89
- model_path: null
90
- w2v2_config_path: models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa
91
-
92
- hydra:
93
- job:
94
- config:
95
- override_dirname:
96
- kv_sep: '-'
97
- item_sep: '__'
98
- exclude_keys:
99
- - run
100
- - task.data
101
- - task.label_dir
102
- run:
103
- dir: ???
104
- sweep:
105
- dir: ???
106
- subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml DELETED
@@ -1,20 +0,0 @@
1
- # @package _global_
2
-
3
- hydra:
4
- launcher:
5
- cpus_per_task: 8
6
- gpus_per_node: 8
7
- tasks_per_node: ${hydra.launcher.gpus_per_node}
8
- nodes: 4
9
- comment: null
10
- mem_gb: 384
11
- timeout_min: 4320
12
- max_num_timeout: 100
13
- constraint: volta32gb
14
- name: ${hydra.job.config_name}/${hydra.job.override_dirname}
15
- submitit_folder: ${hydra.sweep.dir}/submitit/%j
16
-
17
- distributed_training:
18
- distributed_world_size: 32
19
- distributed_port: 29671
20
- nprocs_per_node: 8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .mert_dataset import MERTDataset
2
- from .eat_data import *
 
 
 
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py DELETED
@@ -1,115 +0,0 @@
1
- import logging
2
- import torch
3
- import torch.nn.functional as F
4
- from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
5
- from typing import Tuple
6
- try:
7
- import kaldiio
8
- except:
9
- kaldiio = None
10
- import warnings
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class ArkDataset(RawAudioDataset):
16
- def __init__(
17
- self,
18
- wav_scp,
19
- dur_scp,
20
- sr = 24000,
21
- max_dur = 20,
22
- num_buckets=0,
23
- normalize=False,
24
- ):
25
- super().__init__(
26
- sample_rate=sr,
27
- max_sample_size=max_dur*sr,
28
- min_sample_size=1200,
29
- shuffle=True,
30
- pad=True,
31
- normalize=normalize,
32
- compute_mask=False,
33
- )
34
- self.sr = sr
35
- self.max_dur = max_dur
36
- self.normalize = normalize
37
-
38
- logger.info("Loading Kaldi scp files from {}".format(wav_scp))
39
-
40
- self.wav_data = kaldiio.load_scp(wav_scp)
41
- self.keys = list(self.wav_data.keys())
42
- dur_data = {}
43
- keys_set = set(self.keys)
44
-
45
- with open(dur_scp, 'r') as f:
46
- for line in f:
47
- line = line.strip().split()
48
- if line[0] in keys_set:
49
- dur_data[line[0]] = float(line[-1])
50
- self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys]
51
-
52
- logger.info("Loading Kaldi scp files done")
53
-
54
- self.dataset_len = len(self.keys)
55
- self.set_bucket_info(num_buckets)
56
-
57
- def __len__(self):
58
- return self.dataset_len
59
-
60
- def __getitem__(self, idx):
61
- # print("getitem idx: ", idx)
62
- try_cnt = 0
63
- while True:
64
- idx = idx + try_cnt
65
- try:
66
- with warnings.catch_warnings():
67
- warnings.simplefilter("ignore")
68
- key = self.keys[idx]
69
- # print(self.wav_data[key].keys())
70
- wav = self.wav_data[key]['wav']
71
-
72
- wav = torch.from_numpy(wav).float()
73
- wav = self.postprocess(wav)
74
- # print("success load", idx, " shape =", wav.shape)
75
- return {"id": idx, "source": wav}
76
- except Exception as e:
77
- # from traceback import print_exc
78
- # print_exc()
79
- # print("Error loadding ", idx)
80
- # return {"id": idx, "source": None}
81
- try_cnt += 1
82
- if try_cnt > 50:
83
- return {"id": idx, "source": None}
84
- continue
85
-
86
- def size(self, idx):
87
- return self.sizes[idx]
88
-
89
- def postprocess(self, wav):
90
- if wav.dim() == 2:
91
- wav = wav.mean(-1)
92
- assert wav.dim() == 1, wav.dim()
93
-
94
- if self.normalize:
95
- with torch.no_grad():
96
- wav = F.layer_norm(wav, wav.shape)
97
- return wav
98
-
99
- def collater(self, samples):
100
- # print("collate from:", [s['source'].shape for s in samples if s['source'] is not None])
101
- return super().collater(samples)
102
-
103
- if __name__ == '__main__':
104
- import torch
105
- raw_tensor_str = torch.Tensor.__repr__
106
- torch.Tensor.__str__ = torch.Tensor.__repr__ = lambda self: f'Tensor{{Size({[*self.shape]}) {self.device} {str(self.dtype)[6]}{str(self.dtype)[-2:]}}}' if self.numel() > 10 else raw_tensor_str(self)
107
-
108
- ds = ArkDataset(
109
- wav_scp='data/ark_demo/wav_ark.scp',
110
- dur_scp='data/ark_demo/dur_ark.scp',
111
- sr=24000,
112
- )
113
-
114
- for i in range(len(ds)):
115
- print(ds[i])