Sayoyo commited on
Commit
48bb372
·
1 Parent(s): 917a84f

[feat] add repainting & edit

Browse files
Files changed (2) hide show
  1. pipeline_ace_step.py +401 -50
  2. ui/components.py +254 -13
pipeline_ace_step.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import re
5
 
6
  import torch
 
7
  from loguru import logger
8
  from tqdm import tqdm
9
  import json
@@ -22,11 +23,11 @@ from models.ace_step_transformer import ACEStepTransformer2DModel
22
  from models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
23
  from apg_guidance import apg_forward, MomentumBuffer, cfg_forward, cfg_zero_star, cfg_double_condition_forward
24
  import torchaudio
 
 
25
  torch.backends.cudnn.benchmark = False
26
  torch.set_float32_matmul_precision('high')
27
-
28
- # Enable TF32 for faster training on Ampere GPUs,
29
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
30
  torch.backends.cuda.matmul.allow_tf32 = True
31
 
32
 
@@ -49,9 +50,10 @@ def ensure_directory_exists(directory):
49
  REPO_ID = "ACE-Step/ACE-Step-v1-3.5B"
50
 
51
 
 
52
  class ACEStepPipeline:
53
 
54
- def __init__(self, checkpoint_dir=None, device_id=0, dtype="bfloat16", text_encoder_checkpoint_path=None, persistent_storage_path=None, **kwargs):
55
  if checkpoint_dir is None:
56
  if persistent_storage_path is None:
57
  checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
@@ -64,6 +66,7 @@ class ACEStepPipeline:
64
  self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
65
  self.device = device
66
  self.loaded = False
 
67
 
68
  def load_checkpoint(self, checkpoint_dir=None):
69
  device = self.device
@@ -157,9 +160,10 @@ class ACEStepPipeline:
157
  self.loaded = True
158
 
159
  # compile
160
- self.music_dcae = torch.compile(self.music_dcae)
161
- self.ace_step_transformer = torch.compile(self.ace_step_transformer)
162
- self.text_encoder_model = torch.compile(self.text_encoder_model)
 
163
 
164
  def get_text_embeddings(self, texts, device, text_max_length=256):
165
  inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
@@ -226,7 +230,7 @@ class ACEStepPipeline:
226
 
227
  def get_lang(self, text):
228
  language = "en"
229
- try:
230
  _ = self.lang_segment.getTexts(text)
231
  langCounts = self.lang_segment.getCounts()
232
  language = langCounts[0][0]
@@ -267,6 +271,250 @@ class ACEStepPipeline:
267
  print("tokenize error", e, "for line", line, "major_language", lang)
268
  return lyric_token_idx
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  @torch.no_grad()
271
  def text2music_diffusion_process(
272
  self,
@@ -296,13 +544,16 @@ class ACEStepPipeline:
296
  add_retake_noise=False,
297
  guidance_scale_text=0.0,
298
  guidance_scale_lyric=0.0,
 
 
 
299
  ):
300
 
301
  logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
302
  do_classifier_free_guidance = True
303
  if guidance_scale == 0.0 or guidance_scale == 1.0:
304
  do_classifier_free_guidance = False
305
-
306
  do_double_condition_guidance = False
307
  if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
308
  do_double_condition_guidance = True
@@ -322,7 +573,10 @@ class ACEStepPipeline:
322
  num_train_timesteps=1000,
323
  shift=3.0,
324
  )
 
325
  frame_length = int(duration * 44100 / 512 / 8)
 
 
326
 
327
  if len(oss_steps) > 0:
328
  infer_steps = max(oss_steps)
@@ -337,16 +591,30 @@ class ACEStepPipeline:
337
  logger.info(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
338
  else:
339
  timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
340
-
341
  target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
 
 
342
  if add_retake_noise:
343
  retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
344
  retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
 
 
 
 
 
345
  # to make sure mean = 0, std = 1
346
- target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
 
 
 
 
 
 
 
347
 
348
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
349
-
350
  # guidance interval
351
  start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
352
  end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
@@ -356,20 +624,20 @@ class ACEStepPipeline:
356
 
357
  def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
358
  handlers = []
359
-
360
  def hook(module, input, output):
361
  output[:] *= tau
362
  return output
363
-
364
  for i in range(l_min, l_max):
365
  handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
366
  handlers.append(handler)
367
-
368
  encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
369
-
370
  for hook in handlers:
371
  hook.remove()
372
-
373
  return encoder_hidden_states
374
 
375
  # P(speaker, text, lyric)
@@ -402,7 +670,7 @@ class ACEStepPipeline:
402
  torch.zeros_like(lyric_token_ids),
403
  lyric_mask,
404
  )
405
-
406
  encoder_hidden_states_no_lyric = None
407
  if do_double_condition_guidance:
408
  # P(null_speaker, text, lyric_weaker)
@@ -429,11 +697,11 @@ class ACEStepPipeline:
429
 
430
  def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
431
  handlers = []
432
-
433
  def hook(module, input, output):
434
  output[:] *= tau
435
  return output
436
-
437
  for i in range(l_min, l_max):
438
  handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
439
  handlers.append(handler)
@@ -441,10 +709,10 @@ class ACEStepPipeline:
441
  handlers.append(handler)
442
 
443
  sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
444
-
445
  for hook in handlers:
446
  hook.remove()
447
-
448
  return sample
449
 
450
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
@@ -551,7 +819,13 @@ class ACEStepPipeline:
551
  ).sample
552
 
553
  target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
 
 
 
 
 
554
 
 
555
  return target_latents
556
 
557
  def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
@@ -570,7 +844,7 @@ class ACEStepPipeline:
570
  def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="flac"):
571
  if save_path is None:
572
  logger.warning("save_path is None, using default path ./outputs/")
573
- base_path = f"./outputs/"
574
  ensure_directory_exists(base_path)
575
  else:
576
  base_path = save_path
