Spaces:
Runtime error
Runtime error
Commit
•
34e88c0
1
Parent(s):
6fa3e0e
fix device setting
Browse files
models/hyperstyle/hypernetworks/hypernetwork.py
CHANGED
@@ -91,7 +91,7 @@ class SharedWeightsHyperNetResNetSeparable(Module):
|
|
91 |
self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')]
|
92 |
|
93 |
self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12]
|
94 |
-
self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None)
|
95 |
|
96 |
self.refinement_blocks = nn.ModuleList()
|
97 |
self.n_outputs = opts.n_hypernet_outputs
|
|
|
91 |
self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')]
|
92 |
|
93 |
self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12]
|
94 |
+
self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None, device=opts.device)
|
95 |
|
96 |
self.refinement_blocks = nn.ModuleList()
|
97 |
self.n_outputs = opts.n_hypernet_outputs
|
models/hyperstyle/utils/model_utils.py
CHANGED
@@ -32,7 +32,7 @@ def load_model(checkpoint_path, device='cuda', update_opts=None, is_restyle_enco
|
|
32 |
net = HyperStyle(opts)
|
33 |
|
34 |
net.eval()
|
35 |
-
net.to(device)
|
36 |
return net, opts
|
37 |
|
38 |
|
|
|
32 |
net = HyperStyle(opts)
|
33 |
|
34 |
net.eval()
|
35 |
+
net.to(opts.device)
|
36 |
return net, opts
|
37 |
|
38 |
|