robin-courant commited on
Commit
3521db3
1 Parent(s): 0d05803

Update utils/common_viz.py

Browse files
Files changed (1) hide show
  1. utils/common_viz.py +6 -6
utils/common_viz.py CHANGED
@@ -106,23 +106,23 @@ def get_batch(
106
 
107
  def init(
108
  config_name: str,
109
- ) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset, torch.device]:
110
  with initialize(version_base="1.3", config_path="../configs"):
111
  config = compose(config_name=config_name)
112
 
113
  OmegaConf.register_new_resolver("eval", eval)
114
 
115
  # Initialize model
116
- device = torch.device(config.compnode.device)
117
  diffuser = instantiate(config.diffuser)
118
- state_dict = torch.load(config.checkpoint_path, map_location=device)["state_dict"]
119
  state_dict["ema.initted"] = diffuser.ema.initted
120
  state_dict["ema.step"] = diffuser.ema.step
121
  diffuser.load_state_dict(state_dict, strict=False)
122
- diffuser.to(device).eval()
123
 
124
  # Initialize CLIP model
125
- clip_model = load_clip_model("ViT-B/32", device)
126
 
127
  # Initialize dataset
128
  config.dataset.char.load_vertices = True
@@ -133,4 +133,4 @@ def init(
133
  diffuser.get_matrix = dataset.get_matrix
134
  diffuser.v_get_matrix = dataset.get_matrix
135
 
136
- return diffuser, clip_model, dataset, device
 
106
 
107
  def init(
108
  config_name: str,
109
+ ) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset]:
110
  with initialize(version_base="1.3", config_path="../configs"):
111
  config = compose(config_name=config_name)
112
 
113
  OmegaConf.register_new_resolver("eval", eval)
114
 
115
  # Initialize model
116
+ # device = torch.device(config.compnode.device)
117
  diffuser = instantiate(config.diffuser)
118
+ state_dict = torch.load(config.checkpoint_path, map_location="cpu")["state_dict"]
119
  state_dict["ema.initted"] = diffuser.ema.initted
120
  state_dict["ema.step"] = diffuser.ema.step
121
  diffuser.load_state_dict(state_dict, strict=False)
122
+ diffuser.to("cpu").eval()
123
 
124
  # Initialize CLIP model
125
+ clip_model = load_clip_model("ViT-B/32", "cpu")
126
 
127
  # Initialize dataset
128
  config.dataset.char.load_vertices = True
 
133
  diffuser.get_matrix = dataset.get_matrix
134
  diffuser.v_get_matrix = dataset.get_matrix
135
 
136
+ return diffuser, clip_model, dataset, config.compnode.device