zero gpu
Browse files- apg_guidance.py +6 -5
- pipeline_ace_step.py +1 -0
apg_guidance.py
CHANGED
@@ -17,14 +17,15 @@ def project(
|
|
17 |
dims=[-1, -2],
|
18 |
):
|
19 |
dtype = v0.dtype
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
24 |
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
25 |
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
26 |
v0_orthogonal = v0 - v0_parallel
|
27 |
-
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
28 |
|
29 |
|
30 |
def apg_forward(
|
|
|
17 |
dims=[-1, -2],
|
18 |
):
|
19 |
dtype = v0.dtype
|
20 |
+
device_type = v0.device.type
|
21 |
+
if device_type == "mps":
|
22 |
+
v0, v1 = v0.cpu(), v1.cpu()
|
23 |
+
|
24 |
+
v0, v1 = v0.double(), v1.double()
|
25 |
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
26 |
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
27 |
v0_orthogonal = v0 - v0_parallel
|
28 |
+
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
|
29 |
|
30 |
|
31 |
def apg_forward(
|
pipeline_ace_step.py
CHANGED
@@ -955,6 +955,7 @@ class ACEStepPipeline:
|
|
955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
956 |
return latents
|
957 |
|
|
|
958 |
def __call__(
|
959 |
self,
|
960 |
audio_duration: float = 60.0,
|
|
|
955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
956 |
return latents
|
957 |
|
958 |
+
@spaces.GPU
|
959 |
def __call__(
|
960 |
self,
|
961 |
audio_duration: float = 60.0,
|