ludusc commited on
Commit
3b262b0
1 Parent(s): f2b3196
Files changed (1) hide show
  1. pages/1_Omniart_Disentanglement.py +9 -1
pages/1_Omniart_Disentanglement.py CHANGED
@@ -126,11 +126,16 @@ with input_col_2:
126
 
127
  choose_image_button = st.form_submit_button('Choose the defined image')
128
  random_id = st.form_submit_button('Generate a random image')
 
129
 
130
- if random_id:
131
  image_id = random.randint(0, 50000)
132
  st.session_state.image_id = image_id
133
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
 
 
 
 
134
 
135
  if choose_image_button:
136
  image_id = int(image_id)
@@ -162,6 +167,9 @@ if st.session_state.space_id == 'Z':
162
  else:
163
  original_image_vec = annotations['w_vectors'][st.session_state.image_id]
164
 
 
 
 
165
  img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
166
  # input_image = original_image_dict['image']
167
  # input_label = original_image_dict['label']
 
126
 
127
  choose_image_button = st.form_submit_button('Choose the defined image')
128
  random_id = st.form_submit_button('Generate a random image')
129
+ projection_id = st.form_submit_button('Generate an image on the boudary')
130
 
131
+ if random_id or projection_id:
132
  image_id = random.randint(0, 50000)
133
  st.session_state.image_id = image_id
134
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
135
+ projection = False
136
+
137
+ if projection_id:
138
+ projection = True
139
 
140
  if choose_image_button:
141
  image_id = int(image_id)
 
167
  else:
168
  original_image_vec = annotations['w_vectors'][st.session_state.image_id]
169
 
170
+ if projection:
171
+ original_image_vec = original_image_vec - np.dot(original_image_vec, separation_vector) * separation_vector
172
+
173
  img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
174
  # input_image = original_image_dict['image']
175
  # input_label = original_image_dict['label']