Spaces:
Runtime error
Runtime error
ethanNeuralImage
commited on
Commit
•
6fa3e0e
1
Parent(s):
838d7c7
fix GPU usage to be optional
Browse files- hyperstyle_global_directions/edit.py +6 -6
- hyperstyle_global_directions/global_direction.py +6 -5
- hyperstyle_global_directions/stylespace_utils.py +2 -2
- models/hyperstyle/hypernetworks/hypernetwork.py +1 -1
- models/hyperstyle/hypernetworks/shared_weights_hypernet.py +6 -5
- models/hyperstyle/hyperstyle.py +1 -1
hyperstyle_global_directions/edit.py
CHANGED
@@ -41,18 +41,18 @@ def parse_args(args_list=None):
|
|
41 |
|
42 |
|
43 |
def load_direction_calculator(args):
|
44 |
-
delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().
|
45 |
with open(args.s_statistics, "rb") as channels_statistics:
|
46 |
_, s_std = pickle.load(channels_statistics)
|
47 |
-
s_std = [torch.from_numpy(s_i).float().
|
48 |
with open(args.text_prompt_templates, "r") as templates:
|
49 |
text_prompt_templates = templates.readlines()
|
50 |
-
global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
|
51 |
return global_direction_calculator
|
52 |
|
53 |
|
54 |
def load_stylegan_generator(args):
|
55 |
-
stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).
|
56 |
checkpoint = torch.load(args.stylegan_weights)
|
57 |
stylegan_model.load_state_dict(checkpoint['g_ema'])
|
58 |
return stylegan_model
|
@@ -72,7 +72,7 @@ def run():
|
|
72 |
if args.n_images is not None and idx >= args.n_images:
|
73 |
break
|
74 |
weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True)
|
75 |
-
weight_deltas = [torch.from_numpy(w).
|
76 |
latent = torch.from_numpy(latent)
|
77 |
results, results_latent, source_img = edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas)
|
78 |
torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg",
|
@@ -80,7 +80,7 @@ def run():
|
|
80 |
|
81 |
|
82 |
def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas=None):
|
83 |
-
latent_code = latent.
|
84 |
truncation = 1
|
85 |
mean_latent = None
|
86 |
input_is_latent = True
|
|
|
41 |
|
42 |
|
43 |
def load_direction_calculator(args):
|
44 |
+
delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().to(args.device)
|
45 |
with open(args.s_statistics, "rb") as channels_statistics:
|
46 |
_, s_std = pickle.load(channels_statistics)
|
47 |
+
s_std = [torch.from_numpy(s_i).float().to(args.device) for s_i in s_std]
|
48 |
with open(args.text_prompt_templates, "r") as templates:
|
49 |
text_prompt_templates = templates.readlines()
|
50 |
+
global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, args.device)
|
51 |
return global_direction_calculator
|
52 |
|
53 |
|
54 |
def load_stylegan_generator(args):
|
55 |
+
stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).to(args.device)
|
56 |
checkpoint = torch.load(args.stylegan_weights)
|
57 |
stylegan_model.load_state_dict(checkpoint['g_ema'])
|
58 |
return stylegan_model
|
|
|
72 |
if args.n_images is not None and idx >= args.n_images:
|
73 |
break
|
74 |
weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True)
|
75 |
+
weight_deltas = [torch.from_numpy(w).to(args.device) if w is not None else None for w in weight_deltas]
|
76 |
latent = torch.from_numpy(latent)
|
77 |
results, results_latent, source_img = edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas)
|
78 |
torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg",
|
|
|
80 |
|
81 |
|
82 |
def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas=None):
|
83 |
+
latent_code = latent.to(args.device)
|
84 |
truncation = 1
|
85 |
mean_latent = None
|
86 |
input_is_latent = True
|
hyperstyle_global_directions/global_direction.py
CHANGED
@@ -7,12 +7,13 @@ from hyperstyle_global_directions.stylespace_utils import features_channels_to_s
|
|
7 |
|
8 |
class StyleCLIPGlobalDirection:
|
9 |
|
10 |
-
def __init__(self, delta_i_c, s_std, text_prompts_templates):
|
11 |
super(StyleCLIPGlobalDirection, self).__init__()
|
|
|
12 |
self.delta_i_c = delta_i_c
|
13 |
self.s_std = s_std
|
14 |
self.text_prompts_templates = text_prompts_templates
|
15 |
-
self.clip_model, _ = clip.load("ViT-B/32", device=
|
16 |
|
17 |
def get_delta_s(self, neutral_text, target_text, beta):
|
18 |
delta_i = self.get_delta_i([target_text, neutral_text]).float()
|
@@ -23,7 +24,7 @@ class StyleCLIPGlobalDirection:
|
|
23 |
max_channel_value = torch.abs(delta_s).max()
|
24 |
if max_channel_value > 0:
|
25 |
delta_s /= max_channel_value
|
26 |
-
direction = features_channels_to_s(delta_s, self.s_std)
|
27 |
return direction
|
28 |
|
29 |
def get_delta_i(self, text_prompts):
|
@@ -37,11 +38,11 @@ class StyleCLIPGlobalDirection:
|
|
37 |
text_features_list = []
|
38 |
for text_prompt in text_prompts:
|
39 |
formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class
|
40 |
-
formatted_text_prompts = clip.tokenize(formatted_text_prompts).
|
41 |
text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder
|
42 |
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
|
43 |
text_embedding = text_embeddings.mean(dim=0)
|
44 |
text_embedding /= text_embedding.norm()
|
45 |
text_features_list.append(text_embedding)
|
46 |
-
text_features = torch.stack(text_features_list, dim=1).
|
47 |
return text_features.t()
|
|
|
7 |
|
8 |
class StyleCLIPGlobalDirection:
|
9 |
|
10 |
+
def __init__(self, delta_i_c, s_std, text_prompts_templates, device='cuda'):
|
11 |
super(StyleCLIPGlobalDirection, self).__init__()
|
12 |
+
self.device=device
|
13 |
self.delta_i_c = delta_i_c
|
14 |
self.s_std = s_std
|
15 |
self.text_prompts_templates = text_prompts_templates
|
16 |
+
self.clip_model, _ = clip.load("ViT-B/32", device=device)
|
17 |
|
18 |
def get_delta_s(self, neutral_text, target_text, beta):
|
19 |
delta_i = self.get_delta_i([target_text, neutral_text]).float()
|
|
|
24 |
max_channel_value = torch.abs(delta_s).max()
|
25 |
if max_channel_value > 0:
|
26 |
delta_s /= max_channel_value
|
27 |
+
direction = features_channels_to_s(delta_s, self.s_std, self.device)
|
28 |
return direction
|
29 |
|
30 |
def get_delta_i(self, text_prompts):
|
|
|
38 |
text_features_list = []
|
39 |
for text_prompt in text_prompts:
|
40 |
formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class
|
41 |
+
formatted_text_prompts = clip.tokenize(formatted_text_prompts).to(self.device) # tokenize
|
42 |
text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder
|
43 |
text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
|
44 |
text_embedding = text_embeddings.mean(dim=0)
|
45 |
text_embedding /= text_embedding.norm()
|
46 |
text_features_list.append(text_embedding)
|
47 |
+
text_features = torch.stack(text_features_list, dim=1).to(self.device)
|
48 |
return text_features.t()
|
hyperstyle_global_directions/stylespace_utils.py
CHANGED
@@ -5,7 +5,7 @@ STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128,
|
|
5 |
TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
|
6 |
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11]
|
7 |
|
8 |
-
def features_channels_to_s(s_without_torgb, s_std):
|
9 |
s = []
|
10 |
start_index_features = 0
|
11 |
for c in range(len(STYLESPACE_DIMENSIONS)):
|
@@ -14,7 +14,7 @@ def features_channels_to_s(s_without_torgb, s_std):
|
|
14 |
s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c]
|
15 |
start_index_features = end_index_features
|
16 |
else:
|
17 |
-
s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).
|
18 |
s_i = s_i.view(1, 1, -1, 1, 1)
|
19 |
s.append(s_i)
|
20 |
return s
|
|
|
5 |
TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
|
6 |
STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11]
|
7 |
|
8 |
+
def features_channels_to_s(s_without_torgb, s_std, device='cuda'):
|
9 |
s = []
|
10 |
start_index_features = 0
|
11 |
for c in range(len(STYLESPACE_DIMENSIONS)):
|
|
|
14 |
s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c]
|
15 |
start_index_features = end_index_features
|
16 |
else:
|
17 |
+
s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).to(device)
|
18 |
s_i = s_i.view(1, 1, -1, 1, 1)
|
19 |
s.append(s_i)
|
20 |
return s
|
models/hyperstyle/hypernetworks/hypernetwork.py
CHANGED
@@ -34,7 +34,7 @@ class SharedWeightsHyperNetResNet(Module):
|
|
34 |
self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')]
|
35 |
|
36 |
self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12]
|
37 |
-
self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None)
|
38 |
|
39 |
self.refinement_blocks = nn.ModuleList()
|
40 |
self.n_outputs = opts.n_hypernet_outputs
|
|
|
34 |
self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')]
|
35 |
|
36 |
self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12]
|
37 |
+
self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None, device=opts.device)
|
38 |
|
39 |
self.refinement_blocks = nn.ModuleList()
|
40 |
self.n_outputs = opts.n_hypernet_outputs
|
models/hyperstyle/hypernetworks/shared_weights_hypernet.py
CHANGED
@@ -5,8 +5,9 @@ from torch.nn.parameter import Parameter
|
|
5 |
|
6 |
class SharedWeightsHypernet(nn.Module):
|
7 |
|
8 |
-
def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None):
|
9 |
super(SharedWeightsHypernet, self).__init__()
|
|
|
10 |
self.mode = mode
|
11 |
self.z_dim = z_dim
|
12 |
self.f_size = f_size
|
@@ -15,11 +16,11 @@ class SharedWeightsHypernet(nn.Module):
|
|
15 |
self.out_size = out_size
|
16 |
self.in_size = in_size
|
17 |
|
18 |
-
self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).
|
19 |
-
self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).
|
20 |
|
21 |
-
self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).
|
22 |
-
self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).
|
23 |
|
24 |
def forward(self, z):
|
25 |
batch_size = z.shape[0]
|
|
|
5 |
|
6 |
class SharedWeightsHypernet(nn.Module):
|
7 |
|
8 |
+
def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None, device='cuda'):
|
9 |
super(SharedWeightsHypernet, self).__init__()
|
10 |
+
self.device=device
|
11 |
self.mode = mode
|
12 |
self.z_dim = z_dim
|
13 |
self.f_size = f_size
|
|
|
16 |
self.out_size = out_size
|
17 |
self.in_size = in_size
|
18 |
|
19 |
+
self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2))
|
20 |
+
self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2))
|
21 |
|
22 |
+
self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).to(self.device) / 40, 2))
|
23 |
+
self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).to(self.device) / 40, 2))
|
24 |
|
25 |
def forward(self, z):
|
26 |
batch_size = z.shape[0]
|
models/hyperstyle/hyperstyle.py
CHANGED
@@ -145,7 +145,7 @@ class HyperStyle(nn.Module):
|
|
145 |
w_net = pSp(opts_w_encoder)
|
146 |
w_net = w_net.encoder
|
147 |
w_net.eval()
|
148 |
-
w_net.
|
149 |
return w_net
|
150 |
|
151 |
def __get_initial_inversion(self, x, resize=True):
|
|
|
145 |
w_net = pSp(opts_w_encoder)
|
146 |
w_net = w_net.encoder
|
147 |
w_net.eval()
|
148 |
+
w_net.to(self.opts.device)
|
149 |
return w_net
|
150 |
|
151 |
def __get_initial_inversion(self, x, resize=True):
|