ludusc commited on
Commit
78d8811
1 Parent(s): 1253aaa
Files changed (1) hide show
  1. pages/1_Textiles_Disentanglement.py +9 -10
pages/1_Textiles_Disentanglement.py CHANGED
@@ -35,7 +35,6 @@ concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
35
  concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
36
  concept_vectors['score'] = concept_vectors['score'].astype(float)
37
  concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
38
- print(concept_vectors[['vector', 'score']], concept_vectors.loc[0, 'vector'], concept_vectors.loc[0, 'vector'].shape)
39
 
40
  with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
41
  model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
@@ -77,7 +76,7 @@ with input_col_1:
77
  st.session_state.image_id = int(image_id)
78
 
79
  with input_col_2:
80
- with st.form('text_form'):
81
 
82
  st.write('**Choose color to vary**')
83
  type_col = st.selectbox('Color:', tuple(COLORS_LIST))
@@ -88,7 +87,7 @@ with input_col_2:
88
  color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=0, step=1, value=7)
89
  color_lambda_button = st.form_submit_button('Choose the defined lambda')
90
 
91
- if choose_text_button:
92
  st.session_state.concept_ids = type_col
93
  st.session_state.space_id = space_id
94
 
@@ -97,13 +96,13 @@ with input_col_3:
97
 
98
  st.write('**Saturation variation**')
99
  chosen_saturation_lambda_input = st.empty()
100
- saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=0, step=1)
101
- saturation_lambda_button = st.form_submit_button('Choose the defined lambda')
102
 
103
  st.write('**Value variation**')
104
  chosen_value_lambda_input = st.empty()
105
- value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=0, step=1)
106
- value_lambda_button = st.form_submit_button('Choose the defined lambda')
107
 
108
  # with input_col_4:
109
  # with st.form('Network specifics:'):
@@ -140,9 +139,9 @@ with header_col_1:
140
  st.write(f'Original image')
141
 
142
  with header_col_2:
143
- color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids].loc[0, ['vector', 'score']]
144
- saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].loc[0, ['vector', 'score']]
145
- value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].loc[0, ['vector', 'score']]
146
  st.write(f'Change in {st.session_state.concept_ids} of {np.round(color_lambda, 2)}, in saturation of {np.round(saturation_lambda, 2)}, in value of {np.round(value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation}, value vector: {performance_value}')
147
 
148
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
 
35
  concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
36
  concept_vectors['score'] = concept_vectors['score'].astype(float)
37
  concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
 
38
 
39
  with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
40
  model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
 
76
  st.session_state.image_id = int(image_id)
77
 
78
  with input_col_2:
79
+ with st.form('text_form_1'):
80
 
81
  st.write('**Choose color to vary**')
82
  type_col = st.selectbox('Color:', tuple(COLORS_LIST))
 
87
  color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=0, step=1, value=7)
88
  color_lambda_button = st.form_submit_button('Choose the defined lambda')
89
 
90
+ if colors_button:
91
  st.session_state.concept_ids = type_col
92
  st.session_state.space_id = space_id
93
 
 
96
 
97
  st.write('**Saturation variation**')
98
  chosen_saturation_lambda_input = st.empty()
99
+ saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=0, step=1, key=0)
100
+ saturation_lambda_button = st.form_submit_button('Choose the defined lambda for Saturation')
101
 
102
  st.write('**Value variation**')
103
  chosen_value_lambda_input = st.empty()
104
+ value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=0, step=1, key=1)
105
+ value_lambda_button = st.form_submit_button('Choose the defined lambda for Value')
106
 
107
  # with input_col_4:
108
  # with st.form('Network specifics:'):
 
139
  st.write(f'Original image')
140
 
141
  with header_col_2:
142
+ color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids].reset_index().loc[0, ['vector', 'score']]
143
+ saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].reset_index().loc[0, ['vector', 'score']]
144
+ value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].reset_index().loc[0, ['vector', 'score']]
145
  st.write(f'Change in {st.session_state.concept_ids} of {np.round(color_lambda, 2)}, in saturation of {np.round(saturation_lambda, 2)}, in value of {np.round(value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation}, value vector: {performance_value}')
146
 
147
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------