@@ -581,6 +855,16 @@ class ACEStepPipeline:
581
  torchaudio.save(output_path_flac, target_wav, sample_rate=sample_rate, format=format)
582
  return output_path_flac
583
 
 
 
 
 
 
 
 
 
 
 
584
  def __call__(
585
  self,
586
  audio_duration: float = 60.0,
@@ -604,6 +888,14 @@ class ACEStepPipeline:
604
  retake_seeds: list = None,
605
  retake_variance: float = 0.5,
606
  task: str = "text2music",
 
 
 
 
 
 
 
 
607
  save_path: str = None,
608
  format: str = "flac",
609
  batch_size: int = 1,
@@ -626,7 +918,7 @@ class ACEStepPipeline:
626
  oss_steps = list(map(int, oss_steps.split(",")))
627
  else:
628
  oss_steps = []
629
-
630
  texts = [prompt]
631
  encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
632
  encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
@@ -657,32 +949,83 @@ class ACEStepPipeline:
657
  preprocess_time_cost = end_time - start_time
658
  start_time = end_time
659
 
660
- target_latents = self.text2music_diffusion_process(
661
- duration=audio_duration,
662
- encoder_text_hidden_states=encoder_text_hidden_states,
663
- text_attention_mask=text_attention_mask,
664
- speaker_embds=speaker_embeds,
665
- lyric_token_ids=lyric_token_idx,
666
- lyric_mask=lyric_mask,
667
- guidance_scale=guidance_scale,
668
- omega_scale=omega_scale,
669
- infer_steps=infer_step,
670
- random_generators=random_generators,
671
- scheduler_type=scheduler_type,
672
- cfg_type=cfg_type,
673
- guidance_interval=guidance_interval,
674
- guidance_interval_decay=guidance_interval_decay,
675
- min_guidance_scale=min_guidance_scale,
676
- oss_steps=oss_steps,
677
- encoder_text_hidden_states_null=encoder_text_hidden_states_null,
678
- use_erg_lyric=use_erg_lyric,
679
- use_erg_diffusion=use_erg_diffusion,
680
- retake_random_generators=retake_random_generators,
681
- retake_variance=retake_variance,
682
- add_retake_noise=task == "retake",
683
- guidance_scale_text=guidance_scale_text,
684
- guidance_scale_lyric=guidance_scale_lyric,
685
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
687
  end_time = time.time()
688
  diffusion_time_cost = end_time - start_time
@@ -726,6 +1069,14 @@ class ACEStepPipeline:
726
  "retake_variance": retake_variance,
727
  "guidance_scale_text": guidance_scale_text,
728
  "guidance_scale_lyric": guidance_scale_lyric,
 
 
 
 
 
 
 
 
729
  }
730
  # save input_params_json
731
  for output_audio_path in output_paths:
 
4
  import re
5
 
6
  import torch
7
+ import torch.nn as nn
8
  from loguru import logger
9
  from tqdm import tqdm
10
  import json
 
23
  from models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer
24
  from apg_guidance import apg_forward, MomentumBuffer, cfg_forward, cfg_zero_star, cfg_double_condition_forward
25
  import torchaudio
26
+
27
+
28
  torch.backends.cudnn.benchmark = False
29
  torch.set_float32_matmul_precision('high')
30
+ torch.backends.cudnn.deterministic = True
 
 
31
  torch.backends.cuda.matmul.allow_tf32 = True
32
 
33
 
 
50
  REPO_ID = "ACE-Step/ACE-Step-v1-3.5B"
51
 
52
 
53
+ # class ACEStepPipeline(DiffusionPipeline):
54
  class ACEStepPipeline:
55
 
56
+ def __init__(self, checkpoint_dir=None, device_id=0, dtype="bfloat16", text_encoder_checkpoint_path=None, persistent_storage_path=None, torch_compile=False, **kwargs):
57
  if checkpoint_dir is None:
58
  if persistent_storage_path is None:
59
  checkpoint_dir = os.path.join(os.path.dirname(__file__), "checkpoints")
 
66
  self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
67
  self.device = device
68
  self.loaded = False
69
+ self.torch_compile = torch_compile
70
 
71
  def load_checkpoint(self, checkpoint_dir=None):
72
  device = self.device
 
160
  self.loaded = True
161
 
162
  # compile
163
+ if self.torch_compile:
164
+ self.music_dcae = torch.compile(self.music_dcae)
165
+ self.ace_step_transformer = torch.compile(self.ace_step_transformer)
166
+ self.text_encoder_model = torch.compile(self.text_encoder_model)
167
 
168
  def get_text_embeddings(self, texts, device, text_max_length=256):
169
  inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
 
230
 
231
  def get_lang(self, text):
232
  language = "en"
233
+ try:
234
  _ = self.lang_segment.getTexts(text)
235
  langCounts = self.lang_segment.getCounts()
236
  language = langCounts[0][0]
 
271
  print("tokenize error", e, "for line", line, "major_language", lang)
272
  return lyric_token_idx
273
 
274
+ def calc_v(
275
+ self,
276
+ zt_src,
277
+ zt_tar,
278
+ t,
279
+ encoder_text_hidden_states,
280
+ text_attention_mask,
281
+ target_encoder_text_hidden_states,
282
+ target_text_attention_mask,
283
+ speaker_embds,
284
+ target_speaker_embeds,
285
+ lyric_token_ids,
286
+ lyric_mask,
287
+ target_lyric_token_ids,
288
+ target_lyric_mask,
289
+ do_classifier_free_guidance=False,
290
+ guidance_scale=1.0,
291
+ target_guidance_scale=1.0,
292
+ cfg_type="apg",
293
+ attention_mask=None,
294
+ momentum_buffer=None,
295
+ momentum_buffer_tar=None,
296
+ return_src_pred=True
297
+ ):
298
+ noise_pred_src = None
299
+ if return_src_pred:
300
+ src_latent_model_input = torch.cat([zt_src, zt_src]) if do_classifier_free_guidance else zt_src
301
+ timestep = t.expand(src_latent_model_input.shape[0])
302
+ # source
303
+ noise_pred_src = self.ace_step_transformer(
304
+ hidden_states=src_latent_model_input,
305
+ attention_mask=attention_mask,
306
+ encoder_text_hidden_states=encoder_text_hidden_states,
307
+ text_attention_mask=text_attention_mask,
308
+ speaker_embeds=speaker_embds,
309
+ lyric_token_idx=lyric_token_ids,
310
+ lyric_mask=lyric_mask,
311
+ timestep=timestep,
312
+ ).sample
313
+
314
+ if do_classifier_free_guidance:
315
+ noise_pred_with_cond_src, noise_pred_uncond_src = noise_pred_src.chunk(2)
316
+ if cfg_type == "apg":
317
+ noise_pred_src = apg_forward(
318
+ pred_cond=noise_pred_with_cond_src,
319
+ pred_uncond=noise_pred_uncond_src,
320
+ guidance_scale=guidance_scale,
321
+ momentum_buffer=momentum_buffer,
322
+ )
323
+ elif cfg_type == "cfg":
324
+ noise_pred_src = cfg_forward(
325
+ cond_output=noise_pred_with_cond_src,
326
+ uncond_output=noise_pred_uncond_src,
327
+ cfg_strength=guidance_scale,
328
+ )
329
+
330
+ tar_latent_model_input = torch.cat([zt_tar, zt_tar]) if do_classifier_free_guidance else zt_tar
331
+ timestep = t.expand(tar_latent_model_input.shape[0])
332
+ # target
333
+ noise_pred_tar = self.ace_step_transformer(
334
+ hidden_states=tar_latent_model_input,
335
+ attention_mask=attention_mask,
336
+ encoder_text_hidden_states=target_encoder_text_hidden_states,
337
+ text_attention_mask=target_text_attention_mask,
338
+ speaker_embeds=target_speaker_embeds,
339
+ lyric_token_idx=target_lyric_token_ids,
340
+ lyric_mask=target_lyric_mask,
341
+ timestep=timestep,
342
+ ).sample
343
+
344
+ if do_classifier_free_guidance:
345
+ noise_pred_with_cond_tar, noise_pred_uncond_tar = noise_pred_tar.chunk(2)
346
+ if cfg_type == "apg":
347
+ noise_pred_tar = apg_forward(
348
+ pred_cond=noise_pred_with_cond_tar,
349
+ pred_uncond=noise_pred_uncond_tar,
350
+ guidance_scale=target_guidance_scale,
351
+ momentum_buffer=momentum_buffer_tar,
352
+ )
353
+ elif cfg_type == "cfg":
354
+ noise_pred_tar = cfg_forward(
355
+ cond_output=noise_pred_with_cond_tar,
356
+ uncond_output=noise_pred_uncond_tar,
357
+ cfg_strength=target_guidance_scale,
358
+ )
359
+ return noise_pred_src, noise_pred_tar
360
+
361
+ @torch.no_grad()
362
+ def flowedit_diffusion_process(
363
+ self,
364
+ encoder_text_hidden_states,
365
+ text_attention_mask,
366
+ speaker_embds,
367
+ lyric_token_ids,
368
+ lyric_mask,
369
+ target_encoder_text_hidden_states,
370
+ target_text_attention_mask,
371
+ target_speaker_embeds,
372
+ target_lyric_token_ids,
373
+ target_lyric_mask,
374
+ src_latents,
375
+ random_generators=None,
376
+ infer_steps=60,
377
+ guidance_scale=15.0,
378
+ n_min=0,
379
+ n_max=1.0,
380
+ n_avg=1,
381
+ ):
382
+
383
+ do_classifier_free_guidance = True
384
+ if guidance_scale == 0.0 or guidance_scale == 1.0:
385
+ do_classifier_free_guidance = False
386
+
387
+ target_guidance_scale = guidance_scale
388
+ device = encoder_text_hidden_states.device
389
+ dtype = encoder_text_hidden_states.dtype
390
+ bsz = encoder_text_hidden_states.shape[0]
391
+
392
+ scheduler = FlowMatchEulerDiscreteScheduler(
393
+ num_train_timesteps=1000,
394
+ shift=3.0,
395
+ )
396
+
397
+ T_steps = infer_steps
398
+ frame_length = src_latents.shape[-1]
399
+ attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
400
+
401
+ timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, device, timesteps=None)
402
+
403
+ if do_classifier_free_guidance:
404
+ attention_mask = torch.cat([attention_mask] * 2, dim=0)
405
+
406
+ encoder_text_hidden_states = torch.cat([encoder_text_hidden_states, torch.zeros_like(encoder_text_hidden_states)], 0)
407
+ text_attention_mask = torch.cat([text_attention_mask] * 2, dim=0)
408
+
409
+ target_encoder_text_hidden_states = torch.cat([target_encoder_text_hidden_states, torch.zeros_like(target_encoder_text_hidden_states)], 0)
410
+ target_text_attention_mask = torch.cat([target_text_attention_mask] * 2, dim=0)
411
+
412
+ speaker_embds = torch.cat([speaker_embds, torch.zeros_like(speaker_embds)], 0)
413
+ target_speaker_embeds = torch.cat([target_speaker_embeds, torch.zeros_like(target_speaker_embeds)], 0)
414
+
415
+ lyric_token_ids = torch.cat([lyric_token_ids, torch.zeros_like(lyric_token_ids)], 0)
416
+ lyric_mask = torch.cat([lyric_mask, torch.zeros_like(lyric_mask)], 0)
417
+
418
+ target_lyric_token_ids = torch.cat([target_lyric_token_ids, torch.zeros_like(target_lyric_token_ids)], 0)
419
+ target_lyric_mask = torch.cat([target_lyric_mask, torch.zeros_like(target_lyric_mask)], 0)
420
+
421
+ momentum_buffer = MomentumBuffer()
422
+ momentum_buffer_tar = MomentumBuffer()
423
+ x_src = src_latents
424
+ zt_edit = x_src.clone()
425
+ xt_tar = None
426
+ n_min = int(infer_steps * n_min)
427
+ n_max = int(infer_steps * n_max)
428
+
429
+ logger.info("flowedit start from {} to {}".format(n_min, n_max))
430
+
431
+ for i, t in tqdm(enumerate(timesteps), total=T_steps):
432
+
433
+ if i < n_min:
434
+ continue
435
+
436
+ t_i = t/1000
437
+
438
+ if i+1 < len(timesteps):
439
+ t_im1 = (timesteps[i+1])/1000
440
+ else:
441
+ t_im1 = torch.zeros_like(t_i).to(t_i.device)
442
+
443
+ if i < n_max:
444
+ # Calculate the average of the V predictions
445
+ V_delta_avg = torch.zeros_like(x_src)
446
+ for k in range(n_avg):
447
+ fwd_noise = randn_tensor(shape=x_src.shape, generator=random_generators, device=device, dtype=dtype)
448
+
449
+ zt_src = (1 - t_i) * x_src + (t_i) * fwd_noise
450
+
451
+ zt_tar = zt_edit + zt_src - x_src
452
+
453
+ Vt_src, Vt_tar = self.calc_v(
454
+ zt_src=zt_src,
455
+ zt_tar=zt_tar,
456
+ t=t,
457
+ encoder_text_hidden_states=encoder_text_hidden_states,
458
+ text_attention_mask=text_attention_mask,
459
+ target_encoder_text_hidden_states=target_encoder_text_hidden_states,
460
+ target_text_attention_mask=target_text_attention_mask,
461
+ speaker_embds=speaker_embds,
462
+ target_speaker_embeds=target_speaker_embeds,
463
+ lyric_token_ids=lyric_token_ids,
464
+ lyric_mask=lyric_mask,
465
+ target_lyric_token_ids=target_lyric_token_ids,
466
+ target_lyric_mask=target_lyric_mask,
467
+ do_classifier_free_guidance=do_classifier_free_guidance,
468
+ guidance_scale=guidance_scale,
469
+ target_guidance_scale=target_guidance_scale,
470
+ attention_mask=attention_mask,
471
+ momentum_buffer=momentum_buffer
472
+ )
473
+ V_delta_avg += (1 / n_avg) * (Vt_tar - Vt_src) # - (hfg-1)*( x_src))
474
+
475
+ # propagate direct ODE
476
+ zt_edit = zt_edit.to(torch.float32)
477
+ zt_edit = zt_edit + (t_im1 - t_i) * V_delta_avg
478
+ zt_edit = zt_edit.to(V_delta_avg.dtype)
479
+ else: # i >= T_steps-n_min # regular sampling for last n_min steps
480
+ if i == n_max:
481
+ fwd_noise = randn_tensor(shape=x_src.shape, generator=random_generators, device=device, dtype=dtype)
482
+ scheduler._init_step_index(t)
483
+ sigma = scheduler.sigmas[scheduler.step_index]
484
+ xt_src = sigma * fwd_noise + (1.0 - sigma) * x_src
485
+ xt_tar = zt_edit + xt_src - x_src
486
+
487
+ _, Vt_tar = self.calc_v(
488
+ zt_src=None,
489
+ zt_tar=xt_tar,
490
+ t=t,
491
+ encoder_text_hidden_states=encoder_text_hidden_states,
492
+ text_attention_mask=text_attention_mask,
493
+ target_encoder_text_hidden_states=target_encoder_text_hidden_states,
494
+ target_text_attention_mask=target_text_attention_mask,
495
+ speaker_embds=speaker_embds,
496
+ target_speaker_embeds=target_speaker_embeds,
497
+ lyric_token_ids=lyric_token_ids,
498
+ lyric_mask=lyric_mask,
499
+ target_lyric_token_ids=target_lyric_token_ids,
500
+ target_lyric_mask=target_lyric_mask,
501
+ do_classifier_free_guidance=do_classifier_free_guidance,
502
+ guidance_scale=guidance_scale,
503
+ target_guidance_scale=target_guidance_scale,
504
+ attention_mask=attention_mask,
505
+ momentum_buffer_tar=momentum_buffer_tar,
506
+ return_src_pred=False,
507
+ )
508
+
509
+ dtype = Vt_tar.dtype
510
+ xt_tar = xt_tar.to(torch.float32)
511
+ prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar
512
+ prev_sample = prev_sample.to(dtype)
513
+ xt_tar = prev_sample
514
+
515
+ target_latents = zt_edit if xt_tar is None else xt_tar
516
+ return target_latents
517
+
518
  @torch.no_grad()
