Sayoyo commited on
Commit
851cdce
·
1 Parent(s): 7b2e525

[fix] extend

Browse files
Files changed (3) hide show
  1. apg_guidance.py +6 -5
  2. app.py +1 -1
  3. pipeline_ace_step.py +13 -13
apg_guidance.py CHANGED
@@ -17,14 +17,15 @@ def project(
17
  dims=[-1, -2],
18
  ):
19
  dtype = v0.dtype
20
- if v0.device.type == "mps":
21
- v0, v1 = v0.float(), v1.float()
22
- else:
23
- v0, v1 = v0.double(), v1.double()
 
24
  v1 = torch.nn.functional.normalize(v1, dim=dims)
25
  v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
26
  v0_orthogonal = v0 - v0_parallel
27
- return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
28
 
29
 
30
  def apg_forward(
 
17
  dims=[-1, -2],
18
  ):
19
  dtype = v0.dtype
20
+ device_type = v0.device.type
21
+ if device_type == "mps":
22
+ v0, v1 = v0.cpu(), v1.cpu()
23
+
24
+ v0, v1 = v0.double(), v1.double()
25
  v1 = torch.nn.functional.normalize(v1, dim=dims)
26
  v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
27
  v0_orthogonal = v0 - v0_parallel
28
+ return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
29
 
30
 
31
  def apg_forward(
app.py CHANGED
@@ -7,7 +7,7 @@ import os
7
 
8
  parser = argparse.ArgumentParser()
9
  parser.add_argument("--checkpoint_path", type=str, default=None)
10
- parser.add_argument("--server_name", type=str, default="0.0.0.0")
11
  parser.add_argument("--port", type=int, default=7860)
12
  parser.add_argument("--device_id", type=int, default=0)
13
  parser.add_argument("--share", action='store_true', default=False)
 
7
 
8
  parser = argparse.ArgumentParser()
9
  parser.add_argument("--checkpoint_path", type=str, default=None)
10
+ parser.add_argument("--server_name", type=str, default="127.0.0.1")
11
  parser.add_argument("--port", type=int, default=7860)
12
  parser.add_argument("--device_id", type=int, default=0)
13
  parser.add_argument("--share", action='store_true', default=False)
pipeline_ace_step.py CHANGED
@@ -68,8 +68,8 @@ class ACEStepPipeline:
68
  if device.type == "cpu" and torch.backends.mps.is_available():
69
  device = torch.device("mps")
70
  self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
71
- if device.type == "mps" and self.dtype == torch.bfloat16:
72
- self.dtype = torch.float16
73
  self.device = device
74
  self.loaded = False
75
  self.torch_compile = torch_compile
@@ -181,33 +181,33 @@ class ACEStepPipeline:
181
  last_hidden_states = outputs.last_hidden_state
182
  attention_mask = inputs["attention_mask"]
183
  return last_hidden_states, attention_mask
184
-
185
  def get_text_embeddings_null(self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10):
186
  inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
187
  inputs = {key: value.to(device) for key, value in inputs.items()}
188
  if self.text_encoder_model.device != device:
189
  self.text_encoder_model.to(device)
190
-
191
  def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
192
  handlers = []
193
-
194
  def hook(module, input, output):
195
  output[:] *= tau
196
  return output
197
-
198
  for i in range(l_min, l_max):
199
  handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook)
200
  handlers.append(handler)
201
-
202
  with torch.no_grad():
203
  outputs = self.text_encoder_model(**inputs)
204
  last_hidden_states = outputs.last_hidden_state
205
-
206
  for hook in handlers:
207
  hook.remove()
208
-
209
  return last_hidden_states
210
-
211
  last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
212
  return last_hidden_states
213
 
@@ -236,7 +236,7 @@ class ACEStepPipeline:
236
 
237
  def get_lang(self, text):
238
  language = "en"
239
- try:
240
  _ = self.lang_segment.getTexts(text)
241
  langCounts = self.lang_segment.getCounts()
242
  language = langCounts[0][0]
@@ -912,9 +912,9 @@ class ACEStepPipeline:
912
 
913
  if is_extend:
914
  if to_right_pad_gt_latents is not None:
915
- target_latents = torch.cate([target_latents, to_right_pad_gt_latents], dim=-1)
916
  if to_left_pad_gt_latents is not None:
917
- target_latents = torch.cate([to_right_pad_gt_latents, target_latents], dim=0)
918
  return target_latents
919
 
920
  def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):
 
68
  if device.type == "cpu" and torch.backends.mps.is_available():
69
  device = torch.device("mps")
70
  self.dtype = torch.bfloat16 if dtype == "bfloat16" else torch.float32
71
+ if device.type == "mps":
72
+ self.dtype = torch.float32
73
  self.device = device
74
  self.loaded = False
75
  self.torch_compile = torch_compile
 
181
  last_hidden_states = outputs.last_hidden_state
182
  attention_mask = inputs["attention_mask"]
183
  return last_hidden_states, attention_mask
184
+
185
  def get_text_embeddings_null(self, texts, device, text_max_length=256, tau=0.01, l_min=8, l_max=10):
186
  inputs = self.text_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=text_max_length)
187
  inputs = {key: value.to(device) for key, value in inputs.items()}
188
  if self.text_encoder_model.device != device:
189
  self.text_encoder_model.to(device)
190
+
191
  def forward_with_temperature(inputs, tau=0.01, l_min=8, l_max=10):
192
  handlers = []
193
+
194
  def hook(module, input, output):
195
  output[:] *= tau
196
  return output
197
+
198
  for i in range(l_min, l_max):
199
  handler = self.text_encoder_model.encoder.block[i].layer[0].SelfAttention.q.register_forward_hook(hook)
200
  handlers.append(handler)
201
+
202
  with torch.no_grad():
203
  outputs = self.text_encoder_model(**inputs)
204
  last_hidden_states = outputs.last_hidden_state
205
+
206
  for hook in handlers:
207
  hook.remove()
208
+
209
  return last_hidden_states
210
+
211
  last_hidden_states = forward_with_temperature(inputs, tau, l_min, l_max)
212
  return last_hidden_states
213
 
 
236
 
237
  def get_lang(self, text):
238
  language = "en"
239
+ try:
240
  _ = self.lang_segment.getTexts(text)
241
  langCounts = self.lang_segment.getCounts()
242
  language = langCounts[0][0]
 
912
 
913
  if is_extend:
914
  if to_right_pad_gt_latents is not None:
915
+ target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1)
916
  if to_left_pad_gt_latents is not None:
917
+ target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0)
918
  return target_latents
919
 
920
  def latents2audio(self, latents, target_wav_duration_second=30, sample_rate=48000, save_path=None, format="flac"):