ludusc commited on
Commit
e79558d
1 Parent(s): 895260d

third page

Browse files
backend/disentangle_concepts.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
 
8
 
9
 
10
- def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W'):
11
  """
12
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
13
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
@@ -33,9 +33,19 @@ def generate_composite_images(model, z, decision_boundaries, lambdas, latent_spa
33
  repetitions = 16
34
  z_0 = z
35
 
36
- for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
37
- decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
38
- z_0 = z_0 + int(lmbd) * decision_boundary
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  if latent_space == 'Z':
 
7
 
8
 
9
 
10
+ def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W', negative_colors=None):
11
  """
12
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
13
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
 
33
  repetitions = 16
34
  z_0 = z
35
 
36
+ if negative_colors:
37
+ for decision_boundary, lmbd, neg_boundary in zip(decision_boundaries, lambdas, negative_colors):
38
+ decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
39
+ if neg_boundary != 'None':
40
+ neg_boundary = torch.from_numpy(neg_boundary.copy()).to(device)
41
+
42
+ z_0 = z_0 + int(lmbd) * (decision_boundary - (neg_boundary.T * decision_boundary) * neg_boundary)
43
+ else:
44
+ z_0 = z_0 + int(lmbd) * decision_boundary
45
+ else:
46
+ for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
47
+ decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
48
+ z_0 = z_0 + int(lmbd) * decision_boundary
49
 
50
 
51
  if latent_space == 'Z':
pages/1_Textiles_Disentanglement.py CHANGED
@@ -139,6 +139,11 @@ with input_col_4:
139
  with st.form('text_form_2'):
140
  st.write('Use best options')
141
  best = st.selectbox('Option:', tuple([True, False]), index=0)
 
 
 
 
 
142
  if st.session_state.best is False:
143
  st.write('Options for StyleSpace (not available for Saturation and Value)')
144
  sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
 
139
  with st.form('text_form_2'):
140
  st.write('Use best options')
141
  best = st.selectbox('Option:', tuple([True, False]), index=0)
142
+ sign = True
143
+ num_factors=10
144
+ cl_method='LR'
145
+ regularization=0.1
146
+ extremes=True
147
  if st.session_state.best is False:
148
  st.write('Options for StyleSpace (not available for Saturation and Value)')
149
  sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