519
  def text2music_diffusion_process(
520
  self,
 
544
  add_retake_noise=False,
545
  guidance_scale_text=0.0,
546
  guidance_scale_lyric=0.0,
547
+ repaint_start=0,
548
+ repaint_end=0,
549
+ src_latents=None,
550
  ):
551
 
552
  logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale))
553
  do_classifier_free_guidance = True
554
  if guidance_scale == 0.0 or guidance_scale == 1.0:
555
  do_classifier_free_guidance = False
556
+
557
  do_double_condition_guidance = False
558
  if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0:
559
  do_double_condition_guidance = True
 
573
  num_train_timesteps=1000,
574
  shift=3.0,
575
  )
576
+
577
  frame_length = int(duration * 44100 / 512 / 8)
578
+ if src_latents is not None:
579
+ frame_length = src_latents.shape[-1]
580
 
581
  if len(oss_steps) > 0:
582
  infer_steps = max(oss_steps)
 
591
  logger.info(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}")
592
  else:
593
  timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=infer_steps, device=device, timesteps=None)
594
+
595
  target_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=random_generators, device=device, dtype=dtype)
596
+
597
+ is_repaint = False
598
  if add_retake_noise:
599
  retake_variance = torch.tensor(retake_variance * math.pi/2).to(device).to(dtype)
