ethanNeuralImage commited on
Commit
34e88c0
1 Parent(s): 6fa3e0e

fix device setting

Browse files
models/hyperstyle/hypernetworks/hypernetwork.py CHANGED
@@ -91,7 +91,7 @@ class SharedWeightsHyperNetResNetSeparable(Module):
91
  self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')]
92
 
93
  self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12]
94
- self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None)
95
 
96
  self.refinement_blocks = nn.ModuleList()
97
  self.n_outputs = opts.n_hypernet_outputs
 
91
  self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')]
92
 
93
  self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12]
94
+ self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None, device=opts.device)
95
 
96
  self.refinement_blocks = nn.ModuleList()
97
  self.n_outputs = opts.n_hypernet_outputs
models/hyperstyle/utils/model_utils.py CHANGED
@@ -32,7 +32,7 @@ def load_model(checkpoint_path, device='cuda', update_opts=None, is_restyle_enco
32
  net = HyperStyle(opts)
33
 
34
  net.eval()
35
- net.to(device)
36
  return net, opts
37
 
38
 
 
32
  net = HyperStyle(opts)
33
 
34
  net.eval()
35
+ net.to(opts.device)
36
  return net, opts
37
 
38