erwann commited on
Commit
663705e
1 Parent(s): ca3c3e9

update state for spaces

Browse files
Files changed (2) hide show
  1. ImageState.py +3 -3
  2. app.py +55 -26
ImageState.py CHANGED
@@ -65,7 +65,7 @@ class ImageState:
65
  # current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
66
  # return (self.blend_latent, current_vector_transforms)
67
  # @cache
68
- def get_mask(self, img, mask=None):
69
  if img and "mask" in img and img["mask"] is not None:
70
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
71
  attn_mask = torch.ceil(attn_mask[0].to(self.device))
@@ -81,7 +81,7 @@ class ImageState:
81
  print("mask in apply ", get_resized_tensor(attn_mask), get_resized_tensor(attn_mask).shape)
82
  return attn_mask
83
  def set_mask(self, img):
84
- attn_mask = self.get_mask(img)
85
  self.attn_mask = attn_mask
86
  # attn_mask = torch.ones_like(img, device=self.device)
87
  x = attn_mask.clone()
@@ -119,7 +119,7 @@ class ImageState:
119
  print(f"val = {val}")
120
  self.quant = val
121
  return self._render_all_transformations()
122
- def apply_gender_vector(self, weight):
123
  self.asian_transform = weight * self.asian_vector
124
  return self._render_all_transformations()
125
  def update_images(self, path1, path2, blend_weight):
 
65
  # current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
66
  # return (self.blend_latent, current_vector_transforms)
67
  # @cache
68
+ def _get_mask(self, img, mask=None):
69
  if img and "mask" in img and img["mask"] is not None:
70
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
71
  attn_mask = torch.ceil(attn_mask[0].to(self.device))
 
81
  print("mask in apply ", get_resized_tensor(attn_mask), get_resized_tensor(attn_mask).shape)
82
  return attn_mask
83
  def set_mask(self, img):
84
+ attn_mask = self._get_mask(img)
85
  self.attn_mask = attn_mask
86
  # attn_mask = torch.ones_like(img, device=self.device)
87
  x = attn_mask.clone()
 
119
  print(f"val = {val}")
120
  self.quant = val
121
  return self._render_all_transformations()
122
+ def apply_asian_vector(self, weight):
123
  self.asian_transform = weight * self.asian_vector
124
  return self._render_all_transformations()
125
  def update_images(self, path1, path2, blend_weight):
app.py CHANGED
@@ -13,8 +13,6 @@ import gradio as gr
13
  from transformers import CLIPModel, CLIPProcessor
14
 
15
  import edit
16
- # import importlib
17
- # importlib.reload(edit)
18
  from app_backend import ImagePromptOptimizer, ProcessorGradientFlow
19
  from ImageState import ImageState
20
  from loaders import load_default
@@ -27,14 +25,45 @@ vqgan.eval()
27
  processor = ProcessorGradientFlow(device=device)
28
  clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
29
  clip.to(device)
30
- def set_img_from_example(img):
31
  return state.update_images(img, img, 0)
32
  def get_cleared_mask():
33
  return gr.Image.update(value=None)
34
  # mask.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  with gr.Blocks(css="styles.css") as demo:
36
- promptoptim = ImagePromptOptimizer(vqgan, clip, processor, quantize=True)
37
- state = ImageState(vqgan, promptoptim)
38
  with gr.Row():
39
  with gr.Column(scale=1):