600
  retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
601
+ repaint_start_frame = int(repaint_start * 44100 / 512 / 8)
602
+ repaint_end_frame = int(repaint_end * 44100 / 512 / 8)
603
+
604
+ # retake
605
+ is_repaint = repaint_end_frame - repaint_start_frame != frame_length
606
  # to make sure mean = 0, std = 1
607
+ if not is_repaint:
608
+ target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
609
+ else:
610
+ repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=device, dtype=dtype)
611
+ repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0
612
+ repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
613
+ repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents)
614
+ z0 = repaint_noise
615
 
616
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
617
+
618
  # guidance interval
619
  start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
620
  end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
 
624
 
625
  def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
626
  handlers = []
627
+
628
  def hook(module, input, output):
629
  output[:] *= tau
630
  return output
631
+
632
  for i in range(l_min, l_max):
633
  handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
634
  handlers.append(handler)
635
+
636
  encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
637
+
638
  for hook in handlers:
639
  hook.remove()
640
+
641
  return encoder_hidden_states
642
 
643
  # P(speaker, text, lyric)
 
670
  torch.zeros_like(lyric_token_ids),
671
  lyric_mask,
672
  )
673
+
674
  encoder_hidden_states_no_lyric = None
675
  if do_double_condition_guidance:
676
  # P(null_speaker, text, lyric_weaker)
 
697
 
698
  def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
699
  handlers = []
700
+
701
  def hook(module, input, output):
702
  output[:] *= tau
703
  return output
704
+
705
  for i in range(l_min, l_max):
706
  handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
707
  handlers.append(handler)
 
709
  handlers.append(handler)
710
 
711
  sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
712
+
713
  for hook in handlers:
714
  hook.remove()
715
+
716
  return sample
717
 
718
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
 
819
  ).sample
820
 
821
  target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
822
+ if is_repaint:
823
+ t_i = t / 1000
824
+ x0 = src_latents
825
+ xt = (1 - t_i) * x0 + t_i * z0
826
+ target_latents = torch.where(repaint_mask == 1.0, target_latents, xt)
827
 
828
+
829
  return target_latents
830
 
831
  def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
 
844
  def save_wav_file(self, target_wav, idx, save_path=None, sample_rate=48000, format="flac"):
845
  if save_path is None:
846
  logger.warning("save_path is None, using default path ./outputs/")
847
+ base_path = f"./outputs"
848
  ensure_directory_exists(base_path)
849
  else:
850
  base_path = save_path
 
855
  torchaudio.save(output_path_flac, target_wav, sample_rate=sample_rate, format=format)
856
  return output_path_flac
857
 
