Spaces:
Running
Running
[feat] add torch.compile
Browse files- pipeline_ace_step.py +21 -25
pipeline_ace_step.py
CHANGED
@@ -2,17 +2,14 @@ import random
|
|
2 |
import time
|
3 |
import os
|
4 |
import re
|
5 |
-
import glob
|
6 |
|
7 |
import torch
|
8 |
-
import torch.nn as nn
|
9 |
from loguru import logger
|
10 |
from tqdm import tqdm
|
11 |
import json
|
12 |
import math
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
|
15 |
-
# from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
16 |
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
17 |
from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
|
18 |
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
@@ -64,12 +61,12 @@ class ACEStepPipeline:
|
|
64 |
|
65 |
def load_checkpoint(self, checkpoint_dir=None):
|
66 |
device = self.device
|
67 |
-
|
68 |
dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
|
69 |
vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
|
70 |
ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
|
71 |
text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
|
72 |
-
|
73 |
files_exist = (
|
74 |
os.path.exists(os.path.join(dcae_model_path, "config.json")) and
|
75 |
os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
|
@@ -154,9 +151,9 @@ class ACEStepPipeline:
|
|
154 |
self.loaded = True
|
155 |
|
156 |
# compile
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
|
161 |
def get_text_embeddings(self, texts, device, text_max_length=256):
|
162 |
inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
|
@@ -223,7 +220,7 @@ class ACEStepPipeline:
|
|
223 |
|
224 |
def get_lang(self, text):
|
225 |
language = "en"
|
226 |
-
try:
|
227 |
_ = self.lang_segment.getTexts(text)
|
228 |
langCounts = self.lang_segment.getCounts()
|
229 |
language = langCounts[0][0]
|
@@ -341,10 +338,10 @@ class ACEStepPipeline:
|
|
341 |
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
342 |
# to make sure mean = 0, std = 1
|
343 |
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
344 |
-
|
345 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
346 |
-
|
347 |
-
# guidance interval
|
348 |
start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
|
349 |
end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
|
350 |
logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
|
@@ -353,20 +350,20 @@ class ACEStepPipeline:
|
|
353 |
|
354 |
def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
|
355 |
handlers = []
|
356 |
-
|
357 |
def hook(module, input, output):
|
358 |
output[:] *= tau
|
359 |
return output
|
360 |
-
|
361 |
for i in range(l_min, l_max):
|
362 |
handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
|
363 |
handlers.append(handler)
|
364 |
-
|
365 |
encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
|
366 |
-
|
367 |
for hook in handlers:
|
368 |
hook.remove()
|
369 |
-
|
370 |
return encoder_hidden_states
|
371 |
|
372 |
# P(speaker, text, lyric)
|
@@ -399,7 +396,7 @@ class ACEStepPipeline:
|
|
399 |
torch.zeros_like(lyric_token_ids),
|
400 |
lyric_mask,
|
401 |
)
|
402 |
-
|
403 |
encoder_hidden_states_no_lyric = None
|
404 |
if do_double_condition_guidance:
|
405 |
# P(null_speaker, text, lyric_weaker)
|
@@ -426,11 +423,11 @@ class ACEStepPipeline:
|
|
426 |
|
427 |
def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
|
428 |
handlers = []
|
429 |
-
|
430 |
def hook(module, input, output):
|
431 |
output[:] *= tau
|
432 |
return output
|
433 |
-
|
434 |
for i in range(l_min, l_max):
|
435 |
handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
|
436 |
handlers.append(handler)
|
@@ -438,13 +435,12 @@ class ACEStepPipeline:
|
|
438 |
handlers.append(handler)
|
439 |
|
440 |
sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
|
441 |
-
|
442 |
for hook in handlers:
|
443 |
hook.remove()
|
444 |
-
|
445 |
return sample
|
446 |
|
447 |
-
|
448 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
449 |
# expand the latents if we are doing classifier free guidance
|
450 |
latents = target_latents
|
@@ -549,7 +545,7 @@ class ACEStepPipeline:
|
|
549 |
).sample
|
550 |
|
551 |
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
552 |
-
|
553 |
return target_latents
|
554 |
|
555 |
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
@@ -624,7 +620,7 @@ class ACEStepPipeline:
|
|
624 |
oss_steps = list(map(int, oss_steps.split(",")))
|
625 |
else:
|
626 |
oss_steps = []
|
627 |
-
|
628 |
texts = [prompt]
|
629 |
encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
|
630 |
encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
|
|
|
2 |
import time
|
3 |
import os
|
4 |
import re
|
|
|
5 |
|
6 |
import torch
|
|
|
7 |
from loguru import logger
|
8 |
from tqdm import tqdm
|
9 |
import json
|
10 |
import math
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
|
|
13 |
from schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
14 |
from schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
|
15 |
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps
|
|
|
61 |
|
62 |
def load_checkpoint(self, checkpoint_dir=None):
|
63 |
device = self.device
|
64 |
+
|
65 |
dcae_model_path = os.path.join(checkpoint_dir, "music_dcae_f8c8")
|
66 |
vocoder_model_path = os.path.join(checkpoint_dir, "music_vocoder")
|
67 |
ace_step_model_path = os.path.join(checkpoint_dir, "ace_step_transformer")
|
68 |
text_encoder_model_path = os.path.join(checkpoint_dir, "umt5-base")
|
69 |
+
|
70 |
files_exist = (
|
71 |
os.path.exists(os.path.join(dcae_model_path, "config.json")) and
|
72 |
os.path.exists(os.path.join(dcae_model_path, "diffusion_pytorch_model.safetensors")) and
|
|
|
151 |
self.loaded = True
|
152 |
|
153 |
# compile
|
154 |
+
self.music_dcae = torch.compile(self.music_dcae)
|
155 |
+
self.ace_step_transformer = torch.compile(self.ace_step_transformer)
|
156 |
+
self.text_encoder_model = torch.compile(self.text_encoder_model)
|
157 |
|
158 |
def get_text_embeddings(self, texts, device, text_max_length=256):
|
159 |
inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
|
|
|
220 |
|
221 |
def get_lang(self, text):
|
222 |
language = "en"
|
223 |
+
try:
|
224 |
_ = self.lang_segment.getTexts(text)
|
225 |
langCounts = self.lang_segment.getCounts()
|
226 |
language = langCounts[0][0]
|
|
|
338 |
retake_latents = randn_tensor(shape=(bsz, 8, 16, frame_length), generator=retake_random_generators, device=device, dtype=dtype)
|
339 |
# to make sure mean = 0, std = 1
|
340 |
target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents
|
341 |
+
|
342 |
attention_mask = torch.ones(bsz, frame_length, device=device, dtype=dtype)
|
343 |
+
|
344 |
+
# guidance interval
|
345 |
start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2))
|
346 |
end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5))
|
347 |
logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}")
|
|
|
350 |
|
351 |
def forward_encoder_with_temperature(self, inputs, tau=0.01, l_min=4, l_max=6):
|
352 |
handlers = []
|
353 |
+
|
354 |
def hook(module, input, output):
|
355 |
output[:] *= tau
|
356 |
return output
|
357 |
+
|
358 |
for i in range(l_min, l_max):
|
359 |
handler = self.ace_step_transformer.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook)
|
360 |
handlers.append(handler)
|
361 |
+
|
362 |
encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode(**inputs)
|
363 |
+
|
364 |
for hook in handlers:
|
365 |
hook.remove()
|
366 |
+
|
367 |
return encoder_hidden_states
|
368 |
|
369 |
# P(speaker, text, lyric)
|
|
|
396 |
torch.zeros_like(lyric_token_ids),
|
397 |
lyric_mask,
|
398 |
)
|
399 |
+
|
400 |
encoder_hidden_states_no_lyric = None
|
401 |
if do_double_condition_guidance:
|
402 |
# P(null_speaker, text, lyric_weaker)
|
|
|
423 |
|
424 |
def forward_diffusion_with_temperature(self, hidden_states, timestep, inputs, tau=0.01, l_min=15, l_max=20):
|
425 |
handlers = []
|
426 |
+
|
427 |
def hook(module, input, output):
|
428 |
output[:] *= tau
|
429 |
return output
|
430 |
+
|
431 |
for i in range(l_min, l_max):
|
432 |
handler = self.ace_step_transformer.transformer_blocks[i].attn.to_q.register_forward_hook(hook)
|
433 |
handlers.append(handler)
|
|
|
435 |
handlers.append(handler)
|
436 |
|
437 |
sample = self.ace_step_transformer.decode(hidden_states=hidden_states, timestep=timestep, **inputs).sample
|
438 |
+
|
439 |
for hook in handlers:
|
440 |
hook.remove()
|
441 |
+
|
442 |
return sample
|
443 |
|
|
|
444 |
for i, t in tqdm(enumerate(timesteps), total=num_inference_steps):
|
445 |
# expand the latents if we are doing classifier free guidance
|
446 |
latents = target_latents
|
|
|
545 |
).sample
|
546 |
|
547 |
target_latents = scheduler.step(model_output=noise_pred, timestep=t, sample=target_latents, return_dict=False, omega=omega_scale)[0]
|
548 |
+
|
549 |
return target_latents
|
550 |
|
551 |
def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
|
|
|
620 |
oss_steps = list(map(int, oss_steps.split(",")))
|
621 |
else:
|
622 |
oss_steps = []
|
623 |
+
|
624 |
texts = [prompt]
|
625 |
encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(texts, self.device)
|
626 |
encoder_text_hidden_states = encoder_text_hidden_states.repeat(batch_size, 1, 1)
|