Sayoyo commited on
Commit
d39a7e5
·
1 Parent(s): 5488167

[feat] add torch.compile

Browse files
Files changed (1) hide show
  1. pipeline_ace_step.py +21 -25
pipeline_ace_step.py CHANGED
@@ -2,17 +2,14 @@ import random
2
  import time
3
  import os
4
  import re
5
- import glob
6
 
7
  import torch
8
- import torch.nn as nn
9
  from loguru import logger
10
  from tqdm import tqdm
11
  import json
12
  import math
13
  from huggingface_hub import hf_hub_download
14
 
15
- # from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
  from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
17
  from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
18
  from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
@@ -64,12 +61,12 @@ class ACEStepPipeline:
64
 
65
  def load_checkpoint(self, checkpoint_dir=None):
66
  device = self.device
67
-
68
  dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
69
  vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
70
  ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
71
  text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
72
-
73
  files_exist = (
74
  os.path.exists(os.path.join(dcae_model_path, "config.json")) and
75
  os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
@@ -154,9 +151,9 @@ class ACEStepPipeline:
154
  self.loaded = True
155
 
156
  # compile
157
- # self.music_dcae = torch.compile(self.music_dcae)
158
- # self.ace_step_transformer = torch.compile(self.ace_step_transformer)
159
- # self.text_encoder_model = torch.compile(self.text_encoder_model)
160
 
161
  def get_text_embeddings(self, texts, device, text_max_length=256):
162
  inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
@@ -223,7 +220,7 @@ class ACEStepPipeline:
223
 
224
  def get_lang(self, text):
225
  language = "en"
226
- try:
227
  _ = self.lang_segment.getTexts(text)
228
  langCounts = self.lang_segment.getCounts()
229
  language = langCounts[0][0]
@@ -341,10 +338,10 @@ class ACEStepPipeline:
341
  retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
342
  # to make sure mean = 0, std = 1
343
  target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
344
-
345
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
346
-
347
- # guidance interval逻辑
348
  start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
349
  end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
350
  logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
@@ -353,20 +350,20 @@ class ACEStepPipeline:
353
 
354
  def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
355
  handlers = []
356
-
357
  def hook(module, input, output):
358
  output[:] *= tau
359
  return output
360
-
361
  for i in range(l_min, l_max):
362
  handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
363
  handlers.append(handler)
364
-
365
  encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
366
-
367
  for hook in handlers:
368
  hook.remove()
369
-
370
  return encoder_hidden_states
371
 
372
  # P(speaker, text, lyric)
@@ -399,7 +396,7 @@ class ACEStepPipeline:
399
  torch.zeros_like(lyric_token_ids),
400
  lyric_mask,
401
  )
402
-
403
  encoder_hidden_states_no_lyric = None
404
  if do_double_condition_guidance:
405
  # P(null_speaker, text, lyric_weaker)
@@ -426,11 +423,11 @@ class ACEStepPipeline:
426
 
427
  def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
428
  handlers = []
429
-
430
  def hook(module, input, output):
431
  output[:] *= tau
432
  return output
433
-
434
  for i in range(l_min, l_max):
435
  handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
436
  handlers.append(handler)
@@ -438,13 +435,12 @@ class ACEStepPipeline:
438
  handlers.append(handler)
439
 
440
  sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
441
-
442
  for hook in handlers:
443
  hook.remove()
444
-
445
  return sample
446
 
447
-
448
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
449
  # expand the latents if we are doing classifier free guidance
450
  latents = target_latents
@@ -549,7 +545,7 @@ class ACEStepPipeline:
549
  ).sample
550
 
551
  target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
552
-
553
  return target_latents
554
 
555
  def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
@@ -624,7 +620,7 @@ class ACEStepPipeline:
624
  oss_steps = list(map(int, oss_steps.split(",")))
625
  else:
626
  oss_steps = []
627
-
628
  texts = [prompt]
629
  encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
630
  encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
 
2
  import time
3
  import os
4
  import re
 
5
 
6
  import torch
 
7
  from loguru import logger
8
  from tqdm import tqdm
9
  import json
10
  import math
11
  from huggingface_hub import hf_hub_download
12
 
 
13
  from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
14
  from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
15
  from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
 
61
 
62
  def load_checkpoint(self, checkpoint_dir=None):
63
  device = self.device
64
+
65
  dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
66
  vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
67
  ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
68
  text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
69
+
70
  files_exist = (
71
  os.path.exists(os.path.join(dcae_model_path, "config.json")) and
72
  os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
 
151
  self.loaded = True
152
 
153
  # compile
154
+ self.music_dcae = torch.compile(self.music_dcae)
155
+ self.ace_step_transformer = torch.compile(self.ace_step_transformer)
156
+ self.text_encoder_model = torch.compile(self.text_encoder_model)
157
 
158
  def get_text_embeddings(self, texts, device, text_max_length=256):
159
  inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
 
220
 
221
  def get_lang(self, text):
222
  language = "en"
223
+ try:
224
  _ = self.lang_segment.getTexts(text)
225
  langCounts = self.lang_segment.getCounts()
226
  language = langCounts[0][0]
 
338
  retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
339
  # to make sure mean = 0, std = 1
340
  target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
341
+
342
  attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
343
+
344
+ # guidance interval
345
  start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
346
  end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
347
  logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
 
350
 
351
  def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
352
  handlers = []
353
+
354
  def hook(module, input, output):
355
  output[:] *= tau
356
  return output
357
+
358
  for i in range(l_min, l_max):
359
  handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
360
  handlers.append(handler)
361
+
362
  encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
363
+
364
  for hook in handlers:
365
  hook.remove()
366
+
367
  return encoder_hidden_states
368
 
369
  # P(speaker, text, lyric)
 
396
  torch.zeros_like(lyric_token_ids),
397
  lyric_mask,
398
  )
399
+
400
  encoder_hidden_states_no_lyric = None
401
  if do_double_condition_guidance:
402
  # P(null_speaker, text, lyric_weaker)
 
423
 
424
  def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
425
  handlers = []
426
+
427
  def hook(module, input, output):
428
  output[:] *= tau
429
  return output
430
+
431
  for i in range(l_min, l_max):
432
  handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
433
  handlers.append(handler)
 
435
  handlers.append(handler)
436
 
437
  sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
438
+
439
  for hook in handlers:
440
  hook.remove()
441
+
442
  return sample
443
 
 
444
  for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
445
  # expand the latents if we are doing classifier free guidance
446
  latents = target_latents
 
545
  ).sample
546
 
547
  target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
548
+
549
  return target_latents
550
 
551
  def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
 
620
  oss_steps = list(map(int, oss_steps.split(",")))
621
  else:
622
  oss_steps = []
623
+
624
  texts = [prompt]
625
  encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
626
  encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)