Joseph Catrambone commited on
Commit
d491fdb
1 Parent(s): a1a7f32

Automatically scale input images to 512x512 with center crop if they're non-square.

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -24,9 +24,21 @@ ddim_sampler = DDIMSampler(model) # ControlNet _only_ works with DDIM.
24
 
25
  def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces: int, min_confidence: float, num_samples, ddim_steps, guess_mode, strength, scale, seed: int, eta):
26
  with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
27
  empty = generate_annotation(input_image, max_faces, min_confidence)
28
  visualization = Image.fromarray(empty) # Save to help debug.
29
 
 
30
  empty = numpy.moveaxis(empty, 2, 0) # h, w, c -> c, h, w
31
  control = torch.from_numpy(empty.copy()).float().to(device) / 255.0
32
  control = torch.stack([control for _ in range(num_samples)], dim=0)
@@ -81,7 +93,7 @@ with block:
81
  gr.Markdown("## Control Stable Diffusion with a Facial Pose")
82
  with gr.Row():
83
  with gr.Column():
84
- input_image = gr.Image(source='upload', type="numpy")
85
  prompt = gr.Textbox(label="Prompt")
86
  run_button = gr.Button(label="Run")
87
  with gr.Accordion("Advanced options", open=False):
 
24
 
25
  def process(input_image: Image.Image, prompt, a_prompt, n_prompt, max_faces: int, min_confidence: float, num_samples, ddim_steps, guess_mode, strength, scale, seed: int, eta):
26
  with torch.no_grad():
27
+ # Scale to 512x512.
28
+ img_size = input_image.size
29
+ scale_factor = 512/min(img_size)
30
+ input_image = input_image.resize((1+int(img_size[0]*scale_factor), 1+int(img_size[1]*scale_factor)))
31
+ img_size = input_image.size
32
+ left_padding = (img_size[0] - 512)//2
33
+ top_padding = (img_size[1] - 512)//2
34
+ input_image = input_image.crop((left_padding, top_padding, left_padding+512, top_padding+512))
35
+
36
+ # Generate annotation
37
+ input_image = numpy.asarray(input_image)
38
  empty = generate_annotation(input_image, max_faces, min_confidence)
39
  visualization = Image.fromarray(empty) # Save to help debug.
40
 
41
+ # Prep for network:
42
  empty = numpy.moveaxis(empty, 2, 0) # h, w, c -> c, h, w
43
  control = torch.from_numpy(empty.copy()).float().to(device) / 255.0
44
  control = torch.stack([control for _ in range(num_samples)], dim=0)
 
93
  gr.Markdown("## Control Stable Diffusion with a Facial Pose")
94
  with gr.Row():
95
  with gr.Column():
96
+ input_image = gr.Image(source='upload', type="pil")
97
  prompt = gr.Textbox(label="Prompt")
98
  run_button = gr.Button(label="Run")
99
  with gr.Accordion("Advanced options", open=False):