ethanNeuralImage commited on
Commit
6fa3e0e
1 Parent(s): 838d7c7

fix GPU usage to be optional

Browse files
hyperstyle_global_directions/edit.py CHANGED
@@ -41,18 +41,18 @@ def parse_args(args_list=None):
41
 
42
 
43
  def load_direction_calculator(args):
44
- delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().cuda()
45
  with open(args.s_statistics, "rb") as channels_statistics:
46
  _, s_std = pickle.load(channels_statistics)
47
- s_std = [torch.from_numpy(s_i).float().cuda() for s_i in s_std]
48
  with open(args.text_prompt_templates, "r") as templates:
49
  text_prompt_templates = templates.readlines()
50
- global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates)
51
  return global_direction_calculator
52
 
53
 
54
  def load_stylegan_generator(args):
55
- stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).cuda()
56
  checkpoint = torch.load(args.stylegan_weights)
57
  stylegan_model.load_state_dict(checkpoint['g_ema'])
58
  return stylegan_model
@@ -72,7 +72,7 @@ def run():
72
  if args.n_images is not None and idx >= args.n_images:
73
  break
74
  weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True)
75
- weight_deltas = [torch.from_numpy(w).cuda() if w is not None else None for w in weight_deltas]
76
  latent = torch.from_numpy(latent)
77
  results, results_latent, source_img = edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas)
78
  torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg",
@@ -80,7 +80,7 @@ def run():
80
 
81
 
82
  def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas=None):
83
- latent_code = latent.cuda()
84
  truncation = 1
85
  mean_latent = None
86
  input_is_latent = True
 
41
 
42
 
43
  def load_direction_calculator(args):
44
+ delta_i_c = torch.from_numpy(np.load(args.delta_i_c)).float().to(args.device)
45
  with open(args.s_statistics, "rb") as channels_statistics:
46
  _, s_std = pickle.load(channels_statistics)
47
+ s_std = [torch.from_numpy(s_i).float().to(args.device) for s_i in s_std]
48
  with open(args.text_prompt_templates, "r") as templates:
49
  text_prompt_templates = templates.readlines()
50
+ global_direction_calculator = StyleCLIPGlobalDirection(delta_i_c, s_std, text_prompt_templates, args.device)
51
  return global_direction_calculator
52
 
53
 
54
  def load_stylegan_generator(args):
55
+ stylegan_model = Generator(args.stylegan_size, 512, 8, channel_multiplier=2).to(args.device)
56
  checkpoint = torch.load(args.stylegan_weights)
57
  stylegan_model.load_state_dict(checkpoint['g_ema'])
58
  return stylegan_model
 
72
  if args.n_images is not None and idx >= args.n_images:
73
  break
74
  weight_deltas = np.load(os.path.join(args.weight_deltas_path, image_name.split(".")[0] + ".npy"), allow_pickle=True)
75
+ weight_deltas = [torch.from_numpy(w).to(args.device) if w is not None else None for w in weight_deltas]
76
  latent = torch.from_numpy(latent)
77
  results, results_latent, source_img = edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas)
78
  torchvision.utils.save_image(results, f"{args.output_path}/{image_name.split('.')[0]}.jpg",
 
80
 
81
 
82
  def edit_image(image_name, latent, stylegan_model, global_direction_calculator, args, weight_deltas=None):
83
+ latent_code = latent.to(args.device)
84
  truncation = 1
85
  mean_latent = None
86
  input_is_latent = True
hyperstyle_global_directions/global_direction.py CHANGED
@@ -7,12 +7,13 @@ from hyperstyle_global_directions.stylespace_utils import features_channels_to_s
7
 
8
  class StyleCLIPGlobalDirection:
9
 
10
- def __init__(self, delta_i_c, s_std, text_prompts_templates):
11
  super(StyleCLIPGlobalDirection, self).__init__()
 
12
  self.delta_i_c = delta_i_c
13
  self.s_std = s_std
14
  self.text_prompts_templates = text_prompts_templates
15
- self.clip_model, _ = clip.load("ViT-B/32", device="cuda")
16
 
17
  def get_delta_s(self, neutral_text, target_text, beta):
18
  delta_i = self.get_delta_i([target_text, neutral_text]).float()
@@ -23,7 +24,7 @@ class StyleCLIPGlobalDirection:
23
  max_channel_value = torch.abs(delta_s).max()
24
  if max_channel_value > 0:
25
  delta_s /= max_channel_value
26
- direction = features_channels_to_s(delta_s, self.s_std)
27
  return direction
28
 
29
  def get_delta_i(self, text_prompts):
@@ -37,11 +38,11 @@ class StyleCLIPGlobalDirection:
37
  text_features_list = []
38
  for text_prompt in text_prompts:
39
  formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class
40
- formatted_text_prompts = clip.tokenize(formatted_text_prompts).cuda() # tokenize
41
  text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder
42
  text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
43
  text_embedding = text_embeddings.mean(dim=0)
44
  text_embedding /= text_embedding.norm()
45
  text_features_list.append(text_embedding)
46
- text_features = torch.stack(text_features_list, dim=1).cuda()
47
  return text_features.t()
 
7
 
8
  class StyleCLIPGlobalDirection:
9
 
10
+ def __init__(self, delta_i_c, s_std, text_prompts_templates, device='cuda'):
11
  super(StyleCLIPGlobalDirection, self).__init__()
12
+ self.device=device
13
  self.delta_i_c = delta_i_c
14
  self.s_std = s_std
15
  self.text_prompts_templates = text_prompts_templates
16
+ self.clip_model, _ = clip.load("ViT-B/32", device=device)
17
 
18
  def get_delta_s(self, neutral_text, target_text, beta):
19
  delta_i = self.get_delta_i([target_text, neutral_text]).float()
 
24
  max_channel_value = torch.abs(delta_s).max()
25
  if max_channel_value > 0:
26
  delta_s /= max_channel_value
27
+ direction = features_channels_to_s(delta_s, self.s_std, self.device)
28
  return direction
29
 
30
  def get_delta_i(self, text_prompts):
 
38
  text_features_list = []
39
  for text_prompt in text_prompts:
40
  formatted_text_prompts = [template.format(text_prompt) for template in self.text_prompts_templates] # format with class
41
+ formatted_text_prompts = clip.tokenize(formatted_text_prompts).to(self.device) # tokenize
42
  text_embeddings = self.clip_model.encode_text(formatted_text_prompts) # embed with text encoder
43
  text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
44
  text_embedding = text_embeddings.mean(dim=0)
45
  text_embedding /= text_embedding.norm()
46
  text_features_list.append(text_embedding)
47
+ text_features = torch.stack(text_features_list, dim=1).to(self.device)
48
  return text_features.t()
hyperstyle_global_directions/stylespace_utils.py CHANGED
@@ -5,7 +5,7 @@ STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128,
5
  TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
6
  STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11]
