ing0 commited on
Commit
f5b749d
·
1 Parent(s): 0d355fb
app.py CHANGED
@@ -27,22 +27,16 @@ from diffrhythm.infer.infer import inference
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  device='cuda'
30
- cfm, cfm_full, tokenizer, muq, vae = prepare_model(device)
31
  cfm = torch.compile(cfm)
32
- cfm_full = torch.compile(cfm_full)
33
 
34
- @spaces.GPU(duration=40)
35
- def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav', odeint_method='euler', Music_Duration='95s', device='cuda'):
36
- if Music_Duration == '95s':
37
- max_frames = 2048
38
- cfm_model = cfm
39
- else:
40
- max_frames = 6144
41
- cfm_model = cfm_full
42
  if randomize_seed:
43
  seed = random.randint(0, MAX_SEED)
44
  torch.manual_seed(seed)
45
- sway_sampling_coef = -1 if steps < 32 else None
46
  vocal_flag = False
47
  try:
48
  lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
@@ -53,9 +47,16 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
53
  except Exception as e:
54
  raise gr.Error(f"Error: {str(e)}")
55
  negative_style_prompt = get_negative_style_prompt(device)
56
- latent_prompt = get_reference_latent(device, max_frames)
57
- generated_song = inference(cfm_model=cfm_model,
 
 
 
 
 
58
  vae_model=vae,
 
 
59
  cond=latent_prompt,
60
  text=lrc_prompt,
61
  duration=max_frames,
@@ -68,6 +69,8 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42,
68
  file_type=file_type,
69
  vocal_flag=vocal_flag,
70
  odeint_method=odeint_method,
 
 
71
  )
72
  return generated_song
73
 
@@ -234,8 +237,8 @@ with gr.Blocks(css=css) as demo:
234
  - If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
235
 
236
  """)
237
- Music_Duration = gr.Radio(["95s", "285s"], label="Music Duration", value="95s")
238
-
239
  lyrics_btn = gr.Button("Generate", variant="primary")
240
  audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
241
  with gr.Accordion("Advanced Settings", open=False):
@@ -266,6 +269,12 @@ with gr.Blocks(css=css) as demo:
266
  interactive=True,
267
  elem_id="step_slider"
268
  )
 
 
 
 
 
 
269
  odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
270
  file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
271
 
@@ -409,7 +418,7 @@ with gr.Blocks(css=css) as demo:
409
 
410
  lyrics_btn.click(
411
  fn=infer_music,
412
- inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method, Music_Duration],
413
  outputs=audio_output
414
  )
415
 
 
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  device='cuda'
30
+ cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(device)
31
  cfm = torch.compile(cfm)
 
32
 
33
+ @spaces.GPU
34
+ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav', odeint_method='euler', preference_infer="quality first", edit=False, edit_segments=None, device='cuda'):
35
+ max_frames = 2048
36
+ sway_sampling_coef = -1 if steps < 32 else None
 
 
 
 
37
  if randomize_seed:
38
  seed = random.randint(0, MAX_SEED)
39
  torch.manual_seed(seed)
 
40
  vocal_flag = False
41
  try:
42
  lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
 
47
  except Exception as e:
48
  raise gr.Error(f"Error: {str(e)}")
49
  negative_style_prompt = get_negative_style_prompt(device)
50
+ latent_prompt, pred_frames = get_reference_latent(device, max_frames, edit, edit_segments, ref_audio_path, vae)
51
+
52
+ if preference_infer == "quality first":
53
+ batch_infer_num = 5
54
+ else:
55
+ batch_infer_num = 1
56
+ generated_song = inference(cfm_model=cfm,
57
  vae_model=vae,
58
+ eval_model=eval_model,
59
+ eval_muq=eval_muq,
60
  cond=latent_prompt,
61
  text=lrc_prompt,
62
  duration=max_frames,
 
69
  file_type=file_type,
70
  vocal_flag=vocal_flag,
71
  odeint_method=odeint_method,
72
+ pred_frames=pred_frames,
73
+ batch_infer_num=batch_infer_num,
74
  )
75
  return generated_song
76
 
 
237
  - If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.
238
 
239
  """)
240
+ # Music_Duration = gr.Radio(["95s", "285s"], label="Music Duration", value="95s")
241
+ preference_infer = gr.Radio(["quality first", "speed first"], label="Preference", value="quality first")
242
  lyrics_btn = gr.Button("Generate", variant="primary")
243
  audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
244
  with gr.Accordion("Advanced Settings", open=False):
 
269
  interactive=True,
270
  elem_id="step_slider"
271
  )
272
+ edit = gr.Checkbox(label="edit", value=False)
273
+ edit_segeditments = gr.Textbox(
274
+ label="Edit Segments",
275
+ placeholder="Time segments to edit (in seconds). Format: `[[start1,end1],...]",
276
+ )
277
+
278
  odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
279
  file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
280
 
 
418
 
419
  lyrics_btn.click(
420
  fn=infer_music,
421
+ inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method, preference_infer, edit, edit_segments],
422
  outputs=audio_output
423
  )
424
 
diffrhythm/config/{diffrhythm-1b.json → config.json} RENAMED
@@ -2,7 +2,7 @@
2
  "model_type": "diffrhythm",
