multimodalart HF Staff commited on
Commit
3df2fdf
·
verified ·
1 Parent(s): 33d7a35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -2
app.py CHANGED
@@ -43,19 +43,34 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
 
46
  def can_expand(source_width, source_height, target_width, target_height, alignment):
 
47
  if alignment in ("Left", "Right") and source_width >= target_width:
48
  return False
49
  if alignment in ("Top", "Bottom") and source_height >= target_height:
50
  return False
51
  return True
52
 
53
- @spaces.GPU
54
  def infer(image, width, height, overlap_width, num_inference_steps, resize_option, custom_resize_size, prompt_input=None, alignment="Middle"):
55
  source = image
56
  target_size = (width, height)
57
  overlap = overlap_width
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  if resize_option == "Full":
60
  resize_size = max(source.width, source.height)
61
  elif resize_option == "1/2":
@@ -74,7 +89,7 @@ def infer(image, width, height, overlap_width, num_inference_steps, resize_optio
74
 
75
  if not can_expand(source.width, source.height, target_size[0], target_size[1], alignment):
76
  alignment = "Middle"
77
-
78
  if alignment == "Middle":
79
  margin_x = (target_size[0] - source.width) // 2
80
  margin_y = (target_size[1] - source.height) // 2
@@ -97,6 +112,7 @@ def infer(image, width, height, overlap_width, num_inference_steps, resize_optio
97
  mask = Image.new('L', target_size, 255)
98
  mask_draw = ImageDraw.Draw(mask)
99
 
 
100
  if alignment == "Middle":
101
  mask_draw.rectangle([
102
  (margin_x + overlap, margin_y + overlap),
@@ -151,9 +167,11 @@ def infer(image, width, height, overlap_width, num_inference_steps, resize_optio
151
  yield background, cnet_image
152
 
153
  def clear_result():
 
154
  return gr.update(value=None)
155
 
156
  def preload_presets(target_ratio, ui_width, ui_height):
 
157
  if target_ratio == "9:16":
158
  changed_width = 720
159
  changed_height = 1280
@@ -174,6 +192,8 @@ def select_the_right_preset(user_width, user_height):
174
  return "9:16"
175
  elif user_width == 1280 and user_height == 720:
176
  return "16:9"
 
 
177
  else:
178
  return "Custom"
179
 
 
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
+
47
  def can_expand(source_width, source_height, target_width, target_height, alignment):
48
+ """Checks if the image can be expanded based on the alignment."""
49
  if alignment in ("Left", "Right") and source_width >= target_width:
50
  return False
51
  if alignment in ("Top", "Bottom") and source_height >= target_height:
52
  return False
53
  return True
54
 
55
+ @spaces.GPU(duration=24)
56
  def infer(image, width, height, overlap_width, num_inference_steps, resize_option, custom_resize_size, prompt_input=None, alignment="Middle"):
57
  source = image
58
  target_size = (width, height)
59
  overlap = overlap_width
60
 
61
+ # Upscale if source is smaller than target in both dimensions
62
+ if source.width < target_size[0] and source.height < target_size[1]:
63
+ scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
64
+ new_width = int(source.width * scale_factor)
65
+ new_height = int(source.height * scale_factor)
66
+ source = source.resize((new_width, new_height), Image.LANCZOS)
67
+
68
+ if source.width > target_size[0] or source.height > target_size[1]:
69
+ scale_factor = min(target_size[0] / source.width, target_size[1] / source.height)
70
+ new_width = int(source.width * scale_factor)
71
+ new_height = int(source.height * scale_factor)
72
+ source = source.resize((new_width, new_height), Image.LANCZOS)
73
+
74
  if resize_option == "Full":
75
  resize_size = max(source.width, source.height)
76
  elif resize_option == "1/2":
 
89
 
90
  if not can_expand(source.width, source.height, target_size[0], target_size[1], alignment):
91
  alignment = "Middle"
92
+ # Calculate margins based on alignment
93
  if alignment == "Middle":
94
  margin_x = (target_size[0] - source.width) // 2
95
  margin_y = (target_size[1] - source.height) // 2
 
112
  mask = Image.new('L', target_size, 255)
113
  mask_draw = ImageDraw.Draw(mask)
114
 
115
+ # Adjust mask generation based on alignment
116
  if alignment == "Middle":
117
  mask_draw.rectangle([
118
  (margin_x + overlap, margin_y + overlap),
 
167
  yield background, cnet_image
168
 
169
  def clear_result():
170
+ """Clears the result ImageSlider."""
171
  return gr.update(value=None)
172
 
173
  def preload_presets(target_ratio, ui_width, ui_height):
174
+ """Updates the width and height sliders based on the selected aspect ratio."""
175
  if target_ratio == "9:16":
176
  changed_width = 720
177
  changed_height = 1280
 
192
  return "9:16"
193
  elif user_width == 1280 and user_height == 720:
194
  return "16:9"
195
+ elif user_width == 1024 and user_height == 1024:
196
+ return "1:1"
197
  else:
198
  return "Custom"
199