shgao commited on
Commit
ae359e4
1 Parent(s): 24d750d

update demo

Browse files
Files changed (1) hide show
  1. sam2edit.py +17 -11
sam2edit.py CHANGED
@@ -29,8 +29,8 @@ def create_demo():
29
  from diffusers.utils import load_image
30
 
31
  base_model_path = "stabilityai/stable-diffusion-2-inpainting"
32
- config_dict = OrderedDict([('SAM Pretrained(v0-1)', 'shgao/edit-anything-v0-1-1'),
33
- ('LAION Pretrained(v0-3)', 'shgao/edit-anything-v0-3'),
34
  # ('LAION Pretrained(v0-3-1)', '../../edit/edit-anything-ckpt-v0-3'),
35
  ])
36
  def obtain_generation_model(controlnet_path):
@@ -48,7 +48,7 @@ def create_demo():
48
  return pipe
49
  global default_controlnet_path
50
  global pipe
51
- default_controlnet_path = config_dict['LAION Pretrained(v0-3)']
52
  pipe = obtain_generation_model(default_controlnet_path)
53
 
54
  # Segment-Anything init.
@@ -123,11 +123,16 @@ def create_demo():
123
  return full_img, res
124
 
125
 
126
- def process(condition_model, source_image, mask_image, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
127
 
128
  input_image = source_image["image"]
129
  if mask_image is None:
130
- mask_image = source_image["mask"]
 
 
 
 
 
131
  global default_controlnet_path
132
  print("To Use:", config_dict[condition_model], "Current:", default_controlnet_path)
133
  if default_controlnet_path!=config_dict[condition_model]:
@@ -248,18 +253,19 @@ def create_demo():
248
  "## Edit Anything")
249
  with gr.Row():
250
  with gr.Column():
251
- source_image = gr.Image(source='upload',label="Image (support sketch)", type="numpy", tool="sketch")
252
- mask_image = gr.Image(source='upload', label="Edit region (Optional)", type="numpy", value=None)
253
- prompt = gr.Textbox(label="Prompt")
254
- enable_auto_prompt = gr.Checkbox(label='Auto generated BLIP2 prompt', value=True)
255
  run_button = gr.Button(label="Run")
256
  condition_model = gr.Dropdown(choices=list(config_dict.keys()),
257
  value=list(config_dict.keys())[1],
258
  label='Model',
259
  multiselect=False)
260
  num_samples = gr.Slider(
261
- label="Images", minimum=1, maximum=12, value=1, step=1)
262
  with gr.Accordion("Advanced options", open=False):
 
263
  image_resolution = gr.Slider(
264
  label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
265
  strength = gr.Slider(
@@ -282,7 +288,7 @@ def create_demo():
282
  result_gallery = gr.Gallery(
283
  label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
284
  result_text = gr.Text(label='BLIP2+Human Prompt Text')
285
- ips = [condition_model, source_image, mask_image, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
286
  detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
287
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery, result_text])
288
  return demo
 
29
  from diffusers.utils import load_image
30
 
31
  base_model_path = "stabilityai/stable-diffusion-2-inpainting"
32
+ config_dict = OrderedDict([('SAM Pretrained(v0-1): Good Natural Sense', 'shgao/edit-anything-v0-1-1'),
33
+ ('LAION Pretrained(v0-3): Good Face', 'shgao/edit-anything-v0-3'),
34
  # ('LAION Pretrained(v0-3-1)', '../../edit/edit-anything-ckpt-v0-3'),
35
  ])
36
  def obtain_generation_model(controlnet_path):
 
48
  return pipe
49
  global default_controlnet_path
50
  global pipe
51
+ default_controlnet_path = config_dict['LAION Pretrained(v0-3): Good Face']
52
  pipe = obtain_generation_model(default_controlnet_path)
53
 
54
  # Segment-Anything init.
 
123
  return full_img, res
124
 
125
 
126
+ def process(condition_model, source_image, enable_all_generate, mask_image, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
127
 
128
  input_image = source_image["image"]
129
  if mask_image is None:
130
+ if enable_all_generate:
131
+ print("source_image", source_image["mask"].shape, input_image.shape,)
132
+ print(source_image["mask"].max())
133
+ mask_image = np.ones((input_image.shape[0], input_image.shape[1], 3))*255
134
+ else:
135
+ mask_image = source_image["mask"]
136
  global default_controlnet_path
137
  print("To Use:", config_dict[condition_model], "Current:", default_controlnet_path)
138
  if default_controlnet_path!=config_dict[condition_model]:
 
253
  "## Edit Anything")
254
  with gr.Row():
255
  with gr.Column():
256
+ source_image = gr.Image(source='upload',label="Image (Upload an image and cover the region you want to edit with sketch)", type="numpy", tool="sketch")
257
+ enable_all_generate = gr.Checkbox(label='Auto generation on all region.', value=True)
258
+ prompt = gr.Textbox(label="Prompt (Text in the expected things of edited region)")
259
+ enable_auto_prompt = gr.Checkbox(label='Auto generate text prompt from input image with BLIP2: Warning: Enable this may makes your prompt not working.', value=True)
260
  run_button = gr.Button(label="Run")
261
  condition_model = gr.Dropdown(choices=list(config_dict.keys()),
262
  value=list(config_dict.keys())[1],
263
  label='Model',
264
  multiselect=False)
265
  num_samples = gr.Slider(
266
+ label="Images", minimum=1, maximum=12, value=2, step=1)
267
  with gr.Accordion("Advanced options", open=False):
268
+ mask_image = gr.Image(source='upload', label="(Optional) Upload a predefined mask of edit region if you do not want to write your prompt.", type="numpy", value=None)
269
  image_resolution = gr.Slider(
270
  label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
271
  strength = gr.Slider(
 
288
  result_gallery = gr.Gallery(
289
  label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
290
  result_text = gr.Text(label='BLIP2+Human Prompt Text')
291
+ ips = [condition_model, source_image, enable_all_generate, mask_image, enable_auto_prompt, prompt, a_prompt, n_prompt, num_samples, image_resolution,
292
  detect_resolution, ddim_steps, guess_mode, strength, scale, seed, eta]
293
  run_button.click(fn=process, inputs=ips, outputs=[result_gallery, result_text])
294
  return demo