Spaces:
Runtime error
Runtime error
Update wan/modules/vae.py
Browse files- wan/modules/vae.py +3 -3
wan/modules/vae.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import logging
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
import torch.
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from einops import rearrange
|
|
@@ -648,14 +648,14 @@ class WanVAE:
|
|
| 648 |
"""
|
| 649 |
videos: A list of videos each with shape [C, T, H, W].
|
| 650 |
"""
|
| 651 |
-
with amp.autocast(dtype=self.dtype):
|
| 652 |
return [
|
| 653 |
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
| 654 |
for u in videos
|
| 655 |
]
|
| 656 |
|
| 657 |
def decode(self, zs):
|
| 658 |
-
with amp.autocast(dtype=self.dtype):
|
| 659 |
return [
|
| 660 |
self.model.decode(u.unsqueeze(0),
|
| 661 |
self.scale).float().clamp_(-1, 1).squeeze(0)
|
|
|
|
| 2 |
import logging
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
import torch.amp as amp
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from einops import rearrange
|
|
|
|
| 648 |
"""
|
| 649 |
videos: A list of videos each with shape [C, T, H, W].
|
| 650 |
"""
|
| 651 |
+
with amp.autocast("cuda", dtype=self.dtype):
|
| 652 |
return [
|
| 653 |
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
| 654 |
for u in videos
|
| 655 |
]
|
| 656 |
|
| 657 |
def decode(self, zs):
|
| 658 |
+
with amp.autocast("cuda", dtype=self.dtype):
|
| 659 |
return [
|
| 660 |
self.model.decode(u.unsqueeze(0),
|
| 661 |
self.scale).float().clamp_(-1, 1).squeeze(0)
|