Max Reimann commited on
Commit
dc6a058
·
1 Parent(s): 591e364

Improve parameters

Browse files
Files changed (1) hide show
  1. Whitebox_style_transfer.py +9 -6
Whitebox_style_transfer.py CHANGED
@@ -264,22 +264,25 @@ def on_slider():
264
 
265
 
266
  with coll2:
267
- show_params_names = [ 'bumpScale', "bumpOpacity", "contourOpacity"]
268
  display_means = []
 
269
  def create_slider(name):
270
- mean = torch.mean(vp[:, effect.vpd.name2idx[name]]).item()
271
- display_mean = mean + 0.5
 
272
  display_means.append(display_mean)
273
  if "slider_" + name not in st.session_state or st.session_state["action"] != "slider_change":
274
  st.session_state["slider_" + name] = display_mean
275
  slider = st.slider(f"Mean {name}: ", 0.0, 1.0, step=0.05, key="slider_" + name, on_change=on_slider)
276
- vp[:, effect.vpd.name2idx[name]] += slider - display_mean
277
- vp.clamp_(-0.5, 0.5)
 
278
 
279
  for name in show_params_names:
280
  create_slider(name)
281
 
282
- others_idx = set(range(len(effect.vpd.vp_ranges))) - set([effect.vpd.name2idx[name] for name in show_params_names])
283
  others_names = [effect.vpd.vp_ranges[i][0] for i in sorted(list(others_idx))]
284
  other_param = st.selectbox("Other parameters: ", others_names)
285
  create_slider(other_param)
 
264
 
265
 
266
  with coll2:
267
+ show_params_names = [ 'bumpiness',"bumpSpecular", "contours"]
268
  display_means = []
269
+ params_mapping = {"bumpiness": ['bumpScale', "bumpOpacity"], "bumpSpecular": ["bumpSpecular"], "contours": [ "contourOpacity", "contour"]}
270
  def create_slider(name):
271
+ params = params_mapping[name] if name in params_mapping else [name]
272
+ means = [torch.mean(vp[:, effect.vpd.name2idx[n]]).item() for n in params]
273
+ display_mean = np.average(means) + 0.5
274
  display_means.append(display_mean)
275
  if "slider_" + name not in st.session_state or st.session_state["action"] != "slider_change":
276
  st.session_state["slider_" + name] = display_mean
277
  slider = st.slider(f"Mean {name}: ", 0.0, 1.0, step=0.05, key="slider_" + name, on_change=on_slider)
278
+ for i, param_name in enumerate(params):
279
+ vp[:, effect.vpd.name2idx[param_name]] += slider - (means[i] + 0.5)
280
+ vp.clamp_(-0.5, 0.5)
281
 
282
  for name in show_params_names:
283
  create_slider(name)
284
 
285
+ others_idx = set(range(len(effect.vpd.vp_ranges))) - set([effect.vpd.name2idx[name] for name in sum(params_mapping.values(), [])])
286
  others_names = [effect.vpd.vp_ranges[i][0] for i in sorted(list(others_idx))]
287
  other_param = st.selectbox("Other parameters: ", others_names)
288
  create_slider(other_param)