pages/3_Vectors_algebra.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+
8
+ from matplotlib.backends.backend_agg import RendererAgg
9
+
10
+ from backend.disentangle_concepts import *
11
+ import torch_utils
12
+ import dnnlib
13
+ import legacy
14
+
15
+ _lock = RendererAgg.lock
16
+
17
+
18
+ st.set_page_config(layout='wide')
19
+ BACKGROUND_COLOR = '#bcd0e7'
20
+ SECONDARY_COLOR = '#bce7db'
21
+
22
+
23
+ st.title('Disentanglement studies on the Textile Dataset')
24
+ st.markdown(
25
+ """
26
+ This is a demo of the Disentanglement studies on the [iMET Textiles Dataset](https://www.metmuseum.org/art/collection/search/85531).
27
+ """,
28
+ unsafe_allow_html=False,)
29
+
30
+ annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
31
+ with open(annotations_file, 'rb') as f:
32
+ annotations = pickle.load(f)
33
+
34
+ concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
35
+ concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
36
+ concept_vectors['score'] = concept_vectors['score'].astype(float)
37
+
38
+ concept_vectors['sign'] = [True if 'sign:True' in val else False for val in concept_vectors['kwargs']]
39
+ concept_vectors['extremes'] = [True if 'extremes method:True' in val else False for val in concept_vectors['kwargs']]
40
+ concept_vectors['regularization'] = [float(val.split(',')[1].strip('regularization: ')) if 'regularization:' in val else False for val in concept_vectors['kwargs']]
41
+ concept_vectors['cl_method'] = [val.split(',')[0].strip('classification method:') if 'classification method:' in val else False for val in concept_vectors['kwargs']]
42
+ concept_vectors['num_factors'] = [int(val.split(',')[1].strip('number of factors:')) if 'number of factors:' in val else False for val in concept_vectors['kwargs']]
43
+
44
+ concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
45
+
46
+ with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
47
+ model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
48
+
49
+ COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
50
+ COLORS_NEGATIVE = COLORS_LIST + ['None']
51
+
52
+ if 'image_id' not in st.session_state:
53
+ st.session_state.image_id = 52921
54
+ if 'colors' not in st.session_state:
55
+ st.session_state.colors = [COLORS_LIST[0]]
56
+ if 'non_colors' not in st.session_state:
57
+ st.session_state.non_colors = ['None']
58
+ if 'space_id' not in st.session_state:
59
+ st.session_state.space_id = 'W'
60
+ if 'color_lambda' not in st.session_state:
61
+ st.session_state.color_lambda = 7
62
+ if 'saturation_lambda' not in st.session_state:
63
+ st.session_state.saturation_lambda = 0
64
+ if 'value_lambda' not in st.session_state:
65
+ st.session_state.value_lambda = 0
66
+ if 'sign' not in st.session_state:
67
+ st.session_state.sign = False
68
+ if 'extremes' not in st.session_state:
69
+ st.session_state.extremes = False
70
+ if 'regularization' not in st.session_state:
71
+ st.session_state.regularization = False
72
+ if 'cl_method' not in st.session_state:
73
+ st.session_state.cl_method = False
74
+ if 'num_factors' not in st.session_state:
75
+ st.session_state.num_factors = False
76
+ if 'best' not in st.session_state:
77
+ st.session_state.best = True
78
+
79
+ # def on_change_random_input():
80
+ # st.session_state.image_id = st.session_state.image_id
81
+
82
+ # ----------------------------- INPUT ----------------------------------
83
+ epsilon_container = st.empty()
84
+ st.header('Image Manipulation with Vector Algebra')
85
+
86
+ header_col_1, header_col_2, header_col_3, header_col_4 = st.columns([1,2,2,1])
87
+ input_col_1, output_col_2, output_col_3, input_col_4 = st.columns([1,2,2,1])
88
+
89
+ # --------------------------- INPUT column 1 ---------------------------
90
+ with input_col_1:
91
+ with st.form('image_form'):
92
+
93
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
94
+ st.write('**Choose or generate a random image to test the disentanglement**')
95
+ chosen_image_id_input = st.empty()
96
+ image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
97
+
98
+ choose_image_button = st.form_submit_button('Choose the defined image')
99
+ random_id = st.form_submit_button('Generate a random image')
100
+
101
+ if random_id:
102
+ image_id = random.randint(0, 100000)
103
+ st.session_state.image_id = image_id
104
+ chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
105
+
106
+ if choose_image_button:
107
+ image_id = int(image_id)
108
+ st.session_state.image_id = image_id
109
+
110
+ with header_col_1:
111
+ st.write('Input image selection')
112
+
113
+ if st.session_state.space_id == 'Z':
114
+ original_image_vec = annotations['z_vectors'][st.session_state.image_id]
115
+ else:
116
+ original_image_vec = annotations['w_vectors'][st.session_state.image_id]
117
+
118
+ img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
119
+
120
+ with output_col_2:
121
+ st.image(img)
122
+
123
+ with header_col_2:
124
+ st.write('Original image')
125
+
126
+ with input_col_4:
127
+ with st.form('text_form_1'):
128
+
129
+ st.write('**Positive colors to vary (including Saturation and Value)**')
130
+ colors = st.multiselect('Color:', tuple(COLORS_LIST), default=[COLORS_LIST[0], COLORS_LIST[1]])
131
+ colors_button = st.form_submit_button('Choose the defined colors')
132
+
133
+ st.session_state.image_id = image_id
134
+ st.session_state.colors = colors
135
+ st.session_state.color_lambda = [5]*len(colors)
136
+ st.session_state.non_colors = ['None']*len(colors)
137
+
138
+ lambdas = []
139
+ negative_cols = []
140
+ for color in colors:
141
+ st.write(color)
142
+ st.write('**Set range of change**')
143
+ chosen_color_lambda_input = st.empty()
144
+ color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=5, key=color+'_number')
145
+ lambdas.append(color_lambda)
146
+
147
+ st.write('**Set dimensions of change to not consider**')
148
+ chosen_color_negative_input = st.empty()
149
+ color_negative = chosen_color_negative_input.selectbox('Color:', tuple(COLORS_NEGATIVE), index=len(COLORS_NEGATIVE)-1, key=color+'_noncolor')
150
+ negative_cols.append(color_negative)
151
+
152
+ lambdas_button = st.form_submit_button('Submit options')
153
+ if lambdas_button:
154
+ st.session_state.color_lambda = lambdas
155
+ st.session_state.non_colors = negative_cols
156
+
157
+
158
+ # print(st.session_state.colors)
159
+ # print(st.session_state.color_lambda)
160
+ # print(st.session_state.non_colors)
161
+
162
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
163
+
164
+ with header_col_3:
165
+ separation_vectors = []
166
+ for col in st.session_state.colors:
167
+ separation_vector, score_1 = concept_vectors[concept_vectors['color'] == col].reset_index().loc[0, ['vector', 'score']]
168
+ separation_vectors.append(separation_vector)
169
+
170
+ negative_separation_vectors = []
171
+ for non_col in st.session_state.non_colors:
172
+ if non_col != 'None':
173
+ negative_separation_vector, score_2 = concept_vectors[concept_vectors['color'] == non_col].reset_index().loc[0, ['vector', 'score']]
174
+ negative_separation_vectors.append(negative_separation_vector)
175
+ else:
176
+ negative_separation_vectors.append('None')
177
+ ## n1 − (n1T n2)n2
178
+ # print(negative_separation_vectors, separation_vectors)
179
+ st.write(f'Output Image, with positive {str(st.session_state.colors)}, and negative {str(st.session_state.non_colors)}')
180
+
181
+ # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
182
+
183
+
184
+ with output_col_3:
185
+ image_updated = generate_composite_images(model, original_image_vec, separation_vectors, lambdas=st.session_state.color_lambda, negative_colors=negative_separation_vectors)
186
+ st.image(image_updated)