ZYMPKU commited on
Commit
660f2af
1 Parent(s): 8876f9a
Files changed (2) hide show
  1. configs/demo.yaml +1 -0
  2. util.py +7 -6
configs/demo.yaml CHANGED
@@ -24,6 +24,7 @@ dual_conditioner: False
24
  steps: 50
25
  init_step: 0
26
  num_workers: 0
 
27
  gpu: 0
28
  max_iter: 100
29
 
 
24
  steps: 50
25
  init_step: 0
26
  num_workers: 0
27
+ use_gpu: False
28
  gpu: 0
29
  max_iter: 100
30
 
util.py CHANGED
@@ -32,18 +32,19 @@ SD_XL_BASE_RATIOS = {
32
  "3.0": (1728, 576),
33
  }
34
 
35
- def init_model(cfg):
36
 
37
- model_cfg = OmegaConf.load(cfg.model_cfg_path)
38
- ckpt = cfg.load_ckpt_path
39
 
40
  model = instantiate_from_config(model_cfg.model)
41
  model.init_from_ckpt(ckpt)
42
 
43
- if cfg.type == "train":
44
  model.train()
45
  else:
46
- model.to(torch.device("cuda", index=cfg.gpu))
 
47
  model.eval()
48
  model.freeze()
49
 
@@ -108,7 +109,7 @@ def deep_copy(batch):
108
  def prepare_batch(cfgs, batch):
109
 
110
  for key in batch:
111
- if isinstance(batch[key], torch.Tensor):
112
  batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
113
 
114
  if not cfgs.dual_conditioner:
 
32
  "3.0": (1728, 576),
33
  }
34
 
35
+ def init_model(cfgs):
36
 
37
+ model_cfg = OmegaConf.load(cfgs.model_cfg_path)
38
+ ckpt = cfgs.load_ckpt_path
39
 
40
  model = instantiate_from_config(model_cfg.model)
41
  model.init_from_ckpt(ckpt)
42
 
43
+ if cfgs.type == "train":
44
  model.train()
45
  else:
46
+ if cfgs.use_gpu:
47
+ model.to(torch.device("cuda", index=cfgs.gpu))
48
  model.eval()
49
  model.freeze()
50
 
 
109
  def prepare_batch(cfgs, batch):
110
 
111
  for key in batch:
112
+ if isinstance(batch[key], torch.Tensor) and cfgs.use_gpu:
113
  batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
114
 
115
  if not cfgs.dual_conditioner: