Sayoyo commited on
Commit
a8bbbf9
·
1 Parent(s): 65fbe9a
Files changed (2) hide show
  1. apg_guidance.py +6 -5
  2. pipeline_ace_step.py +1 -0
apg_guidance.py CHANGED
@@ -17,14 +17,15 @@ def project(
17
  dims=[-1, -2],
18
  ):
19
  dtype = v0.dtype
20
- if v0.device.type == "mps":
21
- v0, v1 = v0.float(), v1.float()
22
- else:
23
- v0, v1 = v0.double(), v1.double()
 
24
  v1 = torch.nn.functional.normalize(v1, dim=dims)
25
  v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
26
  v0_orthogonal = v0 - v0_parallel
27
- return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
28
 
29
 
30
  def apg_forward(
 
17
  dims=[-1, -2],
18
  ):
19
  dtype = v0.dtype
20
+ device_type = v0.device.type
21
+ if device_type == "mps":
22
+ v0, v1 = v0.cpu(), v1.cpu()
23
+
24
+ v0, v1 = v0.double(), v1.double()
25
  v1 = torch.nn.functional.normalize(v1, dim=dims)
26
  v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
27
  v0_orthogonal = v0 - v0_parallel
28
+ return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
29
 
30
 
31
  def apg_forward(
pipeline_ace_step.py CHANGED
@@ -955,6 +955,7 @@ class ACEStepPipeline:
955
  latents, _ = self.music_dcae.encode(input_audio, sr=sr)
956
  return latents
957
 
 
958
  def __call__(
959
  self,
960
  audio_duration: float = 60.0,
 
955
  latents, _ = self.music_dcae.encode(input_audio, sr=sr)
956
  return latents
957
 
958
+ @spaces.GPU
959
  def __call__(
960
  self,
961
  audio_duration: float = 60.0,