yyk19 commited on
Commit
ba96fba
1 Parent(s): 86f09db

fix bugs related to cuda

Browse files
Files changed (1) hide show
  1. cldm/ddim_hacked.py +1 -1
cldm/ddim_hacked.py CHANGED
@@ -16,7 +16,7 @@ class DDIMSampler(object):
16
 
17
  def register_buffer(self, name, attr):
18
  if type(attr) == torch.Tensor:
19
- if attr.device != torch.device("cuda"):
20
  attr = attr.to(torch.device("cuda"))
21
  setattr(self, name, attr)
22
 
 
16
 
17
  def register_buffer(self, name, attr):
18
  if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda") and torch.cuda.is_available():
20
  attr = attr.to(torch.device("cuda"))
21
  setattr(self, name, attr)
22