40
  blue_eyes = gr.Slider(
@@ -90,8 +119,8 @@ with gr.Blocks(css="styles.css") as demo:
90
  set_mask = gr.Button(value="Set mask")
91
  gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
92
  testim = gr.Image()
93
- clear_mask = gr.Button(value="Clear mask")
94
- clear_mask.click(get_cleared_mask, outputs=mask)
95
  with gr.Row():
96
  gr.Examples(
97
  examples=glob.glob("test_pics/*"),
@@ -107,9 +136,9 @@ with gr.Blocks(css="styles.css") as demo:
107
  minimum=0,
108
  maximum=100)
109
 
110
- apply_prompts = gr.Button(value="Apply Prompts", elem_id="apply")
111
- clear = gr.Button(value="Clear all transformations (irreversible)", elem_id="warning")
112
- with gr.Accordion(label="Save Animation", open=False):
113
  gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
114
  duration = gr.Number(value=10, label="Duration of the animation in seconds")
115
  extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
@@ -128,7 +157,7 @@ with gr.Blocks(css="styles.css") as demo:
128
  gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
129
  with gr.Row():
130
  with gr.Column():
131
- gr.Text(value="Prompt Editing Configuration", show_label=False)
132
  with gr.Row():
133
  gr.Markdown(value="## Preset Configs", show_label=False)
134
  with gr.Row():
@@ -162,20 +191,20 @@ with gr.Blocks(css="styles.css") as demo:
162
  # step=1,
163
  # value=0,
164
  # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
165
- clear.click(state.clear_transforms, outputs=[out, mask])
166
- asian_weight.change(state.apply_gender_vector, inputs=[asian_weight], outputs=[out, mask])
167
- lip_size.change(state.apply_lip_vector, inputs=[lip_size], outputs=[out, mask])
168
- # hair_green_purple.change(state.apply_gp_vector, inputs=[hair_green_purple], outputs=[out, mask])
169
- blue_eyes.change(state.apply_rb_vector, inputs=[blue_eyes], outputs=[out, mask])
170
- blend_weight.change(state.blend, inputs=[blend_weight], outputs=[out, mask])
171
- # requantize.change(state.update_requant, inputs=[requantize], outputs=[out, mask])
172
- base_img.change(state.update_images, inputs=[base_img, blend_img, blend_weight], outputs=[out, mask])
173
- blend_img.change(state.update_images, inputs=[base_img, blend_img, blend_weight], outputs=[out, mask])
174
- small_local.click(set_small_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
175
- major_local.click(set_major_local, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
176
- major_global.click(set_major_global, outputs=[iterations, learning_rate, lpips_weight, reconstruction_steps])
177
- apply_prompts.click(state.apply_prompts, inputs=[positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[out, mask])
178
- rewind.change(state.rewind, inputs=[rewind], outputs=[out, mask])
179
- set_mask.click(state.set_mask, inputs=mask, outputs=testim)
180
  demo.queue()
181
  demo.launch(debug=True, enable_queue=True)
 
13
  from transformers import CLIPModel, CLIPProcessor
14
 
15
  import edit
 
 
16
  from app_backend import ImagePromptOptimizer, ProcessorGradientFlow
17
  from ImageState import ImageState
18
  from loaders import load_default
 
25
  processor = ProcessorGradientFlow(device=device)
26
  clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
27
  clip.to(device)
28
+ def set_img_from_example(state, img):
29
  return state.update_images(img, img, 0)
30
  def get_cleared_mask():
31
  return gr.Image.update(value=None)
32
  # mask.clear()
33
+
34
+ class StateWrapper:
35
+ def apply_asian_vector(state, *args, **kwargs):
36
+ return state, *state[0].apply_asian_vector(*args, **kwargs)
37
+ def apply_gp_vector(state, *args, **kwargs):
38
+ return state, *state[0].apply_gp_vector(*args, **kwargs)
39
+ def apply_lip_vector(state, *args, **kwargs):
40
+ return state, *state[0].apply_lip_vector(*args, **kwargs)
41
+ def apply_prompts(state, *args, **kwargs):
42
+ return state, *state[0].apply_prompts(*args, **kwargs)
43
+ def apply_rb_vector(state, *args, **kwargs):
44
+ return state, *state[0].apply_rb_vector(*args, **kwargs)
45
+ def blend(state, *args, **kwargs):
46
+ return state, *state[0].blend(*args, **kwargs)
47
+ def clear_transforms(state, *args, **kwargs):
48
+ return state, *state[0].clear_transforms(*args, **kwargs)
49
+ def init_transforms(state, *args, **kwargs):
50
+ return state, *state[0].init_transforms(*args, **kwargs)
51
+ def prompt_optim(state, *args, **kwargs):
52
+ return state, *state[0].prompt_optim(*args, **kwargs)
53
+ def rescale_mask(state, *args, **kwargs):
54
+ return state, *state[0].rescale_mask(*args, **kwargs)
55
+ def rewind(state, *args, **kwargs):
56
+ return state, *state[0].rewind(*args, **kwargs)
57
+ def set_mask(state, *args, **kwargs):
58
+ return state, *state[0].set_mask(*args, **kwargs)
59
+ def update_images(state, *args, **kwargs):
60
+ return state, *state[0].update_images(*args, **kwargs)
61
+ def update_requant(state, *args, **kwargs):
62
+ return state, *state[0].update_requant(*args, **kwargs)
63
+
64
  with gr.Blocks(css="styles.css") as demo:
65
+ promptoptim = gr.State([ImagePromptOptimizer(vqgan, clip, processor, quantize=True)])
66
+ state = gr.State([ImageState(vqgan, promptoptim)])
67
  with gr.Row():
68
  with gr.Column(scale=1):
69
  blue_eyes = gr.Slider(
 
119
  set_mask = gr.Button(value="Set mask")
120
  gr.Text(value="this image shows the mask passed to the model when you press set mask (debugging purposes)")
121
  testim = gr.Image()
122
+ # # clear_mask = gr.Button(value="Clear mask")
123
+ # clear_mask.click(get_cleared_mask, outputs=mask)
124
  with gr.Row():
125
  gr.Examples(
126
  examples=glob.glob("test_pics/*"),
 
136
  minimum=0,
137
  maximum=100)
138
 
139
+ apply_prompts = gr.Button(value="🎨 Apply Prompts", elem_id="apply")
140
+ clear = gr.Button(value="Clear all transformations (irreversible)", elem_id="warning")
141
+ with gr.Accordion(label="💾 Save Animation", open=False):
142
  gr.Text(value="Creates an animation of all the steps in the editing process", show_label=False)
143
  duration = gr.Number(value=10, label="Duration of the animation in seconds")
144
  extend_frames = gr.Checkbox(value=True, label="Make first and last frame longer")
 
157
  gen_prompts.click(get_random_prompts, outputs=[positive_prompts, negative_prompts])
158
  with gr.Row():
159
  with gr.Column():
160
+ gr.Text(value="⚙️ Prompt Editing Configuration", show_label=False)
161
  with gr.Row():
162
  gr.Markdown(value="## Preset Configs", show_label=False)
163
  with gr.Row():
 
191
  # step=1,
192
  # value=0,
193
  # label="Steps to run at the end, optimizing only the discriminator loss. This helps to reduce artefacts, but because the model is trained on CelebA, this will make your generations look more like generic white celebrities")
194
+ # clear.click(state.clear_transforms, inputs=[state], outputs=[state, out, mask])
195
+ asian_weight.change(StateWrapper.apply_asian_vector, inputs=[state, asian_weight], outputs=[state, out, mask])
196
+ lip_size.change(StateWrapper.apply_lip_vector, inputs=[state, lip_size], outputs=[state, out, mask])
197
+ # hair_green_purple.change(StateWrapper.apply_gp_vector, inputs=[state, hair_green_purple], outputs=[state, out, mask])
198
+ blue_eyes.change(StateWrapper.apply_rb_vector, inputs=[state, blue_eyes], outputs=[state, out, mask])
199
+ blend_weight.change(StateWrapper.blend, inputs=[state, blend_weight], outputs=[state, out, mask])
200
+ # requantize.change(StateWrapper.update_requant, inputs=[state, requantize], outputs=[state, out, mask])
201
+ base_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
202
+ blend_img.change(StateWrapper.update_images, inputs=[state, base_img, blend_img, blend_weight], outputs=[state, out, mask])
203
+ small_local.click(set_small_local, outputs=[state, iterations, learning_rate, lpips_weight, reconstruction_steps])
204
+ major_local.click(set_major_local, outputs=[state, iterations, learning_rate, lpips_weight, reconstruction_steps])
205
+ major_global.click(set_major_global, outputs=[state, iterations, learning_rate, lpips_weight, reconstruction_steps])
206
+ apply_prompts.click(StateWrapper.apply_prompts, inputs=[state, positive_prompts, negative_prompts, learning_rate, iterations, lpips_weight, reconstruction_steps], outputs=[state, out, mask])
207
+ rewind.change(StateWrapper.rewind, inputs=[state, rewind], outputs=[state, out, mask])
208
+ set_mask.click(StateWrapper.set_mask, inputs=mask, outputs=testim)
209
  demo.queue()
210
  demo.launch(debug=True, enable_queue=True)