7
 
8
- def features_channels_to_s(s_without_torgb, s_std):
9
  s = []
10
  start_index_features = 0
11
  for c in range(len(STYLESPACE_DIMENSIONS)):
@@ -14,7 +14,7 @@ def features_channels_to_s(s_without_torgb, s_std):
14
  s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c]
15
  start_index_features = end_index_features
16
  else:
17
- s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).cuda()
18
  s_i = s_i.view(1, 1, -1, 1, 1)
19
  s.append(s_i)
20
  return s
 
5
  TORGB_INDICES = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
6
  STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in TORGB_INDICES][:11]
7
 
8
+ def features_channels_to_s(s_without_torgb, s_std, device='cuda'):
9
  s = []
10
  start_index_features = 0
11
  for c in range(len(STYLESPACE_DIMENSIONS)):
 
14
  s_i = s_without_torgb[start_index_features:end_index_features] * s_std[c]
15
  start_index_features = end_index_features
16
  else:
17
+ s_i = torch.zeros(STYLESPACE_DIMENSIONS[c]).to(device)
18
  s_i = s_i.view(1, 1, -1, 1, 1)
19
  s.append(s_i)
20
  return s
models/hyperstyle/hypernetworks/hypernetwork.py CHANGED
@@ -34,7 +34,7 @@ class SharedWeightsHyperNetResNet(Module):
34
  self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')]
35
 
36
  self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12]
37
- self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None)
38
 
39
  self.refinement_blocks = nn.ModuleList()
40
  self.n_outputs = opts.n_hypernet_outputs
 
34
  self.layers_to_tune = [int(l) for l in opts.layers_to_tune.split(',')]
35
 
36
  self.shared_layers = [0, 2, 3, 5, 6, 8, 9, 11, 12]
37
+ self.shared_weight_hypernet = SharedWeightsHypernet(in_size=512, out_size=512, mode=None, device=opts.device)
38
 
39
  self.refinement_blocks = nn.ModuleList()
40
  self.n_outputs = opts.n_hypernet_outputs
models/hyperstyle/hypernetworks/shared_weights_hypernet.py CHANGED
@@ -5,8 +5,9 @@ from torch.nn.parameter import Parameter
5
 
6
  class SharedWeightsHypernet(nn.Module):
7
 
8
- def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None):
9
  super(SharedWeightsHypernet, self).__init__()
 
10
  self.mode = mode
11
  self.z_dim = z_dim
12
  self.f_size = f_size
@@ -15,11 +16,11 @@ class SharedWeightsHypernet(nn.Module):
15
  self.out_size = out_size
16
  self.in_size = in_size
17
 
18
- self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).cuda() / 40, 2))
19
- self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).cuda() / 40, 2))
20
 
21
- self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).cuda() / 40, 2))
22
- self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).cuda() / 40, 2))
23
 
24
  def forward(self, z):
25
  batch_size = z.shape[0]
 
5
 
6
  class SharedWeightsHypernet(nn.Module):
7
 
8
+ def __init__(self, f_size=3, z_dim=512, out_size=512, in_size=512, mode=None, device='cuda'):
9
  super(SharedWeightsHypernet, self).__init__()
10
+ self.device=device
11
  self.mode = mode
12
  self.z_dim = z_dim
13
  self.f_size = f_size
 
16
  self.out_size = out_size
17
  self.in_size = in_size
18
 
19
+ self.w1 = Parameter(torch.fmod(torch.randn((self.z_dim, self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2))
20
+ self.b1 = Parameter(torch.fmod(torch.randn((self.out_size * self.f_size * self.f_size)).to(self.device) / 40, 2))
21
 
22
+ self.w2 = Parameter(torch.fmod(torch.randn((self.z_dim, self.in_size * self.z_dim)).to(self.device) / 40, 2))
23
+ self.b2 = Parameter(torch.fmod(torch.randn((self.in_size * self.z_dim)).to(self.device) / 40, 2))
24
 
25
  def forward(self, z):
26
  batch_size = z.shape[0]
models/hyperstyle/hyperstyle.py CHANGED
@@ -145,7 +145,7 @@ class HyperStyle(nn.Module):
145
  w_net = pSp(opts_w_encoder)
146
  w_net = w_net.encoder
147
  w_net.eval()
148
- w_net.cuda()
149
  return w_net
150
 
151
  def __get_initial_inversion(self, x, resize=True):
 
145
  w_net = pSp(opts_w_encoder)
146
  w_net = w_net.encoder
147
  w_net.eval()
148
+ w_net.to(self.opts.device)
149
  return w_net
150
 
151
  def __get_initial_inversion(self, x, resize=True):