Spaces:
Runtime error
Runtime error
changes0
Browse files- DisentanglementBase.py +9 -9
- test_disentanglement.sh +1 -1
DisentanglementBase.py
CHANGED
@@ -175,8 +175,8 @@ class DisentanglementBase:
|
|
175 |
return X
|
176 |
|
177 |
def get_train_val(self, extremes=False):
|
178 |
-
X = self.get_encoded_latent()
|
179 |
y = np.array(self.df[self.variable].values)
|
|
|
180 |
if self.categorical:
|
181 |
bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
|
182 |
else 1 for x in range(len(self.colors_list) + 1)]
|
@@ -443,14 +443,14 @@ class DisentanglementBase:
|
|
443 |
axs[i].imshow(image)
|
444 |
axs[i].set_title(np.round(lambd, 2))
|
445 |
plt.tight_layout()
|
446 |
-
plt.savefig(join(self.repo_folder, 'figures', '
|
447 |
plt.close()
|
448 |
|
449 |
if save_separately:
|
450 |
for i, (image, lambd) in enumerate(zip(images, lambdas)):
|
451 |
plt.imshow(image)
|
452 |
plt.tight_layout()
|
453 |
-
plt.savefig(join(self.repo_folder, 'figures', '
|
454 |
plt.close()
|
455 |
|
456 |
return images, lambdas
|
@@ -556,11 +556,11 @@ def continous_experiment(name, var, repo_folder, model, annotations, df, space,
|
|
556 |
|
557 |
def main():
|
558 |
repo_folder = '.'
|
559 |
-
annotations_file = join(repo_folder, 'data/textile_annotated_files/seeds0000-
|
560 |
with open(annotations_file, 'rb') as f:
|
561 |
annotations = pickle.load(f)
|
562 |
|
563 |
-
df_file = join(repo_folder, 'data/textile_annotated_files/
|
564 |
df = pd.read_csv(df_file).fillna('#000000')
|
565 |
|
566 |
model_file = join(repo_folder, 'data/textile_model_files/network-snapshot-005000.pkl')
|
@@ -571,7 +571,7 @@ def main():
|
|
571 |
'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
|
572 |
'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
|
573 |
colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
|
574 |
-
'Blue', '
|
575 |
|
576 |
scores = []
|
577 |
kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
|
@@ -584,12 +584,12 @@ def main():
|
|
584 |
if specific_examples is not None:
|
585 |
disentanglemnet_exp = DisentanglementBase(repo_folder, model, annotations, df, space='w', colors_list=colors_list, compute_s=False)
|
586 |
|
587 |
-
separation_vectors = disentanglemnet_exp.StyleSpace_separation_vector(sign=True, num_factors=10, cutout=None)
|
588 |
-
|
589 |
for specific_example in specific_examples:
|
590 |
seed = specific_example
|
591 |
for i, color in enumerate(colors_list):
|
592 |
-
disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-
|
593 |
|
594 |
return
|
595 |
|
|
|
175 |
return X
|
176 |
|
177 |
def get_train_val(self, extremes=False):
|
|
|
178 |
y = np.array(self.df[self.variable].values)
|
179 |
+
X = self.get_encoded_latent()[:y.shape[0], :]
|
180 |
if self.categorical:
|
181 |
bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
|
182 |
else 1 for x in range(len(self.colors_list) + 1)]
|
|
|
443 |
axs[i].imshow(image)
|
444 |
axs[i].set_title(np.round(lambd, 2))
|
445 |
plt.tight_layout()
|
446 |
+
plt.savefig(join(self.repo_folder, 'figures', 'examples_new', name+'.jpg'))
|
447 |
plt.close()
|
448 |
|
449 |
if save_separately:
|
450 |
for i, (image, lambd) in enumerate(zip(images, lambdas)):
|
451 |
plt.imshow(image)
|
452 |
plt.tight_layout()
|
453 |
+
plt.savefig(join(self.repo_folder, 'figures', 'examples_new', name + '_' + str(lambd) + '.jpg'))
|
454 |
plt.close()
|
455 |
|
456 |
return images, lambdas
|
|
|
556 |
|
557 |
def main():
|
558 |
repo_folder = '.'
|
559 |
+
annotations_file = join(repo_folder, 'data/textile_annotated_files/seeds0000-1000000.pkl')
|
560 |
with open(annotations_file, 'rb') as f:
|
561 |
annotations = pickle.load(f)
|
562 |
|
563 |
+
df_file = join(repo_folder, 'data/textile_annotated_files/top_three_colours_00000-730003.csv')
|
564 |
df = pd.read_csv(df_file).fillna('#000000')
|
565 |
|
566 |
model_file = join(repo_folder, 'data/textile_model_files/network-snapshot-005000.pkl')
|
|
|
571 |
'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
|
572 |
'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
|
573 |
colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
|
574 |
+
'Blue', 'Violet', 'Pink']
|
575 |
|
576 |
scores = []
|
577 |
kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
|
|
|
584 |
if specific_examples is not None:
|
585 |
disentanglemnet_exp = DisentanglementBase(repo_folder, model, annotations, df, space='w', colors_list=colors_list, compute_s=False)
|
586 |
|
587 |
+
# separation_vectors = disentanglemnet_exp.StyleSpace_separation_vector(sign=True, num_factors=10, cutout=None)
|
588 |
+
separation_vectors = disentanglemnet_exp.InterFaceGAN_separation_vector(method='LR', C=0.1)
|
589 |
for specific_example in specific_examples:
|
590 |
seed = specific_example
|
591 |
for i, color in enumerate(colors_list):
|
592 |
+
disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-18, max_epsilon=18, savefig=True, save_separately=True, feature=color, method='InterFaceGAN' + '_' + str('LR') + '_' + str(0.1) + '_' + str(None))
|
593 |
|
594 |
return
|
595 |
|
test_disentanglement.sh
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
#!/bin/bash
|
2 |
-
#SBATCH --time=
|
3 |
#SBATCH --mem=32GB
|
4 |
#SBATCH --gres gpu:1
|
5 |
|
|
|
1 |
#!/bin/bash
|
2 |
+
#SBATCH --time=02:00:00
|
3 |
#SBATCH --mem=32GB
|
4 |
#SBATCH --gres gpu:1
|
5 |
|