Erwann Millon commited on
Commit
28c5269
1 Parent(s): e37b9e5

refactoring and change default path

Browse files
Files changed (2) hide show
  1. app.py +2 -6
  2. loaders.py +1 -1
app.py CHANGED
@@ -5,7 +5,7 @@ import sys
5
  import wandb
6
  import torch
7
 
8
- from presets import set_major_global, set_major_local, set_small_local
9
 
10
  sys.path.append("taming-transformers")
11
 
@@ -36,7 +36,7 @@ def set_img_from_example(state, img):
36
  def get_cleared_mask():
37
  return gr.Image.update(value=None)
38
  class StateWrapper:
39
- """This extremely ugly code is a hacky fix to allow con"""
40
  def create_gif(state, *args, **kwargs):
41
  return state, state[0].create_gif(*args, **kwargs)
42
  def apply_asian_vector(state, *args, **kwargs):
@@ -191,15 +191,11 @@ with gr.Blocks(css="styles.css") as demo:
191
  clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
192
  asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
193
  lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
194
- # hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
195
  blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
196
  blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
197
  # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
198
  base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
199
  blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
200
- # small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
201
- # major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
202
- # major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
203
  apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
204
  rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
205
  set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
 
5
  import wandb
6
  import torch
7
 
8
+ from presets import set_preset
9
 
10
  sys.path.append("taming-transformers")
11
 
 
36
  def get_cleared_mask():
37
  return gr.Image.update(value=None)
38
  class StateWrapper:
39
+ """This extremely ugly code is a hacky fix to allow concurrent users on HF Spaces without instantiating new models for each user."""
40
  def create_gif(state, *args, **kwargs):
41
  return state, state[0].create_gif(*args, **kwargs)
42
  def apply_asian_vector(state, *args, **kwargs):
 
191
  clear.click(StateWrapper.clear_transforms, inputs=[state], outputs=[state, out, mask])
192
  asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
193
  lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
 
194
  blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
195
  blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
196
  # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
197
  base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
198
  blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
 
 
 
199
  apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
200
  rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
201
  set_mask.click(StateWrapper.set_mask, inputs=[state, mask], outputs=[state, testim])
loaders.py CHANGED
@@ -17,7 +17,7 @@ def load_config(config_path, display=False):
17
 
18
 
19
  def load_default(device):
20
- conf_path = "./celeba_vqgan/unwrapped.yaml"
21
  config = load_config(conf_path, display=False)
22
  model = taming.models.vqgan.VQModel(**config.model.params)
23
  sd = torch.load("./celeba_vqgan/vqgan_only.pt", map_location=device)
 
17
 
18
 
19
  def load_default(device):
20
+ conf_path = "./celeba_vqgan/vqgan_only.yaml"
21
  config = load_config(conf_path, display=False)
22
  model = taming.models.vqgan.VQModel(**config.model.params)
23
  sd = torch.load("./celeba_vqgan/vqgan_only.pt", map_location=device)