ludusc commited on
Commit
a5df893
1 Parent(s): ae2da92
DisentanglementBase.py CHANGED
@@ -175,8 +175,8 @@ class DisentanglementBase:
175
  return X
176
 
177
  def get_train_val(self, extremes=False):
178
- X = self.get_encoded_latent()
179
  y = np.array(self.df[self.variable].values)
 
180
  if self.categorical:
181
  bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
182
  else 1 for x in range(len(self.colors_list) + 1)]
@@ -443,14 +443,14 @@ class DisentanglementBase:
443
  axs[i].imshow(image)
444
  axs[i].set_title(np.round(lambd, 2))
445
  plt.tight_layout()
446
- plt.savefig(join(self.repo_folder, 'figures', 'examples', name+'.jpg'))
447
  plt.close()
448
 
449
  if save_separately:
450
  for i, (image, lambd) in enumerate(zip(images, lambdas)):
451
  plt.imshow(image)
452
  plt.tight_layout()
453
- plt.savefig(join(self.repo_folder, 'figures', 'examples', name + '_' + str(lambd) + '.jpg'))
454
  plt.close()
455
 
456
  return images, lambdas
@@ -556,11 +556,11 @@ def continous_experiment(name, var, repo_folder, model, annotations, df, space,
556
 
557
  def main():
558
  repo_folder = '.'
559
- annotations_file = join(repo_folder, 'data/textile_annotated_files/seeds0000-100000_S.pkl')
560
  with open(annotations_file, 'rb') as f:
561
  annotations = pickle.load(f)
562
 
563
- df_file = join(repo_folder, 'data/textile_annotated_files/top_three_colours.csv')
564
  df = pd.read_csv(df_file).fillna('#000000')
565
 
566
  model_file = join(repo_folder, 'data/textile_model_files/network-snapshot-005000.pkl')
@@ -571,7 +571,7 @@ def main():
571
  'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
572
  'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
573
  colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
574
- 'Blue', 'Purple', 'Pink']
575
 
576
  scores = []
577
  kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
@@ -584,12 +584,12 @@ def main():
584
  if specific_examples is not None:
585
  disentanglemnet_exp = DisentanglementBase(repo_folder, model, annotations, df, space='w', colors_list=colors_list, compute_s=False)
586
 
587
- separation_vectors = disentanglemnet_exp.StyleSpace_separation_vector(sign=True, num_factors=10, cutout=None)
588
- # separation_vectors = disentanglemnet_exp.InterFaceGAN_separation_vector(method='LR', C=0.1)
589
  for specific_example in specific_examples:
590
  seed = specific_example
591
  for i, color in enumerate(colors_list):
592
- disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-9, max_epsilon=9, savefig=True, save_separately=True, feature=color, method='StyleSpace' + '_' + str(True) + '_' + str(10) + '_' + str(None))
593
 
594
  return
595
 
 
175
  return X
176
 
177
  def get_train_val(self, extremes=False):
 
178
  y = np.array(self.df[self.variable].values)
179
+ X = self.get_encoded_latent()[:y.shape[0], :]
180
  if self.categorical:
181
  bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
182
  else 1 for x in range(len(self.colors_list) + 1)]
 
443
  axs[i].imshow(image)
444
  axs[i].set_title(np.round(lambd, 2))
445
  plt.tight_layout()
446
+ plt.savefig(join(self.repo_folder, 'figures', 'examples_new', name+'.jpg'))
447
  plt.close()
448
 
449
  if save_separately:
450
  for i, (image, lambd) in enumerate(zip(images, lambdas)):
451
  plt.imshow(image)
452
  plt.tight_layout()
453
+ plt.savefig(join(self.repo_folder, 'figures', 'examples_new', name + '_' + str(lambd) + '.jpg'))
454
  plt.close()
455
 
456
  return images, lambdas
 
556
 
557
  def main():
558
  repo_folder = '.'
559
+ annotations_file = join(repo_folder, 'data/textile_annotated_files/seeds0000-1000000.pkl')
560
  with open(annotations_file, 'rb') as f:
561
  annotations = pickle.load(f)
562
 
563
+ df_file = join(repo_folder, 'data/textile_annotated_files/top_three_colours_00000-730003.csv')
564
  df = pd.read_csv(df_file).fillna('#000000')
565
 
566
  model_file = join(repo_folder, 'data/textile_model_files/network-snapshot-005000.pkl')
 
571
  'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
572
  'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
573
  colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
574
+ 'Blue', 'Violet', 'Pink']
575
 
576
  scores = []
577
  kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
 
584
  if specific_examples is not None:
585
  disentanglemnet_exp = DisentanglementBase(repo_folder, model, annotations, df, space='w', colors_list=colors_list, compute_s=False)
586
 
587
+ # separation_vectors = disentanglemnet_exp.StyleSpace_separation_vector(sign=True, num_factors=10, cutout=None)
588
+ separation_vectors = disentanglemnet_exp.InterFaceGAN_separation_vector(method='LR', C=0.1)
589
  for specific_example in specific_examples:
590
  seed = specific_example
591
  for i, color in enumerate(colors_list):
592
+ disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-18, max_epsilon=18, savefig=True, save_separately=True, feature=color, method='InterFaceGAN' + '_' + str('LR') + '_' + str(0.1) + '_' + str(None))
593
 
594
  return
595
 
test_disentanglement.sh CHANGED
@@ -1,5 +1,5 @@
1
  #!/bin/bash
2
- #SBATCH --time=1-00:00:00
3
  #SBATCH --mem=32GB
4
  #SBATCH --gres gpu:1
5
 
 
1
  #!/bin/bash
2
+ #SBATCH --time=02:00:00
3
  #SBATCH --mem=32GB
4
  #SBATCH --gres gpu:1
5