rinong commited on
Commit
f2ea589
1 Parent(s): 4e5af99

Fixed typing bugs in styleclip projection

Browse files
Files changed (1) hide show
  1. styleclip/styleclip_global.py +2 -1
styleclip/styleclip_global.py CHANGED
@@ -120,6 +120,7 @@ def get_direction(neutral_class, target_class, beta, di, clip_model=None):
120
 
121
  dt = class_weights[:, 1] - class_weights[:, 0]
122
  dt = dt / dt.norm()
 
123
  relevance = di @ dt
124
  mask = relevance.abs() > beta
125
  direction = relevance * mask
@@ -151,7 +152,7 @@ def style_dict_to_style_tensor(style_dict, reference_generator):
151
  return style_tensor
152
 
153
  def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
154
- edit_direction = get_direction(source_class, target_class, beta)
155
 
156
  source_s = style_dict_to_style_tensor(source_latent, reference_generator)
157
 
 
120
 
121
  dt = class_weights[:, 1] - class_weights[:, 0]
122
  dt = dt / dt.norm()
123
+ dt = dt.type(type(di))
124
  relevance = di @ dt
125
  mask = relevance.abs() > beta
126
  direction = relevance * mask
 
152
  return style_tensor
153
 
154
  def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
155
+ edit_direction = get_direction(source_class, target_class, beta, di, clip_model)
156
 
157
  source_s = style_dict_to_style_tensor(source_latent, reference_generator)
158