858
+ def infer_latents(self, input_audio_path):
859
+ if input_audio_path is None:
860
+ return None
861
+ input_audio, sr = self.music_dcae.load_audio(input_audio_path)
862
+ input_audio = input_audio.unsqueeze(0)
863
+ device, dtype = self.device, self.dtype
864
+ input_audio = input_audio.to(device=device, dtype=dtype)
865
+ latents, _ = self.music_dcae.encode(input_audio, sr=sr)
866
+ return latents
867
+
868
  def __call__(
869
  self,
870
  audio_duration: float = 60.0,
 
888
  retake_seeds: list = None,
889
  retake_variance: float = 0.5,
890
  task: str = "text2music",
891
+ repaint_start: int = 0,
892
+ repaint_end: int = 0,
893
+ src_audio_path: str = None,
894
+ edit_target_prompt: str = None,
895
+ edit_target_lyrics: str = None,
896
+ edit_n_min: float = 0.0,
897
+ edit_n_max: float = 1.0,
898
+ edit_n_avg: int = 1,
899
  save_path: str = None,
900
  format: str = "flac",
901
  batch_size: int = 1,
 
918
  oss_steps = list(map(int, oss_steps.split(",")))
919
  else:
920
  oss_steps = []
921
+
922
  texts = [prompt]
923
  encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
924
  encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
 
949
  preprocess_time_cost = end_time - start_time
950
  start_time = end_time
951
 
952
+ add_retake_noise = task in ("retake", "repaint")
953
+ # retake equal to repaint
954
+ if task == "retake":
955
+ repaint_start = 0
956
+ repaint_end = audio_duration
957
+
958
+ src_latents = None
959
+ if src_audio_path is not None:
960
+ assert src_audio_path is not None and task in ("repaint", "edit"), "src_audio_path is required for repaint task"
961
+ assert os.path.exists(src_audio_path), f"src_audio_path {src_audio_path} does not exist"
962
+ src_latents = self.infer_latents(src_audio_path)
963
+
964
+ if task == "edit":
965
+ texts = [edit_target_prompt]
966
+ target_encoder_text_hidden_states, target_text_attention_mask = self.get_text_embeddings(texts, self.device)
967
+ target_encoder_text_hidden_states = target_encoder_text_hidden_states.repeat(batch_size, 1, 1)
968
+ target_text_attention_mask = target_text_attention_mask.repeat(batch_size, 1)
969
+
970
+ target_lyric_token_idx = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
971
+ target_lyric_mask = torch.tensor([0]).repeat(batch_size, 1).to(self.device).long()
972
+ if len(edit_target_lyrics) > 0:
973
+ target_lyric_token_idx = self.tokenize_lyrics(edit_target_lyrics, debug=True)
974
+ target_lyric_mask = [1] * len(target_lyric_token_idx)
975
+ target_lyric_token_idx = torch.tensor(target_lyric_token_idx).unsqueeze(0).to(self.device).repeat(batch_size, 1)
976
+ target_lyric_mask = torch.tensor(target_lyric_mask).unsqueeze(0).to(self.device).repeat(batch_size, 1)
977
+
978
+ target_speaker_embeds = speaker_embeds.clone()
979
+
980
+ target_latents = self.flowedit_diffusion_process(
981
+ encoder_text_hidden_states=encoder_text_hidden_states,
982
+ text_attention_mask=text_attention_mask,
983
+ speaker_embds=speaker_embeds,
984
+ lyric_token_ids=lyric_token_idx,
985
+ lyric_mask=lyric_mask,
986
+ target_encoder_text_hidden_states=target_encoder_text_hidden_states,
987
+ target_text_attention_mask=target_text_attention_mask,
988
+ target_speaker_embeds=target_speaker_embeds,
989
+ target_lyric_token_ids=target_lyric_token_idx,
990
+ target_lyric_mask=target_lyric_mask,
991
+ src_latents=src_latents,
992
+ random_generators=random_generators,
993
+ infer_steps=infer_step,
994
+ guidance_scale=guidance_scale,
995
+ n_min=edit_n_min,
996
+ n_max=edit_n_max,
997
+ n_avg=edit_n_avg,
998
+ )
999
+ else:
1000
+ target_latents = self.text2music_diffusion_process(
1001
+ duration=audio_duration,
1002
+ encoder_text_hidden_states=encoder_text_hidden_states,
1003
+ text_attention_mask=text_attention_mask,
1004
+ speaker_embds=speaker_embeds,
1005
+ lyric_token_ids=lyric_token_idx,
1006
+ lyric_mask=lyric_mask,
1007
+ guidance_scale=guidance_scale,
1008
+ omega_scale=omega_scale,
1009
+ infer_steps=infer_step,
1010
+ random_generators=random_generators,
1011
+ scheduler_type=scheduler_type,
1012
+ cfg_type=cfg_type,
1013
+ guidance_interval=guidance_interval,
1014
+ guidance_interval_decay=guidance_interval_decay,
1015
+ min_guidance_scale=min_guidance_scale,
1016
+ oss_steps=oss_steps,
1017
+ encoder_text_hidden_states_null=encoder_text_hidden_states_null,
1018
+ use_erg_lyric=use_erg_lyric,
1019
+ use_erg_diffusion=use_erg_diffusion,
1020
+ retake_random_generators=retake_random_generators,
1021
+ retake_variance=retake_variance,
1022
+ add_retake_noise=add_retake_noise,
1023
+ guidance_scale_text=guidance_scale_text,
1024
+ guidance_scale_lyric=guidance_scale_lyric,
1025
+ repaint_start=repaint_start,
1026
+ repaint_end=repaint_end,
1027
+ src_latents=src_latents,
1028
+ )
1029
 
1030
  end_time = time.time()
1031
  diffusion_time_cost = end_time - start_time
 
1069
  "retake_variance": retake_variance,
1070
  "guidance_scale_text": guidance_scale_text,
1071
  "guidance_scale_lyric": guidance_scale_lyric,
1072
+ "repaint_start": repaint_start,
1073
+ "repaint_end": repaint_end,
1074
+ "edit_n_min": edit_n_min,
1075
+ "edit_n_max": edit_n_max,
1076
+ "edit_n_avg": edit_n_avg,
1077
+ "src_audio_path": src_audio_path,
1078
+ "edit_target_prompt": edit_target_prompt,
1079
+ "edit_target_lyrics": edit_target_lyrics,
1080
  }
1081
  # save input_params_json
1082
  for output_audio_path in output_paths:
ui/components.py CHANGED
@@ -63,15 +63,15 @@ def create_text2music_ui(
63
  ):
64
  with gr.Row():
65
  with gr.Column():
66
-
67
  with gr.Row(equal_height=True):
 
68
  audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=180, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
69
  sample_bnt = gr.Button("Sample", variant="primary", scale=1)
70
 
71
- prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.")
72
  lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, placeholder=LYRIC_PLACEHOLDER, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
73
 
74
- with gr.Accordion("Basic Settings", open=True):
75
  infer_step = gr.Slider(minimum=1, maximum=1000, step=1, value=60, label="Infer Steps", interactive=True)
76
  guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=15.0, label="Guidance Scale", interactive=True, info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.")
77
  guidance_scale_text = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=5.0, label="Guidance Scale Text", interactive=True, info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start")
@@ -93,14 +93,14 @@ def create_text2music_ui(
93
  min_guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=3.0, label="Min Guidance Scale", interactive=True, info="Min guidance scale for guidance interval decay's end scale")
94
  oss_steps = gr.Textbox(label="OSS Steps", placeholder="16, 29, 52, 96, 129, 158, 172, 183, 189, 200", value=None, info="Optimal Steps for the generation. But not test well")
95
 
96
- text2music_bnt = gr.Button(variant="primary")
97
 
98
  with gr.Column():
99
  outputs, input_params_json = create_output_ui()
100
  with gr.Tab("retake"):
101
- retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance", info="Variance for the retake. 0.0 means no variance. 1.0 means full variance.")
102
- retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None, info="Seed for the retake.")
103
- retake_bnt = gr.Button(variant="primary")
104
  retake_outputs, retake_input_params_json = create_output_ui("Retake")
105
 
106
  def retake_process_func(json_data, retake_variance, retake_seeds):
@@ -138,9 +138,251 @@ def create_text2music_ui(
138
  outputs=retake_outputs + [retake_input_params_json],
139
  )
140
  with gr.Tab("repainting"):
141
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  with gr.Tab("edit"):
143
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def sample_data():
146
  json_data = sample_data_func()
@@ -219,13 +461,12 @@ def create_main_demo_ui(
219
  sample_data_func=dump_func,
220
  ):
221
  with gr.Blocks(
222
- title="FusicModel 1.0 DEMO",
223
  ) as demo:
224
  gr.Markdown(
225
  """
226
- <h1 style="text-align: center;">FusicModel 1.0 DEMO</h1>
227
- """
228
- )
229
 
230
  with gr.Tab("text2music"):
