harsh99 commited on
Commit
ef4558b
·
verified ·
1 Parent(s): 915e59a

Update CatVTON_model.py

Browse files
Files changed (1) hide show
  1. CatVTON_model.py +4 -4
CatVTON_model.py CHANGED
@@ -16,7 +16,7 @@ class CatVTONPix2PixPipeline:
16
  def __init__(
17
  self,
18
  weight_dtype=torch.float32,
19
- device='cuda',
20
  compile=False,
21
  skip_safety_check=True,
22
  use_tf32=True,
@@ -35,9 +35,9 @@ class CatVTONPix2PixPipeline:
35
 
36
  self.unet=models.get('diffusion', None)
37
  # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).
38
- if use_tf32:
39
- torch.set_float32_matmul_precision("high")
40
- torch.backends.cuda.matmul.allow_tf32 = True
41
 
42
  @torch.no_grad()
43
  def __call__(
 
16
  def __init__(
17
  self,
18
  weight_dtype=torch.float32,
19
+ device='cpu',
20
  compile=False,
21
  skip_safety_check=True,
22
  use_tf32=True,
 
35
 
36
  self.unet=models.get('diffusion', None)
37
  # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).
38
+ # if use_tf32:
39
+ # torch.set_float32_matmul_precision("high")
40
+ # torch.backends.cuda.matmul.allow_tf32 = True
41
 
42
  @torch.no_grad()
43
  def __call__(