Spaces:
Running
Running
Update CatVTON_model.py
Browse files- CatVTON_model.py +4 -4
CatVTON_model.py
CHANGED
|
@@ -16,7 +16,7 @@ class CatVTONPix2PixPipeline:
|
|
| 16 |
def __init__(
|
| 17 |
self,
|
| 18 |
weight_dtype=torch.float32,
|
| 19 |
-
device='
|
| 20 |
compile=False,
|
| 21 |
skip_safety_check=True,
|
| 22 |
use_tf32=True,
|
|
@@ -35,9 +35,9 @@ class CatVTONPix2PixPipeline:
|
|
| 35 |
|
| 36 |
self.unet=models.get('diffusion', None)
|
| 37 |
# # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).
|
| 38 |
-
if use_tf32:
|
| 39 |
-
|
| 40 |
-
|
| 41 |
|
| 42 |
@torch.no_grad()
|
| 43 |
def __call__(
|
|
|
|
| 16 |
def __init__(
|
| 17 |
self,
|
| 18 |
weight_dtype=torch.float32,
|
| 19 |
+
device='cpu',
|
| 20 |
compile=False,
|
| 21 |
skip_safety_check=True,
|
| 22 |
use_tf32=True,
|
|
|
|
| 35 |
|
| 36 |
self.unet=models.get('diffusion', None)
|
| 37 |
# # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).
|
| 38 |
+
# if use_tf32:
|
| 39 |
+
# torch.set_float32_matmul_precision("high")
|
| 40 |
+
# torch.backends.cuda.matmul.allow_tf32 = True
|
| 41 |
|
| 42 |
@torch.no_grad()
|
| 43 |
def __call__(
|