3
  "model": {
4
  "dim": 2048,
5
- "depth": 16,
6
  "heads": 32,
7
  "ff_mult": 4,
8
  "text_dim": 512,
 
2
  "model_type": "diffrhythm",
3
  "model": {
4
  "dim": 2048,
5
+ "depth": 16,
6
  "heads": 32,
7
  "ff_mult": 4,
8
  "text_dim": 512,
diffrhythm/infer/infer.py CHANGED
@@ -2,82 +2,51 @@ import torch
2
  import torchaudio
3
  from einops import rearrange
4
  import argparse
5
- import json
6
  import os
7
- from tqdm import tqdm
8
  import random
 
 
 
9
  import numpy as np
10
- import time
11
  import io
12
  import pydub
13
 
14
  from diffrhythm.infer.infer_utils import (
15
- get_reference_latent,
16
  get_lrc_token,
17
- get_audio_style_prompt,
 
 
18
  prepare_model,
19
- get_negative_style_prompt
20
  )
21
 
22
- def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
23
- downsampling_ratio = 2048
24
- io_channels = 2
25
- if not chunked:
26
- # default behavior. Decode the entire latent in parallel
27
- return vae_model.decode_export(latents)
28
- else:
29
- # chunked decoding
30
- hop_size = chunk_size - overlap
31
- total_size = latents.shape[2]
32
- batch_size = latents.shape[0]
33
- chunks = []
34
- i = 0
35
- for i in range(0, total_size - chunk_size + 1, hop_size):
36
- chunk = latents[:,:,i:i+chunk_size]
37
- chunks.append(chunk)
38
- if i+chunk_size != total_size:
39
- # Final chunk
40
- chunk = latents[:,:,-chunk_size:]
41
- chunks.append(chunk)
42
- chunks = torch.stack(chunks)
43
- num_chunks = chunks.shape[0]
44
- # samples_per_latent is just the downsampling ratio
45
- samples_per_latent = downsampling_ratio
46
- # Create an empty waveform, we will populate it with chunks as decode them
47
- y_size = total_size * samples_per_latent
48
- y_final = torch.zeros((batch_size,io_channels,y_size)).to(latents.device)
49
- for i in range(num_chunks):
50
- x_chunk = chunks[i,:]
51
- # decode the chunk
52
- y_chunk = vae_model.decode_export(x_chunk)
53
- # figure out where to put the audio along the time domain
54
- if i == num_chunks-1:
55
- # final chunk always goes at the end
56
- t_end = y_size
57
- t_start = t_end - y_chunk.shape[2]
58
- else:
59
- t_start = i * hop_size * samples_per_latent
60
- t_end = t_start + chunk_size * samples_per_latent
61
- # remove the edges of the overlaps
62
- ol = (overlap//2) * samples_per_latent
63
- chunk_start = 0
64
- chunk_end = y_chunk.shape[2]
65
- if i > 0:
66
- # no overlap for the start of the first chunk
67
- t_start += ol
68
- chunk_start += ol
69
- if i < num_chunks-1:
70
- # no overlap for the end of the last chunk
71
- t_end -= ol
72
- chunk_end -= ol
73
- # paste the chunked audio into our y_final output audio
74
- y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
75
- return y_final
76
-
77
- def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, vocal_flag, odeint_method):
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  with torch.inference_mode():
80
- generated, _ = cfm_model.sample(
81
  cond=cond,
82
  text=text,
83
  duration=duration,
@@ -89,17 +58,27 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
89
  start_time=start_time,
90
  vocal_flag=vocal_flag,
91
  odeint_method=odeint_method,
 
 
92
  )
93
-
94
- generated = generated.to(torch.float32)
95
- latent = generated.transpose(1, 2) # [b d t]
96
- output = decode_audio(latent, vae_model, chunked=False)
97
 
98
- # Rearrange audio batch to a single sequence
99
- output = rearrange(output, "b d n -> d (b n)")
100
- output_tensor = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  output_np = output_tensor.numpy().T.astype(np.float32)
102
-
103
  if file_type == 'wav':
104
  return (44100, output_np)
105
  else:
@@ -111,52 +90,140 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
111
  else:
112
  song.export(buffer, format="ogg", bitrate="320k")
113
  return buffer.getvalue()
114
-
115
 
 
 
116
  if __name__ == "__main__":
117
  parser = argparse.ArgumentParser()
118
- parser.add_argument('--lrc-path', type=str, default="example/eg.lrc") # lyrics of target song
119
- parser.add_argument('--ref-audio-path', type=str, default="example/eg.mp3") # reference audio as style prompt for target song
120
- parser.add_argument('--audio-length', type=int, default=95) # length of target song
121
- parser.add_argument('--output-dir', type=str, default="example/output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  args = parser.parse_args()
123
-
124
- device = 'cuda'
125
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  audio_length = args.audio_length
127
  if audio_length == 95:
128
  max_frames = 2048
129
  elif audio_length == 285:
130
  max_frames = 6144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- cfm, tokenizer, muq, vae = prepare_model(device)
133
-
134
- with open(args.lrc_path, 'r') as f:
135
- lrc = f.read()
136
- lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
137
-
138
- style_prompt = get_audio_style_prompt(muq, args.ref_audio_path)
139
 
140
- negative_style_prompt = get_negative_style_prompt(device)
141
 
142
- latent_prompt = get_reference_latent(device, max_frames)
143
 
144
- s_t = time.time()
145
- generated_song = inference(cfm_model=cfm,
146
- vae_model=vae,
147
- cond=latent_prompt,
148
- text=lrc_prompt,
149
- duration=max_frames,
150
- style_prompt=style_prompt,
151
- negative_style_prompt=negative_style_prompt,
152
- start_time=start_time
153
- )
154
  e_t = time.time() - s_t
155
- print(f"inference cost {e_t} seconds")
156
-
157
  output_dir = args.output_dir
158
  os.makedirs(output_dir, exist_ok=True)
159
-
160
  output_path = os.path.join(output_dir, "output.wav")
161
  torchaudio.save(output_path, generated_song, sample_rate=44100)
162
-
 
2
  import torchaudio
3
  from einops import rearrange
4
  import argparse
 
5
  import os
6
+ import time
7
  import random
8
+
9
+ import torch
10
+ import torchaudio
11
  import numpy as np
12
+ from einops import rearrange
13
  import io
14
  import pydub
15
 
16
  from diffrhythm.infer.infer_utils import (
17
+ decode_audio,
18
  get_lrc_token,
19
+ get_negative_style_prompt,
20
+ get_reference_latent,
21
+ get_style_prompt,
22
  prepare_model,
23
+ eval_song,
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def inference(
28
+ cfm_model,
29
+ vae_model,
30
+ eval_model,
31
+ eval_muq,
32
+ cond,
33
+ text,
34
+ duration,
35
+ style_prompt,
36
+ negative_style_prompt,
37
+ steps,
38
+ cfg_strength,
39
+ sway_sampling_coef,
40
+ start_time,
41
+ file_type,
42
+ vocal_flag,
43
+ odeint_method,
44
+ pred_frames,
45
+ batch_infer_num,
46
+ chunked=True,
47
+ ):
48
  with torch.inference_mode():
49
+ latents, _ = cfm_model.sample(
50
  cond=cond,
51
  text=text,
52
  duration=duration,
 
58
  start_time=start_time,
59
  vocal_flag=vocal_flag,
60
  odeint_method=odeint_method,
61
+ latent_pred_segments=pred_frames,
62
+ batch_infer_num=batch_infer_num
63
  )
 
 
 
 
64
 
65
+ outputs = []
66
+ for latent in latents:
67
+ latent = latent.to(torch.float32)
68
+ latent = latent.transpose(1, 2) # [b d t]
69
+
70
+ output = decode_audio(latent, vae_model, chunked=chunked)
71
+
72
+ # Rearrange audio batch to a single sequence
73
+ output = rearrange(output, "b d n -> d (b n)")
74
+
75
+ outputs.append(output)
76
+ if batch_infer_num > 1:
77
+ generated_song = eval_song(eval_model, eval_muq, outputs)
78
+ else:
79
+ generated_song = outputs[0]
80
+ output_tensor = generated_song.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
81
  output_np = output_tensor.numpy().T.astype(np.float32)
 
82
  if file_type == 'wav':
83
  return (44100, output_np)
84
  else:
 
90
  else:
91
  song.export(buffer, format="ogg", bitrate="320k")
92
  return buffer.getvalue()
 
93
 
94
+
95
+
96
  if __name__ == "__main__":
97
  parser = argparse.ArgumentParser()
98
+ parser.add_argument(
99
+ "--lrc-path",
100
+ type=str,
101
+ help="lyrics of target song",
102
+ ) # lyrics of target song
103
+ parser.add_argument(
104
+ "--ref-prompt",
105
+ type=str,
106
+ help="reference prompt as style prompt for target song",
107
+ required=False,
108
+ ) # reference prompt as style prompt for target song
109
+ parser.add_argument(
110
+ "--ref-audio-path",
111
+ type=str,
112
+ help="reference audio as style prompt for target song",
113
+ required=False,
114
+ ) # reference audio as style prompt for target song
115
+ parser.add_argument(
116
+ "--chunked",
117
+ action="store_true",
118
+ help="whether to use chunked decoding",
119
+ ) # whether to use chunked decoding
120
+ parser.add_argument(
121
+ "--audio-length",
122
+ type=int,
123
+ default=95,
124
+ choices=[95, 285],
125
+ help="length of generated song",
126
+ ) # length of target song
127
+ parser.add_argument(
128
+ "--repo-id", type=str, default="ASLP-lab/DiffRhythm-base", help="target model"
129
+ )
130
+ parser.add_argument(
131
+ "--output-dir",
132
+ type=str,
133
+ default="infer/example/output",
134
+ help="output directory fo generated song",
135
+ ) # output directory of target song
136
+ parser.add_argument(
137
+ "--edit",
138
+ action="store_true",
139
+ help="whether to open edit mode",
140
+ ) # edit flag
141
+ parser.add_argument(
142
+ "--ref-song",
143
+ type=str,
144
+ required=False,
145
+ help="reference prompt as latent prompt for editing",
146
+ ) # reference prompt as latent prompt for editing
147
+ parser.add_argument(
148
+ "--edit-segments",
149
+ type=str,
150
+ required=False,
151
+ help="edit segments o target song",
152
+ ) # edit segments o target song
153
  args = parser.parse_args()
154
+
155
+ assert (
156
+ args.ref_prompt or args.ref_audio_path
157
+ ), "either ref_prompt or ref_audio_path should be provided"
158
+ assert not (
159
+ args.ref_prompt and args.ref_audio_path
160
+ ), "only one of them should be provided"
161
+ if args.edit:
162
+ assert (
163
+ args.ref_song and args.edit_segments
164
+ ), "reference song and edit segments should be provided for editing"
165
+
166
+ device = "cpu"
167
+ if torch.cuda.is_available():
168
+ device = "cuda"
169
+ elif torch.mps.is_available():
170
+ device = "mps"
171
+
172
  audio_length = args.audio_length
173
  if audio_length == 95:
174
  max_frames = 2048
175
  elif audio_length == 285:
176
  max_frames = 6144
177
+
178
+ cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(max_frames, device, repo_id=args.repo_id)
179
+
180
+ if args.lrc_path:
181
+ with open(args.lrc_path, "r", encoding='utf-8') as f:
182
+ lrc = f.read()
183
+ else:
184
+ lrc = ""
185
+ lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
186
+
187
+ if args.ref_audio_path:
188
+ style_prompt = get_style_prompt(muq, args.ref_audio_path)
189
+ else:
190
+ style_prompt = get_style_prompt(muq, prompt=args.ref_prompt)
191
+
192
+ negative_style_prompt = get_negative_style_prompt(device)
193
+
194
+ latent_prompt, pred_frames = get_reference_latent(device, max_frames, args.edit, args.edit_segments, args.ref_song, vae)
195
+
196
+ s_t = time.time()
197
+ generated_songs = inference(
198
+ cfm_model=cfm,
199
+ vae_model=vae,
200
+ cond=latent_prompt,
201
+ text=lrc_prompt,
202
+ duration=max_frames,
203
+ style_prompt=style_prompt,
204
+ negative_style_prompt=negative_style_prompt,
205
+ start_time=start_time,
206
+ pred_frames=pred_frames,
207
+ chunked=args.chunked,
208
+ )
209
 
 
 
 
 
 
 
 
210
 
 
211
 
212
+ generated_song = eval_song(eval_model, eval_muq, generated_songs)
213
 
214
+ # Peak normalize, clip, convert to int16, and save to file
215
+ generated_song = (
216
+ generated_song.to(torch.float32)
217
+ .div(torch.max(torch.abs(generated_song)))
218
+ .clamp(-1, 1)
219
+ .mul(32767)
220
+ .to(torch.int16)
221
+ .cpu()
222
+ )
 
223
  e_t = time.time() - s_t
224
+ print(f"inference cost {e_t:.2f} seconds")
 
225
  output_dir = args.output_dir
226
  os.makedirs(output_dir, exist_ok=True)
227
+
228
  output_path = os.path.join(output_dir, "output.wav")
229
  torchaudio.save(output_path, generated_song, sample_rate=44100)
 
diffrhythm/infer/infer_utils.py CHANGED
@@ -1,66 +1,308 @@
1
  import torch
2
  import librosa
 
3
  import random
4
  import json
5
- from muq import MuQMuLan
6
  from mutagen.mp3 import MP3
7
  import os
8
  import numpy as np
9
  from huggingface_hub import hf_hub_download
 
 
 
 
10
  from diffrhythm.model import DiT, CFM
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def prepare_model(device):
14
  # prepare cfm model
15
- dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-base", filename="cfm_model.pt")
16
- dit_full_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-full", filename="cfm_model.pt")
17
- dit_config_path = "./diffrhythm/config/diffrhythm-1b.json"
18
  with open(dit_config_path) as f:
19
  model_config = json.load(f)
20
  dit_model_cls = DiT
21
  cfm = CFM(
22
- transformer=dit_model_cls(**model_config["model"], use_style_prompt=True, max_pos=2048),
23
  num_channels=model_config["model"]['mel_dim'],
24
- use_style_prompt=True
25
  )
26
  cfm = cfm.to(device)
27
  cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
28
-
29
- cfm_full = CFM(
30
- transformer=dit_model_cls(**model_config["model"], use_style_prompt=True, max_pos=6144),
31
- num_channels=model_config["model"]['mel_dim'],
32
- use_style_prompt=True
33
- )
34
- cfm_full = cfm_full.to(device)
35
- cfm_full = load_checkpoint(cfm_full, dit_full_ckpt_path, device=device, use_ema=False)
36
-
37
  # prepare tokenizer
38
  tokenizer = CNENTokenizer()
39
-
40
  # prepare muq
41
- muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large")
42
  muq = muq.to(device).eval()
43
-
44
  # prepare vae
45
  vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
46
- vae = torch.jit.load(vae_ckpt_path, map_location='cpu').to(device)
 
47
 
48
- return cfm, cfm_full, tokenizer, muq, vae
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  # for song edit, will be added in the future
52
- def get_reference_latent(device, max_frames):
53
- return torch.zeros(1, max_frames, 64).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def get_negative_style_prompt(device):
56
  file_path = "./src/negative_prompt.npy"
57
  vocal_stlye = np.load(file_path)
58
-
59
- vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
60
  vocal_stlye = vocal_stlye.half()
61
-
62
  return vocal_stlye
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def get_audio_style_prompt(model, wav_path):
65
  vocal_flag = False
66
  mulan = model
@@ -85,6 +327,8 @@ def get_audio_style_prompt(model, wav_path):
85
 
86
  return audio_emb, vocal_flag
87
 
 
 
88
  def get_text_style_prompt(model, text_prompt):
89
  mulan = model
90
 
@@ -95,50 +339,88 @@ def get_text_style_prompt(model, text_prompt):
95
  return text_emb
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def parse_lyrics(lyrics: str):
100
  lyrics_with_time = []
101
  lyrics = lyrics.strip()
102
- for line in lyrics.split('\n'):
103
  try:
104
  time, lyric = line[1:9], line[10:]
105
  lyric = lyric.strip()
106
- mins, secs = time.split(':')
107
  secs = int(mins) * 60 + float(secs)
108
  lyrics_with_time.append((secs, lyric))
109
  except:
110
  continue
111
  return lyrics_with_time
112
 
113
- class CNENTokenizer():
 
114
  def __init__(self):
115
- with open('./diffrhythm/g2p/g2p/vocab.json', 'r') as file:
116
- self.phone2id:dict = json.load(file)['vocab']
117
- self.id2phone = {v:k for (k, v) in self.phone2id.items()}
118
  from diffrhythm.g2p.g2p_generation import chn_eng_g2p
 
119
  self.tokenizer = chn_eng_g2p
 
120
  def encode(self, text):
121
  phone, token = self.tokenizer(text)
122
- token = [x+1 for x in token]
123
  return token
 
124
  def decode(self, token):
125
- return "|".join([self.id2phone[x-1] for x in token])
126
-
 
127
  def get_lrc_token(max_frames, text, tokenizer, device):
128
 
129
  lyrics_shift = 0
130
  sampling_rate = 44100
131
  downsample_rate = 2048
132
  max_secs = max_frames / (sampling_rate / downsample_rate)
133
-
134
- pad_token_id = 0
135
  comma_token_id = 1
136
- period_token_id = 2
137
- if text == "":
138
- return torch.zeros((max_frames,), dtype=torch.long).unsqueeze(0).to(device), torch.tensor(0.).unsqueeze(0).to(device).half()
139
 
140
  lrc_with_time = parse_lyrics(text)
141
-
142
  modified_lrc_with_time = []
143
  for i in range(len(lrc_with_time)):
144
  time, line = lrc_with_time[i]
@@ -146,44 +428,49 @@ def get_lrc_token(max_frames, text, tokenizer, device):
146
  modified_lrc_with_time.append((time, line_token))
147
  lrc_with_time = modified_lrc_with_time
148
 
149
- lrc_with_time = [(time_start, line) for (time_start, line) in lrc_with_time if time_start < max_secs]
150
- # lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
151
-
152
- normalized_start_time = 0.
 
 
 
 
 
153
 
154
  lrc = torch.zeros((max_frames,), dtype=torch.long)
155
 
156
  tokens_count = 0
157
  last_end_pos = 0
158
  for time_start, line in lrc_with_time:
159
- tokens = [token if token != period_token_id else comma_token_id for token in line] + [period_token_id]
 
 
160
  tokens = torch.tensor(tokens, dtype=torch.long)
161
  num_tokens = tokens.shape[0]
162
 
163
  gt_frame_start = int(time_start * sampling_rate / downsample_rate)
164
-
165
- frame_shift = random.randint(int(lyrics_shift), int(lyrics_shift))
166
-
167
  frame_start = max(gt_frame_start - frame_shift, last_end_pos)
168
  frame_len = min(num_tokens, max_frames - frame_start)
169
 
170
-
171
-
172
- lrc[frame_start:frame_start + frame_len] = tokens[:frame_len]
173
 
174
  tokens_count += num_tokens
175
- last_end_pos = frame_start + frame_len
176
-
177
  lrc_emb = lrc.unsqueeze(0).to(device)
178
-
179
  normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
180
  normalized_start_time = normalized_start_time.half()
181
-
182
  return lrc_emb, normalized_start_time
183
 
 
184
  def load_checkpoint(model, ckpt_path, device, use_ema=True):
185
- if device == "cuda":
186
- model = model.half()
187
 
188
  ckpt_type = ckpt_path.split(".")[-1]
189
  if ckpt_type == "safetensors":
@@ -207,4 +494,4 @@ def load_checkpoint(model, ckpt_path, device, use_ema=True):
207
  checkpoint = {"model_state_dict": checkpoint}
208
  model.load_state_dict(checkpoint["model_state_dict"], strict=False)
209
 
210
- return model.to(device)
 
1
  import torch
2
  import librosa
3
+ import torchaudio
4
  import random
5
  import json
6
+ from muq import MuQMuLan, MuQ
7
  from mutagen.mp3 import MP3
8
  import os
9
  import numpy as np
10
  from huggingface_hub import hf_hub_download
11
+ from hydra.utils import instantiate
12
+ from omegaconf import OmegaConf
13
+ from safetensors.torch import load_file
14
+
15
  from diffrhythm.model import DiT, CFM
16
 
17
+ def vae_sample(mean, scale):
18
+ stdev = torch.nn.functional.softplus(scale) + 1e-4
19
+ var = stdev * stdev
20
+ logvar = torch.log(var)
21
+ latents = torch.randn_like(mean) * stdev + mean
22
+
23
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
24
+
25
+ return latents, kl
26
+
27
+ def normalize_audio(y, target_dbfs=0):
28
+ max_amplitude = torch.max(torch.abs(y))
29
+
30
+ target_amplitude = 10.0**(target_dbfs / 20.0)
31
+ scale_factor = target_amplitude / max_amplitude
32
+
33
+ normalized_audio = y * scale_factor
34
+
35
+ return normalized_audio
36
+
37
+ def set_audio_channels(audio, target_channels):
38
+ if target_channels == 1:
39
+ # Convert to mono
40
+ audio = audio.mean(1, keepdim=True)
41
+ elif target_channels == 2:
42
+ # Convert to stereo
43
+ if audio.shape[1] == 1:
44
+ audio = audio.repeat(1, 2, 1)
45
+ elif audio.shape[1] > 2:
46
+ audio = audio[:, :2, :]
47
+ return audio
48
+
49
+ class PadCrop(torch.nn.Module):
50
+ def __init__(self, n_samples, randomize=True):
51
+ super().__init__()
52
+ self.n_samples = n_samples
53
+ self.randomize = randomize
54
+
55
+ def __call__(self, signal):
56
+ n, s = signal.shape
57
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
58
+ end = start + self.n_samples
59
+ output = signal.new_zeros([n, self.n_samples])
60
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
61
+ return output
62
+
63
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
64
+
65
+ audio = audio.to(device)
66
+
67
+ if in_sr != target_sr:
68
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
69
+ audio = resample_tf(audio)
70
+ if target_length is None:
71
+ target_length = audio.shape[-1]
72
+ audio = PadCrop(target_length, randomize=False)(audio)
73
+
74
+ # Add batch dimension
75
+ if audio.dim() == 1:
76
+ audio = audio.unsqueeze(0).unsqueeze(0)
77
+ elif audio.dim() == 2:
78
+ audio = audio.unsqueeze(0)
79
+
80
+ audio = set_audio_channels(audio, target_channels)
81
+
82
+ return audio
83
+
84
+ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
85
+ downsampling_ratio = 2048
86
+ io_channels = 2
87
+ if not chunked:
88
+ return vae_model.decode_export(latents)
89
+ else:
90
+ # chunked decoding
91
+ hop_size = chunk_size - overlap
92
+ total_size = latents.shape[2]
93
+ batch_size = latents.shape[0]
94
+ chunks = []
95
+ i = 0
96
+ for i in range(0, total_size - chunk_size + 1, hop_size):
97
+ chunk = latents[:, :, i : i + chunk_size]
98
+ chunks.append(chunk)
99
+ if i + chunk_size != total_size:
100
+ # Final chunk
101
+ chunk = latents[:, :, -chunk_size:]
102
+ chunks.append(chunk)
103
+ chunks = torch.stack(chunks)
104
+ num_chunks = chunks.shape[0]
105
+ # samples_per_latent is just the downsampling ratio
106
+ samples_per_latent = downsampling_ratio
107
+ # Create an empty waveform, we will populate it with chunks as decode them
108
+ y_size = total_size * samples_per_latent
109
+ y_final = torch.zeros((batch_size, io_channels, y_size)).to(latents.device)
110
+ for i in range(num_chunks):
111
+ x_chunk = chunks[i, :]
112
+ # decode the chunk
113
+ y_chunk = vae_model.decode_export(x_chunk)
114
+ # figure out where to put the audio along the time domain
115
+ if i == num_chunks - 1:
116
+ # final chunk always goes at the end
117
+ t_end = y_size
118
+ t_start = t_end - y_chunk.shape[2]
119
+ else:
120
+ t_start = i * hop_size * samples_per_latent
121
+ t_end = t_start + chunk_size * samples_per_latent
122
+ # remove the edges of the overlaps
123
+ ol = (overlap // 2) * samples_per_latent
124
+ chunk_start = 0
125
+ chunk_end = y_chunk.shape[2]
126
+ if i > 0:
127
+ # no overlap for the start of the first chunk
128
+ t_start += ol
129
+ chunk_start += ol
130
+ if i < num_chunks - 1:
131
+ # no overlap for the end of the last chunk
132
+ t_end -= ol
133
+ chunk_end -= ol
134
+ # paste the chunked audio into our y_final output audio
135
+ y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
136
+ return y_final
137
+
138
+ def encode_audio(audio, vae_model, chunked=False, overlap=32, chunk_size=128):
139
+ downsampling_ratio = 2048
140
+ latent_dim = 128
141
+ if not chunked:
142
+ # default behavior. Encode the entire audio in parallel
143
+ return vae_model.encode_export(audio)
144
+ else:
145
+ # CHUNKED ENCODING
146
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
147
+ samples_per_latent = downsampling_ratio
148
+ total_size = audio.shape[2] # in samples
149
+ batch_size = audio.shape[0]
150
+ chunk_size *= samples_per_latent # converting metric in latents to samples
151
+ overlap *= samples_per_latent # converting metric in latents to samples
152
+ hop_size = chunk_size - overlap
153
+ chunks = []
154
+ for i in range(0, total_size - chunk_size + 1, hop_size):
155
+ chunk = audio[:,:,i:i+chunk_size]
156
+ chunks.append(chunk)
157
+ if i+chunk_size != total_size:
158
+ # Final chunk
159
+ chunk = audio[:,:,-chunk_size:]
160
+ chunks.append(chunk)
161
+ chunks = torch.stack(chunks)
162
+ num_chunks = chunks.shape[0]
163
+ # Note: y_size might be a different value from the latent length used in diffusion training
164
+ # because we can encode audio of varying lengths
165
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
166
+ y_size = total_size // samples_per_latent
167
+ # Create an empty latent, we will populate it with chunks as we encode them
168
+ y_final = torch.zeros((batch_size,latent_dim,y_size)).to(audio.device)
169
+ for i in range(num_chunks):
170
+ x_chunk = chunks[i,:]
171
+ # encode the chunk
172
+ y_chunk = vae_model.encode_export(x_chunk)
173
+ # figure out where to put the audio along the time domain
174
+ if i == num_chunks-1:
175
+ # final chunk always goes at the end
176
+ t_end = y_size
177
+ t_start = t_end - y_chunk.shape[2]
178
+ else:
179
+ t_start = i * hop_size // samples_per_latent
180
+ t_end = t_start + chunk_size // samples_per_latent
181
+ # remove the edges of the overlaps
182
+ ol = overlap//samples_per_latent//2
183
+ chunk_start = 0
184
+ chunk_end = y_chunk.shape[2]
185
+ if i > 0:
186
+ # no overlap for the start of the first chunk
187
+ t_start += ol
188
+ chunk_start += ol
189
+ if i < num_chunks-1:
190
+ # no overlap for the end of the last chunk
191
+ t_end -= ol
192
+ chunk_end -= ol
193
+ # paste the chunked audio into our y_final output audio
194
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
195
+ return y_final
196
 
197
  def prepare_model(device):
198
  # prepare cfm model
199
+
200
+ dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-1_2", filename="cfm_model.pt")
201
+ dit_config_path = "./diffrhythm/config/config.json"
202
  with open(dit_config_path) as f:
203
  model_config = json.load(f)
204
  dit_model_cls = DiT
205
  cfm = CFM(
206
+ transformer=dit_model_cls(**model_config["model"], max_frames=2048),
207
  num_channels=model_config["model"]['mel_dim'],
 
208
  )
209
  cfm = cfm.to(device)
210
  cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
211
+
 
 
 
 
 
 
 
 
212
  # prepare tokenizer
213
  tokenizer = CNENTokenizer()
214
+
215
  # prepare muq
216
+ muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./pretrained")
217
  muq = muq.to(device).eval()
218
+
219
  # prepare vae
220
  vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
221
+ vae = torch.jit.load(vae_ckpt_path, map_location="cpu").to(device)
222
+
223
 
224
+ # prepare eval model
225
+ train_config = OmegaConf.load("./pretrained/eval.yaml")
226
+ checkpoint_path = "./pretrained/eval.safetensors"
227
+
228
+ eval_model = instantiate(train_config.generator).to(device).eval()
229
+ state_dict = load_file(checkpoint_path, device="cpu")
230
+ eval_model.load_state_dict(state_dict)
231
+
232
+ eval_muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
233
+ eval_muq = eval_muq.to(device).eval()
234
+
235
+ return cfm, tokenizer, muq, vae, eval_model, eval_muq
236
 
237
 
238
  # for song edit, will be added in the future
239
+ def get_reference_latent(device, max_frames, edit, pred_segments, ref_song, vae_model):
240
+ sampling_rate = 44100
241
+ downsample_rate = 2048
242
+ io_channels = 2
243
+ if edit:
244
+ input_audio, in_sr = torchaudio.load(ref_song)
245
+ input_audio = prepare_audio(input_audio, in_sr=in_sr, target_sr=sampling_rate, target_length=None, target_channels=io_channels, device=device)
246
+ input_audio = normalize_audio(input_audio, -6)
247
+
248
+ with torch.no_grad():
249
+ latent = encode_audio(input_audio, vae_model, chunked=True) # [b d t]
250
+ mean, scale = latent.chunk(2, dim=1)
251
+ prompt, _ = vae_sample(mean, scale)
252
+ prompt = prompt.transpose(1, 2) # [b t d]
253
+
254
+ pred_segments = json.loads(pred_segments)
255
+ # import pdb; pdb.set_trace()
256
+ pred_frames = []
257
+ for st, et in pred_segments:
258
+ sf = 0 if st == -1 else int(st * sampling_rate / downsample_rate)
259
+ # if st == -1:
260
+ # sf = 0
261
+ # else:
262
+ # sf = int(st * sampling_rate / downsample_rate )
263
+
264
+ ef = max_frames if et == -1 else int(et * sampling_rate / downsample_rate)
265
+ # if et == -1:
266
+ # ef = max_frames
267
+ # else:
268
+ # ef = int(et * sampling_rate / downsample_rate )
269
+ pred_frames.append((sf, ef))
270
+ # import pdb; pdb.set_trace()
271
+ return prompt, pred_frames
272
+ else:
273
+ prompt = torch.zeros(1, max_frames, 64).to(device)
274
+ pred_frames = [(0, max_frames)]
275
+ return prompt, pred_frames
276
+
277
 
278
  def get_negative_style_prompt(device):
279
  file_path = "./src/negative_prompt.npy"
280
  vocal_stlye = np.load(file_path)
281
+
282
+ vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
283
  vocal_stlye = vocal_stlye.half()
284
+
285
  return vocal_stlye
286
 
287
+ @torch.no_grad()
288
+ def eval_song(eval_model, eval_muq, songs):
289
+
290
+ resampled_songs = [torchaudio.functional.resample(song.mean(dim=0, keepdim=True), 44100, 24000) for song in songs]
291
+ ssl_list = []
292
+ for i in range(len(resampled_songs)):
293
+ output = eval_muq(resampled_songs[i], output_hidden_states=True)
294
+ muq_ssl = output["hidden_states"][6]
295
+ ssl_list.append(muq_ssl.squeeze(0))
296
+
297
+ ssl = torch.stack(ssl_list)
298
+ scores_g = eval_model(ssl)
299
+ score = torch.mean(scores_g, dim=1)
300
+ idx = score.argmax(dim=0)
301
+
302
+ return songs[idx]
303
+
304
+
305
+ @torch.no_grad()
306
  def get_audio_style_prompt(model, wav_path):
307
  vocal_flag = False
308
  mulan = model
 
327
 
328
  return audio_emb, vocal_flag
329
 
330
+
331
+ @torch.no_grad()
332
  def get_text_style_prompt(model, text_prompt):
333
  mulan = model
334
 
 
339
  return text_emb
340
 
341
 
342
+ @torch.no_grad()
343
+ def get_style_prompt(model, wav_path=None, prompt=None):
344
+ mulan = model
345
+
346
+ if prompt is not None:
347
+ return mulan(texts=prompt).half()
348
+
349
+ ext = os.path.splitext(wav_path)[-1].lower()
350
+ if ext == ".mp3":
351
+ meta = MP3(wav_path)
352
+ audio_len = meta.info.length
353
+ elif ext in [".wav", ".flac"]:
354
+ audio_len = librosa.get_duration(path=wav_path)
355
+ else:
356
+ raise ValueError("Unsupported file format: {}".format(ext))
357
+
358
+ if audio_len < 10:
359
+ print(
360
+ f"Warning: The audio file {wav_path} is too short ({audio_len:.2f} seconds). Expected at least 10 seconds."
361
+ )
362
+
363
+ assert audio_len >= 10
364
+
365
+ mid_time = audio_len // 2
366
+ start_time = mid_time - 5
367
+ wav, _ = librosa.load(wav_path, sr=24000, offset=start_time, duration=10)
368
+
369
+ wav = torch.tensor(wav).unsqueeze(0).to(model.device)
370
+
371
+ with torch.no_grad():
372
+ audio_emb = mulan(wavs=wav) # [1, 512]
373
+
374
+ audio_emb = audio_emb
375
+ audio_emb = audio_emb.half()
376
+
377
+ return audio_emb
378
 
379
  def parse_lyrics(lyrics: str):
380
  lyrics_with_time = []
381
  lyrics = lyrics.strip()
382
+ for line in lyrics.split("\n"):
383
  try:
384
  time, lyric = line[1:9], line[10:]
385
  lyric = lyric.strip()
386
+ mins, secs = time.split(":")
387
  secs = int(mins) * 60 + float(secs)
388
  lyrics_with_time.append((secs, lyric))
389
  except:
390
  continue
391
  return lyrics_with_time
392
 
393
+
394
+ class CNENTokenizer:
395
  def __init__(self):
396
+ with open("./diffrhythm/g2p/g2p/vocab.json", "r", encoding='utf-8') as file:
397
+ self.phone2id: dict = json.load(file)["vocab"]
398
+ self.id2phone = {v: k for (k, v) in self.phone2id.items()}
399
  from diffrhythm.g2p.g2p_generation import chn_eng_g2p
400
+
401
  self.tokenizer = chn_eng_g2p
402
+
403
  def encode(self, text):
404
  phone, token = self.tokenizer(text)
405
+ token = [x + 1 for x in token]
406
  return token
407
+
408
  def decode(self, token):
409
+ return "|".join([self.id2phone[x - 1] for x in token])
410
+
411
+
412
  def get_lrc_token(max_frames, text, tokenizer, device):
413
 
414
  lyrics_shift = 0
415
  sampling_rate = 44100
416
  downsample_rate = 2048
417
  max_secs = max_frames / (sampling_rate / downsample_rate)
418
+
 
419
  comma_token_id = 1
420
+ period_token_id = 2
 
 
421
 
422
  lrc_with_time = parse_lyrics(text)
423
+
424
  modified_lrc_with_time = []
425
  for i in range(len(lrc_with_time)):
426
  time, line = lrc_with_time[i]
 
428
  modified_lrc_with_time.append((time, line_token))
429
  lrc_with_time = modified_lrc_with_time
430
 
431
+ lrc_with_time = [
432
+ (time_start, line)
433
+ for (time_start, line) in lrc_with_time
434
+ if time_start < max_secs
435
+ ]
436
+ if max_frames == 2048:
437
+ lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
438
+
439
+ normalized_start_time = 0.0
440
 
441
  lrc = torch.zeros((max_frames,), dtype=torch.long)
442
 
443
  tokens_count = 0
444
  last_end_pos = 0
445
  for time_start, line in lrc_with_time:
446
+ tokens = [
447
+ token if token != period_token_id else comma_token_id for token in line
448
+ ] + [period_token_id]
449
  tokens = torch.tensor(tokens, dtype=torch.long)
450
  num_tokens = tokens.shape[0]
451
 
452
  gt_frame_start = int(time_start * sampling_rate / downsample_rate)
453
+
454
+ frame_shift = random.randint(int(-lyrics_shift), int(lyrics_shift))
455
+
456
  frame_start = max(gt_frame_start - frame_shift, last_end_pos)
457
  frame_len = min(num_tokens, max_frames - frame_start)
458
 
459
+ lrc[frame_start : frame_start + frame_len] = tokens[:frame_len]
 
 
460
 
461
  tokens_count += num_tokens
462
+ last_end_pos = frame_start + frame_len
463
+
464
  lrc_emb = lrc.unsqueeze(0).to(device)
465
+
466
  normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
467
  normalized_start_time = normalized_start_time.half()
468
+
469
  return lrc_emb, normalized_start_time
470
 
471
+
472
  def load_checkpoint(model, ckpt_path, device, use_ema=True):
473
+ model = model.half()
 
474
 
475
  ckpt_type = ckpt_path.split(".")[-1]
476
  if ckpt_type == "safetensors":
 
494
  checkpoint = {"model_state_dict": checkpoint}
495
  model.load_state_dict(checkpoint["model_state_dict"], strict=False)
496
 
497
+ return model.to(device)
diffrhythm/model/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (290 Bytes)
 
diffrhythm/model/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (508 Bytes)
 
diffrhythm/model/__pycache__/cfm.cpython-310.pyc DELETED
Binary file (6.28 kB)
 
diffrhythm/model/__pycache__/cfm.cpython-312.pyc DELETED
Binary file (10.7 kB)
 
diffrhythm/model/__pycache__/custom_dataset.cpython-310.pyc DELETED
Binary file (11.5 kB)
 
diffrhythm/model/__pycache__/custom_dataset_lrc_emb.cpython-310.pyc DELETED
Binary file (10.5 kB)
 
diffrhythm/model/__pycache__/dataset.cpython-310.pyc DELETED
Binary file (8.04 kB)
 
diffrhythm/model/__pycache__/dit.cpython-310.pyc DELETED
Binary file (5.61 kB)
 
diffrhythm/model/__pycache__/modules.cpython-310.pyc DELETED
Binary file (15.9 kB)
 
diffrhythm/model/__pycache__/trainer.cpython-310.pyc DELETED
Binary file (9.13 kB)
 
diffrhythm/model/__pycache__/utils.cpython-310.pyc DELETED
Binary file (6.03 kB)
 
diffrhythm/model/cfm.py CHANGED
@@ -1,10 +1,22 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
  from __future__ import annotations
@@ -19,9 +31,7 @@ from torch.nn.utils.rnn import pad_sequence
19
 
20
  from torchdiffeq import odeint
21
 
22
- from diffrhythm.model.modules import MelSpec
23
  from diffrhythm.model.utils import (
24
- default,
25
  exists,
26
  list_str_to_idx,
27
  list_str_to_tensor,
@@ -29,12 +39,25 @@ from diffrhythm.model.utils import (
29
  mask_from_frac_lengths,
30
  )
31
 
32
- def custom_mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"], device, max_seq_len): # noqa: F722 F821
 
 
 
 
 
33
  max_seq_len = max_seq_len
34
  seq = torch.arange(max_seq_len, device=device).long()
35
- start_mask = seq[None, :] >= start[:, None]
36
- end_mask = seq[None, :] < end[:, None]
37
- return start_mask & end_mask
 
 
 
 
 
 
 
 
38
 
39
  class CFM(nn.Module):
40
  def __init__(
@@ -42,7 +65,7 @@ class CFM(nn.Module):
42
  transformer: nn.Module,
43
  sigma=0.0,
44
  odeint_kwargs: dict = dict(
45
- method="euler" # 'midpoint'
46
  ),
47
  odeint_options: dict = dict(
48
  min_step=0.05
@@ -54,7 +77,7 @@ class CFM(nn.Module):
54
  num_channels=None,
55
  frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
56
  vocab_char_map: dict[str:int] | None = None,
57
- use_style_prompt: bool = False
58
  ):
59
  super().__init__()
60
 
@@ -83,8 +106,8 @@ class CFM(nn.Module):
83
 
84
  # vocab map for tokenization
85
  self.vocab_char_map = vocab_char_map
86
-
87
- self.use_style_prompt = use_style_prompt
88
 
89
  @property
90
  def device(self):
@@ -112,10 +135,10 @@ class CFM(nn.Module):
112
  t_inter=0.1,
113
  edit_mask=None,
114
  start_time=None,
115
- latent_pred_start_frame=0,
116
- latent_pred_end_frame=2048,
117
  vocal_flag=False,
118
- odeint_method="euler"
 
119
  ):
120
  self.eval()
121
 
@@ -125,7 +148,6 @@ class CFM(nn.Module):
125
  cond = cond.half()
126
 
127
  # raw wave
128
-
129
  if cond.shape[1] > duration:
130
  cond = cond[:, :duration, :]
131
 
@@ -139,7 +161,6 @@ class CFM(nn.Module):
139
  lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
140
 
141
  # text
142
-
143
  if isinstance(text, list):
144
  if exists(self.vocab_char_map):
145
  text = list_str_to_idx(text, self.vocab_char_map).to(device)
@@ -147,26 +168,18 @@ class CFM(nn.Module):
147
  text = list_str_to_tensor(text).to(device)
148
  assert text.shape[0] == batch
149
 
150
- if exists(text):
151
- text_lens = (text != -1).sum(dim=-1)
152
-
153
-
154
  # duration
155
  cond_mask = lens_to_mask(lens)
156
  if edit_mask is not None:
157
  cond_mask = cond_mask & edit_mask
158
 
159
- latent_pred_start_frame = torch.tensor([latent_pred_start_frame]).to(cond.device)
160
- latent_pred_end_frame = duration
161
- latent_pred_end_frame = torch.tensor([latent_pred_end_frame]).to(cond.device)
162
- fixed_span_mask = custom_mask_from_start_end_indices(cond_seq_len, latent_pred_start_frame, latent_pred_end_frame, device=cond.device, max_seq_len=duration)
163
-
164
  fixed_span_mask = fixed_span_mask.unsqueeze(-1)
165
  step_cond = torch.where(fixed_span_mask, torch.zeros_like(cond), cond)
166
 
167
  if isinstance(duration, int):
168
- duration = torch.full((batch,), duration, device=device, dtype=torch.long)
169
-
170
 
171
  duration = duration.clamp(max=max_duration)
172
  max_duration = duration.amax()
@@ -175,7 +188,6 @@ class CFM(nn.Module):
175
  if duplicate_test:
176
  test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
177
 
178
-
179
  if batch > 1:
180
  mask = lens_to_mask(duration)
181
  else: # save memory and speed up, as single inference need no mask currently
@@ -184,20 +196,27 @@ class CFM(nn.Module):
184
  # test for no ref audio
185
  if no_ref_audio:
186
  cond = torch.zeros_like(cond)
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
189
  _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
190
 
191
- if vocal_flag:
192
- style_prompt = negative_style_prompt
193
- negative_style_prompt = torch.zeros_like(style_prompt)
194
-
195
  text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
196
  text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
197
  step_cond = torch.cat([step_cond, step_cond], 0)
198
  style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
199
  start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
200
-
201
 
202
  def fn(t, x):
203
  x = torch.cat([x, x], 0)
@@ -228,7 +247,7 @@ class CFM(nn.Module):
228
  t_start = t_inter
229
  y0 = (1 - t_start) * y0 + t_start * test_cond
230
  steps = int(steps * (1 - t_start))
231
-
232
  t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
233
  if sway_sampling_coef is not None:
234
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
@@ -243,6 +262,7 @@ class CFM(nn.Module):
243
  out = out.permute(0, 2, 1)
244
  out = vocoder(out)
245
 
 
246
  return out, trajectory
247
 
248
  def forward(
@@ -267,11 +287,10 @@ class CFM(nn.Module):
267
 
268
  # get a random span to mask out for training conditionally
269
  frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
270
- rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
271
 
272
  if exists(mask):
273
  rand_span_mask = mask
274
- # rand_span_mask &= mask
275
 
276
  # mel is x1
277
  x1 = inp
@@ -301,7 +320,7 @@ class CFM(nn.Module):
301
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
302
  pred = self.transformer(
303
  x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, drop_prompt=drop_prompt,
304
- style_prompt=style_prompt, style_prompt_lens=style_prompt_lens, grad_ckpt=grad_ckpt, start_time=start_time
305
  )
306
 
307
  # flow matching loss
 
1
+ # Copyright (c) 2025 ASLP-LAB
2
+ # 2025 Ziqian Ning (ningziqian@mail.nwpu.edu.cn)
3
+ # 2025 Huakang Chen (huakang@mail.nwpu.edu.cn)
4
+ # 2025 Guobin Ma (guobin.ma@mail.nwpu.edu.cn)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """ This implementation is adapted from github repo:
19
+ https://github.com/SWivid/F5-TTS.
20
  """
21
 
22
  from __future__ import annotations
 
31
 
32
  from torchdiffeq import odeint
33
 
 
34
  from diffrhythm.model.utils import (
 
35
  exists,
36
  list_str_to_idx,
37
  list_str_to_tensor,
 
39
  mask_from_frac_lengths,
40
  )
41
 
42
+ def custom_mask_from_start_end_indices(
43
+ seq_len: int["b"], # noqa: F821
44
+ latent_pred_segments,
45
+ device,
46
+ max_seq_len
47
+ ):
48
  max_seq_len = max_seq_len
49
  seq = torch.arange(max_seq_len, device=device).long()
50
+
51
+ res_mask = torch.zeros(max_seq_len, device=device, dtype=torch.bool)
52
+
53
+ for start, end in latent_pred_segments:
54
+ start = start.unsqueeze(0)
55
+ end = end.unsqueeze(0)
56
+ start_mask = seq[None, :] >= start[:, None]
57
+ end_mask = seq[None, :] < end[:, None]
58
+ res_mask = res_mask | (start_mask & end_mask)
59
+
60
+ return res_mask
61
 
62
  class CFM(nn.Module):
63
  def __init__(
 
65
  transformer: nn.Module,
66
  sigma=0.0,
67
  odeint_kwargs: dict = dict(
68
+ method="euler"
69
  ),
70
  odeint_options: dict = dict(
71
  min_step=0.05
 
77
  num_channels=None,
78
  frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
79
  vocab_char_map: dict[str:int] | None = None,
80
+ max_frames=2048
81
  ):
82
  super().__init__()
83
 
 
106
 
107
  # vocab map for tokenization
108
  self.vocab_char_map = vocab_char_map
109
+
110
+ self.max_frames = max_frames
111
 
112
  @property
113
  def device(self):
 
135
  t_inter=0.1,
136
  edit_mask=None,
137
  start_time=None,
138
+ latent_pred_segments=None,
 
139
  vocal_flag=False,
140
+ odeint_method="euler",
141
+ batch_infer_num=5
142
  ):
143
  self.eval()
144
 
 
148
  cond = cond.half()
149
 
150
  # raw wave
 
151
  if cond.shape[1] > duration:
152
  cond = cond[:, :duration, :]
153
 
 
161
  lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
162
 
163
  # text
 
164
  if isinstance(text, list):
165
  if exists(self.vocab_char_map):
166
  text = list_str_to_idx(text, self.vocab_char_map).to(device)
 
168
  text = list_str_to_tensor(text).to(device)
169
  assert text.shape[0] == batch
170
 
 
 
 
 
171
  # duration
172
  cond_mask = lens_to_mask(lens)
173
  if edit_mask is not None:
174
  cond_mask = cond_mask & edit_mask
175
 
176
+ latent_pred_segments = torch.tensor(latent_pred_segments).to(cond.device)
177
+ fixed_span_mask = custom_mask_from_start_end_indices(cond_seq_len, latent_pred_segments, device=cond.device, max_seq_len=duration)
 
 
 
178
  fixed_span_mask = fixed_span_mask.unsqueeze(-1)
179
  step_cond = torch.where(fixed_span_mask, torch.zeros_like(cond), cond)
180
 
181
  if isinstance(duration, int):
182
+ duration = torch.full((batch_infer_num,), duration, device=device, dtype=torch.long)
 
183
 
184
  duration = duration.clamp(max=max_duration)
185
  max_duration = duration.amax()
 
188
  if duplicate_test:
189
  test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
190
 
 
191
  if batch > 1:
192
  mask = lens_to_mask(duration)
193
  else: # save memory and speed up, as single inference need no mask currently
 
196
  # test for no ref audio
197
  if no_ref_audio:
198
  cond = torch.zeros_like(cond)
199
+
200
+ if vocal_flag:
201
+ style_prompt = negative_style_prompt
202
+ negative_style_prompt = torch.zeros_like(style_prompt)
203
+
204
+ cond = cond.repeat(batch_infer_num, 1, 1)
205
+ step_cond = step_cond.repeat(batch_infer_num, 1, 1)
206
+ text = text.repeat(batch_infer_num, 1)
207
+ style_prompt = style_prompt.repeat(batch_infer_num, 1)
208
+ negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
209
+ start_time = start_time.repeat(batch_infer_num)
210
+ fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
211
 
212
  start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
213
  _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
214
 
 
 
 
 
215
  text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
216
  text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
217
  step_cond = torch.cat([step_cond, step_cond], 0)
218
  style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
219
  start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
 
220
 
221
  def fn(t, x):
222
  x = torch.cat([x, x], 0)
 
247
  t_start = t_inter
248
  y0 = (1 - t_start) * y0 + t_start * test_cond
249
  steps = int(steps * (1 - t_start))
250
+
251
  t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
252
  if sway_sampling_coef is not None:
253
  t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
 
262
  out = out.permute(0, 2, 1)
263
  out = vocoder(out)
264
 
265
+ out = torch.chunk(out, batch_infer_num, dim=0)
266
  return out, trajectory
267
 
268
  def forward(
 
287
 
288
  # get a random span to mask out for training conditionally
289
  frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
290
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, self.max_frames)
291
 
292
  if exists(mask):
293
  rand_span_mask = mask
 
294
 
295
  # mel is x1
296
  x1 = inp
 
320
  # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
321
  pred = self.transformer(
322
  x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, drop_prompt=drop_prompt,
323
+ style_prompt=style_prompt, start_time=start_time
324
  )
325
 
326
  # flow matching loss
diffrhythm/model/dit.py CHANGED
@@ -1,10 +1,22 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
 
 
 
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
  from __future__ import annotations
@@ -12,22 +24,19 @@ from __future__ import annotations
12
  import torch
13
  from torch import nn
14
  import torch
15
- import torch.nn.functional as F
16
  from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
17
  from transformers.models.llama import LlamaConfig
18
- from torch.utils.checkpoint import checkpoint
19
 
20
  from diffrhythm.model.modules import (
21
  TimestepEmbedding,
22
  ConvNeXtV2Block,
23
  ConvPositionEmbedding,
24
- DiTBlock,
25
  AdaLayerNormZero_Final,
26
  precompute_freqs_cis,
27
  get_pos_embed_indices,
 
28
  )
29
- # from liger_kernel.transformers import apply_liger_kernel_to_llama
30
- # apply_liger_kernel_to_llama()
31
 
32
  # Text embedding
33
  class TextEmbedding(nn.Module):
@@ -77,7 +86,6 @@ class InputEmbedding(nn.Module):
77
  def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], style_emb, time_emb, drop_audio_cond=False): # noqa: F722
78
  if drop_audio_cond: # cfg for cond audio
79
  cond = torch.zeros_like(cond)
80
-
81
  style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
82
  time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
83
  x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
@@ -85,9 +93,7 @@ class InputEmbedding(nn.Module):
85
  return x
86
 
87
 
88
- # Transformer backbone using DiT blocks
89
-
90
-
91
  class DiT(nn.Module):
92
  def __init__(
93
  self,
@@ -103,26 +109,25 @@ class DiT(nn.Module):
103
  text_dim=None,
104
  conv_layers=0,
105
  long_skip_connection=False,
106
- use_style_prompt=False,
107
- max_pos=2048,
108
  ):
109
  super().__init__()
 
 
110
 
111
  cond_dim = 512
112
  self.time_embed = TimestepEmbedding(cond_dim)
113
  self.start_time_embed = TimestepEmbedding(cond_dim)
114
  if text_dim is None:
115
  text_dim = mel_dim
116
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers, max_pos=max_pos)
117
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
118
 
119
-
120
  self.dim = dim
121
  self.depth = depth
122
 
123
- llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings=max_pos)
124
  llama_config._attn_implementation = 'sdpa'
125
-
126
  self.transformer_blocks = nn.ModuleList(
127
  [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
128
  )
@@ -144,7 +149,6 @@ class DiT(nn.Module):
144
  self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation
145
  self.proj_out = nn.Linear(dim, mel_dim)
146
 
147
-
148
  def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
149
  s_t = self.start_time_embed(start_time)
150
  text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
@@ -187,11 +191,22 @@ class DiT(nn.Module):
187
  pos_ids = torch.arange(x.shape[1], device=x.device)
188
  pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
189
  rotary_embed = self.rotary_emb(x, pos_ids)
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  for i, block in enumerate(self.transformer_blocks):
192
- x, *_ = block(x, position_embeddings=rotary_embed)
193
  if i < self.depth // 2:
194
- x = x + text_residuals[i]
195
 
196
  if self.long_skip_connection is not None:
197
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
 
1
+ # Copyright (c) 2025 ASLP-LAB
2
+ # 2025 Ziqian Ning (ningziqian@mail.nwpu.edu.cn)
3
+ # 2025 Huakang Chen (huakang@mail.nwpu.edu.cn)
4
+ # 2025 Yuepeng Jiang (Jiangyp@mail.nwpu.edu.cn)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """ This implementation is adapted from github repo:
19
+ https://github.com/SWivid/F5-TTS.
20
  """
21
 
22
  from __future__ import annotations
 
24
  import torch
25
  from torch import nn
26
  import torch
27
+
28
  from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
29
  from transformers.models.llama import LlamaConfig
 
30
 
31
  from diffrhythm.model.modules import (
32
  TimestepEmbedding,
33
  ConvNeXtV2Block,
34
  ConvPositionEmbedding,
 
35
  AdaLayerNormZero_Final,
36
  precompute_freqs_cis,
37
  get_pos_embed_indices,
38
+ _prepare_decoder_attention_mask,
39
  )
 
 
40
 
41
  # Text embedding
42
  class TextEmbedding(nn.Module):
 
86
  def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], style_emb, time_emb, drop_audio_cond=False): # noqa: F722
87
  if drop_audio_cond: # cfg for cond audio
88
  cond = torch.zeros_like(cond)
 
89
  style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
90
  time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
91
  x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
 
93
  return x
94
 
95
 
96
+ # Transformer backbone using Llama blocks
 
 
97
  class DiT(nn.Module):
98
  def __init__(
99
  self,
 
109
  text_dim=None,
110
  conv_layers=0,
111
  long_skip_connection=False,
112
+ max_frames=2048
 
113
  ):
114
  super().__init__()
115
+
116
+ self.max_frames = max_frames
117
 
118
  cond_dim = 512
119
  self.time_embed = TimestepEmbedding(cond_dim)
120
  self.start_time_embed = TimestepEmbedding(cond_dim)
121
  if text_dim is None:
122
  text_dim = mel_dim
123
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers, max_pos=self.max_frames)
124
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
125
 
 
126
  self.dim = dim
127
  self.depth = depth
128
 
129
+ llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings=self.max_frames)
130
  llama_config._attn_implementation = 'sdpa'
 
131
  self.transformer_blocks = nn.ModuleList(
132
  [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
133
  )
 
149
  self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation
150
  self.proj_out = nn.Linear(dim, mel_dim)
151
 
 
152
  def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
153
  s_t = self.start_time_embed(start_time)
154
  text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
 
191
  pos_ids = torch.arange(x.shape[1], device=x.device)
192
  pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
193
  rotary_embed = self.rotary_emb(x, pos_ids)
194
+
195
+ attention_mask = torch.ones(
196
+ (batch, seq_len),
197
+ dtype=torch.bool,
198
+ device=x.device,
199
+ )
200
+ attention_mask = _prepare_decoder_attention_mask(
201
+ attention_mask,
202
+ (batch, seq_len),
203
+ x,
204
+ )
205
 
206
  for i, block in enumerate(self.transformer_blocks):
207
+ x, *_ = block(x, attention_mask=attention_mask, position_embeddings=rotary_embed)
208
  if i < self.depth // 2:
209
+ x = x + self.text_fusion_linears[i](text_embed)
210
 
211
  if self.long_skip_connection is not None:
212
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
diffrhythm/model/modules.py CHANGED
@@ -609,3 +609,44 @@ class TimestepEmbedding(nn.Module):
609
  time_hidden = time_hidden.to(timestep.dtype)
610
  time = self.time_mlp(time_hidden) # b d
611
  return time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  time_hidden = time_hidden.to(timestep.dtype)
610
  time = self.time_mlp(time_hidden) # b d
611
  return time
612
+
613
+
614
+ # attention mask realated
615
+
616
+
617
+ def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds):
618
+ # create noncausal mask
619
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
620
+ combined_attention_mask = None
621
+
622
+ def _expand_mask(
623
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None
624
+ ):
625
+ """
626
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
627
+ """
628
+ bsz, src_len = mask.size()
629
+ tgt_len = tgt_len if tgt_len is not None else src_len
630
+
631
+ expanded_mask = (
632
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
633
+ )
634
+
635
+ inverted_mask = 1.0 - expanded_mask
636
+
637
+ return inverted_mask.masked_fill(
638
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
639
+ )
640
+
641
+ if attention_mask is not None:
642
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
643
+ expanded_attn_mask = _expand_mask(
644
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
645
+ ).to(inputs_embeds.device)
646
+ combined_attention_mask = (
647
+ expanded_attn_mask
648
+ if combined_attention_mask is None
649
+ else expanded_attn_mask + combined_attention_mask
650
+ )
651
+
652
+ return combined_attention_mask
diffrhythm/model/utils.py CHANGED
@@ -44,15 +44,15 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa
44
  return seq[None, :] < t[:, None]
45
 
46
 
47
- def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
48
- max_seq_len = 2048
49
  seq = torch.arange(max_seq_len, device=start.device).long()
50
  start_mask = seq[None, :] >= start[:, None]
51
  end_mask = seq[None, :] < end[:, None]
52
  return start_mask & end_mask
53
 
54
 
55
- def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
56
  lengths = (frac_lengths * seq_len).long()
57
  max_start = seq_len - lengths
58
 
@@ -60,7 +60,7 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa
60
  start = (max_start * rand).long().clamp(min=0)
61
  end = start + lengths
62
 
63
- return mask_from_start_end_indices(seq_len, start, end)
64
 
65
 
66
  def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
 
44
  return seq[None, :] < t[:, None]
45
 
46
 
47
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"], max_frames): # noqa: F722 F821
48
+ max_seq_len = max_frames
49
  seq = torch.arange(max_seq_len, device=start.device).long()
50
  start_mask = seq[None, :] >= start[:, None]
51
  end_mask = seq[None, :] < end[:, None]
52
  return start_mask & end_mask
53
 
54
 
55
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"], max_frames): # noqa: F722 F821
56
  lengths = (frac_lengths * seq_len).long()
57
  max_start = seq_len - lengths
58
 
 
60
  start = (max_start * rand).long().clamp(min=0)
61
  end = start + lengths
62
 
63
+ return mask_from_start_end_indices(seq_len, start, end, max_frames)
64
 
65
 
66
  def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
pretrained/eval.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class Generator(nn.Module):
8
+
9
+ def __init__(self,
10
+ in_features,
11
+ ffd_hidden_size,
12
+ num_classes,
13
+ attn_layer_num,
14
+
15
+ ):
16
+ super(Generator, self).__init__()
17
+
18
+ self.attn = nn.ModuleList(
19
+ [
20
+ nn.MultiheadAttention(
21
+ embed_dim=in_features,
22
+ num_heads=8,
23
+ dropout=0.2,
24
+ batch_first=True,
25
+ )
26
+ for _ in range(attn_layer_num)
27
+ ]
28
+ )
29
+
30
+ self.ffd = nn.Sequential(
31
+ nn.Linear(in_features, ffd_hidden_size),
32
+ nn.ReLU(),
33
+ nn.Linear(ffd_hidden_size, in_features)
34
+ )
35
+
36
+ self.dropout = nn.Dropout(0.2)
37
+
38
+ self.fc = nn.Linear(in_features * 2, num_classes)
39
+
40
+ self.proj = nn.Tanh()
41
+
42
+
43
+ def forward(self, ssl_feature, judge_id=None):
44
+ '''
45
+ ssl_feature: [B, T, D]
46
+ output: [B, num_classes]
47
+ '''
48
+
49
+ B, T, D = ssl_feature.shape
50
+
51
+ ssl_feature = self.ffd(ssl_feature)
52
+
53
+ tmp_ssl_feature = ssl_feature
54
+
55
+ for attn in self.attn:
56
+ tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
57
+
58
+ ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1)) # B, 2D
59
+
60
+ x = self.fc(ssl_feature) # B, num_classes
61
+
62
+ x = self.proj(x) * 2.0 + 3
63
+
64
+ return x
65
+
66
+
pretrained/eval.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81cbd54af8b103251e425fcbd8f5313975cb742e760c3dae1e10f99969933fd6
3
+ size 100792276
pretrained/eval.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ generator:
2
+ _target_: pretrained.eval.Generator
3
+ in_features: 1024
4
+ ffd_hidden_size: 4096
5
+ num_classes: 5
6
+ attn_layer_num: 4