231
  create_text2music_ui(
 
63
  ):
64
  with gr.Row():
65
  with gr.Column():
 
66
  with gr.Row(equal_height=True):
67
+ # add markdown, tags and lyrics examples are from ai music generation community
68
  audio_duration = gr.Slider(-1, 240.0, step=0.00001, value=180, label="Audio Duration", interactive=True, info="-1 means random duration (30 ~ 240).", scale=9)
69
  sample_bnt = gr.Button("Sample", variant="primary", scale=1)
70
 
71
+ prompt = gr.Textbox(lines=2, label="Tags", max_lines=4, placeholder=TAG_PLACEHOLDER, info="Support tags, descriptions, and scene. Use commas to separate different tags.\ntags and lyrics examples are from ai music generation community")
72
  lyrics = gr.Textbox(lines=9, label="Lyrics", max_lines=13, placeholder=LYRIC_PLACEHOLDER, info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.\nUse [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics")
73
 
74
+ with gr.Accordion("Basic Settings", open=False):
75
  infer_step = gr.Slider(minimum=1, maximum=1000, step=1, value=60, label="Infer Steps", interactive=True)
76
  guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=15.0, label="Guidance Scale", interactive=True, info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.")
77
  guidance_scale_text = gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=5.0, label="Guidance Scale Text", interactive=True, info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start")
 
93
  min_guidance_scale = gr.Slider(minimum=0.0, maximum=200.0, step=0.1, value=3.0, label="Min Guidance Scale", interactive=True, info="Min guidance scale for guidance interval decay's end scale")
94
  oss_steps = gr.Textbox(label="OSS Steps", placeholder="16, 29, 52, 96, 129, 158, 172, 183, 189, 200", value=None, info="Optimal Steps for the generation. But not test well")
95
 
96
+ text2music_bnt = gr.Button("Generate", variant="primary")
97
 
98
  with gr.Column():
99
  outputs, input_params_json = create_output_ui()
100
  with gr.Tab("retake"):
101
+ retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance")
102
+ retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None)
103
+ retake_bnt = gr.Button("Retake", variant="primary")
104
  retake_outputs, retake_input_params_json = create_output_ui("Retake")
105
 
106
  def retake_process_func(json_data, retake_variance, retake_seeds):
 
138
  outputs=retake_outputs + [retake_input_params_json],
139
  )
140
  with gr.Tab("repainting"):
141
+ retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance")
142
+ retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None)
143
+ repaint_start = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=0.0, label="Repaint Start Time", interactive=True)
144
+ repaint_end = gr.Slider(minimum=0.0, maximum=240.0, step=0.01, value=30.0, label="Repaint End Time", interactive=True)
145
+ repaint_source = gr.Radio(["text2music", "last_repaint", "upload"], value="text2music", label="Repaint Source", elem_id="repaint_source")
146
+
147
+ repaint_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="repaint_source_audio_upload")
148
+ repaint_source.change(
149
+ fn=lambda x: gr.update(visible=x == "upload", elem_id="repaint_source_audio_upload"),
150
+ inputs=[repaint_source],
151
+ outputs=[repaint_source_audio_upload],
152
+ )
153
+
154
+ repaint_bnt = gr.Button("Repaint", variant="primary")
155
+ repaint_outputs, repaint_input_params_json = create_output_ui("Repaint")
156
+
157
+ def repaint_process_func(
158
+ text2music_json_data,
159
+ repaint_json_data,
160
+ retake_variance,
161
+ retake_seeds,
162
+ repaint_start,
163
+ repaint_end,
164
+ repaint_source,
165
+ repaint_source_audio_upload,
166
+ prompt,
167
+ lyrics,
168
+ infer_step,
169
+ guidance_scale,
170
+ scheduler_type,
171
+ cfg_type,
172
+ omega_scale,
173
+ manual_seeds,
174
+ guidance_interval,
175
+ guidance_interval_decay,
176
+ min_guidance_scale,
177
+ use_erg_tag,
178
+ use_erg_lyric,
179
+ use_erg_diffusion,
180
+ oss_steps,
181
+ guidance_scale_text,
182
+ guidance_scale_lyric,
183
+ ):
184
+ if repaint_source == "upload":
185
+ src_audio_path = repaint_source_audio_upload
186
+ json_data = text2music_json_data
187
+ elif repaint_source == "text2music":
188
+ json_data = text2music_json_data
189
+ src_audio_path = json_data["audio_path"]
190
+ elif repaint_source == "last_repaint":
191
+ json_data = repaint_json_data
192
+ src_audio_path = json_data["audio_path"]
193
+
194
+ return text2music_process_func(
195
+ json_data["audio_duration"],
196
+ prompt,
197
+ lyrics,
198
+ infer_step,
199
+ guidance_scale,
200
+ scheduler_type,
201
+ cfg_type,
202
+ omega_scale,
203
+ manual_seeds,
204
+ guidance_interval,
205
+ guidance_interval_decay,
206
+ min_guidance_scale,
207
+ use_erg_tag,
208
+ use_erg_lyric,
209
+ use_erg_diffusion,
210
+ oss_steps,
211
+ guidance_scale_text,
212
+ guidance_scale_lyric,
213
+ retake_seeds=retake_seeds,
214
+ retake_variance=retake_variance,
215
+ task="repaint",
216
+ repaint_start=repaint_start,
217
+ repaint_end=repaint_end,
218
+ src_audio_path=src_audio_path,
219
+ )
220
+
221
+ repaint_bnt.click(
222
+ fn=repaint_process_func,
223
+ inputs=[
224
+ input_params_json,
225
+ repaint_input_params_json,
226
+ retake_variance,
227
+ retake_seeds,
228
+ repaint_start,
229
+ repaint_end,
230
+ repaint_source,
231
+ repaint_source_audio_upload,
232
+ prompt,
233
+ lyrics,
234
+ infer_step,
235
+ guidance_scale,
236
+ scheduler_type,
237
+ cfg_type,
238
+ omega_scale,
239
+ manual_seeds,
240
+ guidance_interval,
241
+ guidance_interval_decay,
242
+ min_guidance_scale,
243
+ use_erg_tag,
244
+ use_erg_lyric,
245
+ use_erg_diffusion,
246
+ oss_steps,
247
+ guidance_scale_text,
248
+ guidance_scale_lyric,
249
+ ],
250
+ outputs=repaint_outputs + [repaint_input_params_json],
251
+ )
252
  with gr.Tab("edit"):
