rinong commited on
Commit
9a29e97
1 Parent(s): 6f52ac4

Added tensor dim expansion for edit directions.

Browse files
Files changed (1) hide show
  1. styleclip/styleclip_global.py +14 -1
styleclip/styleclip_global.py CHANGED
@@ -89,6 +89,7 @@ imagenet_templates = [
89
  'a tattoo of the {}.',
90
  ]
91
 
 
92
  FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \
93
  [(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)]
94
 
@@ -107,6 +108,16 @@ def zeroshot_classifier(model, classnames, templates, device):
107
  zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
108
  return zeroshot_weights
109
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def get_direction(neutral_class, target_class, beta, di, clip_model=None):
112
 
@@ -157,6 +168,8 @@ def style_dict_to_style_tensor(style_dict, reference_generator):
157
  def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
158
  edit_direction = get_direction(source_class, target_class, beta, di, clip_model)
159
 
 
 
160
  source_s = style_dict_to_style_tensor(source_latent, reference_generator)
161
 
162
- return source_s + alpha * edit_direction
 
89
  'a tattoo of the {}.',
90
  ]
91
 
92
+ CONV_CODE_INDICES = [(0, 512), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)]
93
  FFHQ_CODE_INDICES = [(0, 512), (512, 1024), (1024, 1536), (1536, 2048), (2560, 3072), (3072, 3584), (4096, 4608), (4608, 5120), (5632, 6144), (6144, 6656), (7168, 7680), (7680, 7936), (8192, 8448), (8448, 8576), (8704, 8832), (8832, 8896), (8960, 9024), (9024, 9056)] + \
94
  [(2048, 2560), (3584, 4096), (5120, 5632), (6656, 7168), (7936, 8192), (8576, 8704), (8896, 8960), (9056, 9088)]
95
 
 
108
  zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
109
  return zeroshot_weights
110
 
111
+ def expand_to_full_dim(partial_tensor):
112
+ full_dim_tensor = torch.zeros(size=(1, 9088))
113
+
114
+ start_idx = 0
115
+ for conv_start, conv_end in CONV_CODE_INDICES:
116
+ length = conv_end - conv_start
117
+ full_dim_tensor[:, conv_start:conv_end] = partial_tensor[:, start_idx:start_idx + length]
118
+ start_idx += length
119
+
120
+ return full_dim_tensor
121
 
122
  def get_direction(neutral_class, target_class, beta, di, clip_model=None):
123
 
 
168
  def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
169
  edit_direction = get_direction(source_class, target_class, beta, di, clip_model)
170
 
171
+ edit_full_dim = expand_to_full_dim(edit_direction)
172
+
173
  source_s = style_dict_to_style_tensor(source_latent, reference_generator)
174
 
175
+ return source_s + alpha * edit_full_dim