Johannes Kolbe commited on
Commit
7db1e87
1 Parent(s): ed6b6d6

added better functionality

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +14 -4
  3. interface.py +17 -4
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Sefa
3
- emoji: 💻
4
  colorFrom: blue
5
- colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
  app_file: app.py
 
1
  ---
2
  title: Sefa
3
+ emoji: 🔮
4
  colorFrom: blue
5
+ colorTo: red
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
  app_file: app.py
app.py CHANGED
@@ -54,11 +54,16 @@ def synthesize(model, gan_type, code):
54
  image = postprocess(image)[0]
55
  return image
56
 
 
 
 
 
 
57
 
58
  """Main function (loop for StreamLit)."""
59
  st.title('Closed-Form Factorization of Latent Semantics in GANs')
60
  st.sidebar.title('Options')
61
- reset = st.sidebar.button('Reset')
62
 
63
  model_name = st.sidebar.selectbox(
64
  'Model to Interpret',
@@ -72,7 +77,7 @@ layer_idx = st.sidebar.selectbox(
72
  layers, boundaries, eigen_values = factorize_model(model, layer_idx)
73
 
74
  num_semantics = st.sidebar.number_input(
75
- 'Number of semantics', value=5, min_value=0, max_value=None, step=1)
76
  steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
77
  if gan_type == 'pggan':
78
  max_step = 5.0
@@ -87,10 +92,12 @@ for sem_idx in steps:
87
  value=0.0,
88
  min_value=-max_step,
89
  max_value=max_step,
90
- step=0.04 * max_step if not reset else 0.0)
 
91
 
92
  image_placeholder = st.empty()
93
  button_placeholder = st.empty()
 
94
 
95
  try:
96
  base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
@@ -105,13 +112,16 @@ if state.model_name != model_name:
105
  state.code_idx = 0
106
  state.codes = base_codes[0:1]
107
 
108
- if button_placeholder.button('Random', key=0):
109
  state.code_idx += 1
110
  if state.code_idx < base_codes.shape[0]:
111
  state.codes = base_codes[state.code_idx][np.newaxis]
112
  else:
113
  state.codes = sample(model, gan_type)
114
 
 
 
 
115
  code = state.codes.copy()
116
  for sem_idx, step in steps.items():
117
  if gan_type == 'pggan':
 
54
  image = postprocess(image)[0]
55
  return image
56
 
57
+ def _update_slider():
58
+ num_semantics = st.session_state["num_semantics"]
59
+ for sem_idx in range(num_semantics):
60
+ st.session_state[f"semantic_slider_{sem_idx}"] = 0
61
+
62
 
63
  """Main function (loop for StreamLit)."""
64
  st.title('Closed-Form Factorization of Latent Semantics in GANs')
65
  st.sidebar.title('Options')
66
+ st.sidebar.button('Reset', on_click=_update_slider, kwargs={})
67
 
68
  model_name = st.sidebar.selectbox(
69
  'Model to Interpret',
 
77
  layers, boundaries, eigen_values = factorize_model(model, layer_idx)
78
 
79
  num_semantics = st.sidebar.number_input(
80
+ 'Number of semantics', value=5, min_value=0, max_value=None, step=1, key="num_semantics")
81
  steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
82
  if gan_type == 'pggan':
83
  max_step = 5.0
 
92
  value=0.0,
93
  min_value=-max_step,
94
  max_value=max_step,
95
+ step=0.04 * max_step,
96
+ key=f"semantic_slider_{sem_idx}")
97
 
98
  image_placeholder = st.empty()
99
  button_placeholder = st.empty()
100
+ button_totally_random = st.empty()
101
 
102
  try:
103
  base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
 
112
  state.code_idx = 0
113
  state.codes = base_codes[0:1]
114
 
115
+ if button_placeholder.button('Next Sample'):
116
  state.code_idx += 1
117
  if state.code_idx < base_codes.shape[0]:
118
  state.codes = base_codes[state.code_idx][np.newaxis]
119
  else:
120
  state.codes = sample(model, gan_type)
121
 
122
+ if button_totally_random.button('Totally Random'):
123
+ state.codes = sample(model, gan_type)
124
+
125
  code = state.codes.copy()
126
  for sem_idx, step in steps.items():
127
  if gan_type == 'pggan':
interface.py CHANGED
@@ -1,5 +1,6 @@
1
  # python 3.7
2
  """Demo."""
 
3
 
4
  import numpy as np
5
  import torch
@@ -55,11 +56,18 @@ def synthesize(model, gan_type, code):
55
  return image
56
 
57
 
 
 
 
 
 
 
58
  def main():
59
  """Main function (loop for StreamLit)."""
 
60
  st.title('Closed-Form Factorization of Latent Semantics in GANs')
61
  st.sidebar.title('Options')
62
- reset = st.sidebar.button('Reset')
63
 
64
  model_name = st.sidebar.selectbox(
65
  'Model to Interpret',
@@ -73,7 +81,7 @@ def main():
73
  layers, boundaries, eigen_values = factorize_model(model, layer_idx)
74
 
75
  num_semantics = st.sidebar.number_input(
76
- 'Number of semantics', value=5, min_value=0, max_value=None, step=1)
77
  steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
78
  if gan_type == 'pggan':
79
  max_step = 5.0
@@ -88,10 +96,12 @@ def main():
88
  value=0.0,
89
  min_value=-max_step,
90
  max_value=max_step,
91
- step=0.04 * max_step if not reset else 0.0)
 
92
 
93
  image_placeholder = st.empty()
94
  button_placeholder = st.empty()
 
95
 
96
  try:
97
  base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
@@ -106,13 +116,16 @@ def main():
106
  state.code_idx = 0
107
  state.codes = base_codes[0:1]
108
 
109
- if button_placeholder.button('Random', key=0):
110
  state.code_idx += 1
111
  if state.code_idx < base_codes.shape[0]:
112
  state.codes = base_codes[state.code_idx][np.newaxis]
113
  else:
114
  state.codes = sample(model, gan_type)
115
 
 
 
 
116
  code = state.codes.copy()
117
  for sem_idx, step in steps.items():
118
  if gan_type == 'pggan':
 
1
  # python 3.7
2
  """Demo."""
3
+ import random
4
 
5
  import numpy as np
6
  import torch
 
56
  return image
57
 
58
 
59
+ def _update_slider():
60
+ num_semantics = st.session_state["num_semantics"]
61
+ for sem_idx in range(num_semantics):
62
+ st.session_state[f"semantic_slider_{sem_idx}"] = 0
63
+
64
+
65
  def main():
66
  """Main function (loop for StreamLit)."""
67
+
68
  st.title('Closed-Form Factorization of Latent Semantics in GANs')
69
  st.sidebar.title('Options')
70
+ st.sidebar.button('Reset', on_click=_update_slider, kwargs={})
71
 
72
  model_name = st.sidebar.selectbox(
73
  'Model to Interpret',
 
81
  layers, boundaries, eigen_values = factorize_model(model, layer_idx)
82
 
83
  num_semantics = st.sidebar.number_input(
84
+ 'Number of semantics', value=5, min_value=0, max_value=None, step=1, key="num_semantics")
85
  steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
86
  if gan_type == 'pggan':
87
  max_step = 5.0
 
96
  value=0.0,
97
  min_value=-max_step,
98
  max_value=max_step,
99
+ step=0.04 * max_step,
100
+ key=f"semantic_slider_{sem_idx}")
101
 
102
  image_placeholder = st.empty()
103
  button_placeholder = st.empty()
104
+ button_totally_random = st.empty()
105
 
106
  try:
107
  base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
 
116
  state.code_idx = 0
117
  state.codes = base_codes[0:1]
118
 
119
+ if button_placeholder.button('Next Sample'):
120
  state.code_idx += 1
121
  if state.code_idx < base_codes.shape[0]:
122
  state.codes = base_codes[state.code_idx][np.newaxis]
123
  else:
124
  state.codes = sample(model, gan_type)
125
 
126
+ if button_totally_random.button('Totally Random'):
127
+ state.codes = sample(model, gan_type)
128
+
129
  code = state.codes.copy()
130
  for sem_idx, step in steps.items():
131
  if gan_type == 'pggan':