253
+ edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4)
254
+ edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13)
255
+
256
+ edit_type = gr.Radio(["only_lyrics", "remix"], value="only_lyrics", label="Edit Type", elem_id="edit_type", info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre")
257
+ edit_n_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.8, label="edit_n_min", interactive=True)
258
+ edit_n_max = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="edit_n_max", interactive=True)
259
+
260
+ def edit_type_change_func(edit_type):
261
+ if edit_type == "only_lyrics":
262
+ n_min = 0.8
263
+ n_max = 1.0
264
+ elif edit_type == "remix":
265
+ n_min = 0.2
266
+ n_max = 0.4
267
+ return n_min, n_max
268
+
269
+ edit_type.change(
270
+ edit_type_change_func,
271
+ inputs=[edit_type],
272
+ outputs=[edit_n_min, edit_n_max]
273
+ )
274
+
275
+ edit_source = gr.Radio(["text2music", "last_edit", "upload"], value="text2music", label="Edit Source", elem_id="edit_source")
276
+ edit_source_audio_upload = gr.Audio(label="Upload Audio", type="filepath", visible=False, elem_id="edit_source_audio_upload")
277
+ edit_source.change(
278
+ fn=lambda x: gr.update(visible=x == "upload", elem_id="edit_source_audio_upload"),
279
+ inputs=[edit_source],
280
+ outputs=[edit_source_audio_upload],
281
+ )
282
+
283
+ edit_bnt = gr.Button("Edit", variant="primary")
284
+ edit_outputs, edit_input_params_json = create_output_ui("Edit")
285
+
286
+ def edit_process_func(
287
+ text2music_json_data,
288
+ edit_input_params_json,
289
+ edit_source,
290
+ edit_source_audio_upload,
291
+ prompt,
292
+ lyrics,
293
+ edit_prompt,
294
+ edit_lyrics,
295
+ edit_n_min,
296
+ edit_n_max,
297
+ infer_step,
298
+ guidance_scale,
299
+ scheduler_type,
300
+ cfg_type,
301
+ omega_scale,
302
+ manual_seeds,
303
+ guidance_interval,
304
+ guidance_interval_decay,
305
+ min_guidance_scale,
306
+ use_erg_tag,
307
+ use_erg_lyric,
308
+ use_erg_diffusion,
309
+ oss_steps,
310
+ guidance_scale_text,
311
+ guidance_scale_lyric,
312
+ ):
313
+ if edit_source == "upload":
314
+ src_audio_path = edit_source_audio_upload
315
+ json_data = text2music_json_data
316
+ elif edit_source == "text2music":
317
+ json_data = text2music_json_data
318
+ src_audio_path = json_data["audio_path"]
319
+ elif edit_source == "last_edit":
320
+ json_data = edit_input_params_json
321
+ src_audio_path = json_data["audio_path"]
322
+
323
+ if not edit_prompt:
324
+ edit_prompt = prompt
325
+ if not edit_lyrics:
326
+ edit_lyrics = lyrics
327
+
328
+ return text2music_process_func(
329
+ json_data["audio_duration"],
330
+ prompt,
331
+ lyrics,
332
+ infer_step,
333
+ guidance_scale,
334
+ scheduler_type,
335
+ cfg_type,
336
+ omega_scale,
337
+ manual_seeds,
338
+ guidance_interval,
339
+ guidance_interval_decay,
340
+ min_guidance_scale,
341
+ use_erg_tag,
342
+ use_erg_lyric,
343
+ use_erg_diffusion,
344
+ oss_steps,
345
+ guidance_scale_text,
346
+ guidance_scale_lyric,
347
+ task="edit",
348
+ src_audio_path=src_audio_path,
349
+ edit_target_prompt=edit_prompt,
350
+ edit_target_lyrics=edit_lyrics,
351
+ edit_n_min=edit_n_min,
352
+ edit_n_max=edit_n_max
353
+ )
354
+
355
+ edit_bnt.click(
356
+ fn=edit_process_func,
357
+ inputs=[
358
+ input_params_json,
359
+ edit_input_params_json,
360
+ edit_source,
361
+ edit_source_audio_upload,
362
+ prompt,
363
+ lyrics,
364
+ edit_prompt,
365
+ edit_lyrics,
366
+ edit_n_min,
367
+ edit_n_max,
368
+ infer_step,
369
+ guidance_scale,
370
+ scheduler_type,
371
+ cfg_type,
372
+ omega_scale,
373
+ manual_seeds,
374
+ guidance_interval,
375
+ guidance_interval_decay,
376
+ min_guidance_scale,
377
+ use_erg_tag,
378
+ use_erg_lyric,
379
+ use_erg_diffusion,
380
+ oss_steps,
381
+ guidance_scale_text,
382
+ guidance_scale_lyric,
383
+ ],
384
+ outputs=edit_outputs + [edit_input_params_json],
385
+ )
386
 
387
  def sample_data():
388
  json_data = sample_data_func()
 
461
  sample_data_func=dump_func,
462
  ):
463
  with gr.Blocks(
464
+ title="ACE-Step Model 1.0 DEMO",
465
  ) as demo:
466
  gr.Markdown(
467
  """
468
+ <h1 style="text-align: center;">ACE-Step: A Step Towards Music Generation Foundation Model</h1>
469
+ """)
 
470
 
471
  with gr.Tab("text2music"):
472
  create_text2music_ui(