linoyts HF staff commited on
Commit
1d4a57d
1 Parent(s): 4ea931b

Update clip_slider_pipeline.py

Browse files
Files changed (1) hide show
  1. clip_slider_pipeline.py +8 -4
clip_slider_pipeline.py CHANGED
@@ -14,7 +14,7 @@ class CLIPSlider:
14
  opposite: str = "",
15
  target_word_2nd: str = "",
16
  opposite_2nd: str = "",
17
- iterations: int = 300,
18
  ):
19
 
20
  self.device = device
@@ -32,7 +32,8 @@ class CLIPSlider:
32
 
33
  def find_latent_direction(self,
34
  target_word:str,
35
- opposite:str, num_iterations: int = None):
 
36
 
37
  # lets identify a latent direction by taking differences between opposites
38
  # target_word = "happy"
@@ -357,12 +358,15 @@ class T5SliderFlux(CLIPSlider):
357
 
358
  def find_latent_direction(self,
359
  target_word:str,
360
- opposite:str):
361
 
362
  # lets identify a latent direction by taking differences between opposites
363
  # target_word = "happy"
364
  # opposite = "sad"
365
-
 
 
 
366
 
367
  with torch.no_grad():
368
  positives = []
 
14
  opposite: str = "",
15
  target_word_2nd: str = "",
16
  opposite_2nd: str = "",
17
+
18
  ):
19
 
20
  self.device = device
 
32
 
33
  def find_latent_direction(self,
34
  target_word:str,
35
+ opposite:str,
36
+ num_iterations: int = None):
37
 
38
  # lets identify a latent direction by taking differences between opposites
39
  # target_word = "happy"
 
358
 
359
  def find_latent_direction(self,
360
  target_word:str,
361
+ opposite:str,num_iterations:int=300 ):
362
 
363
  # lets identify a latent direction by taking differences between opposites
364
  # target_word = "happy"
365
  # opposite = "sad"
366
+ if num_iterations is not None:
367
+ iterations = num_iterations
368
+ else:
369
+ iterations = self.iterations
370
 
371
  with torch.no_grad():
372
  positives = []