Zhengyi commited on
Commit
c5236d2
1 Parent(s): 0d9e593
Files changed (3) hide show
  1. .gitignore +1 -1
  2. app.py +34 -12
  3. pipelines.py +4 -4
.gitignore CHANGED
@@ -1,5 +1,5 @@
1
  # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
  *.py[cod]
4
  *$py.class
5
 
 
1
  # Byte-compiled / optimized / DLL files
2
+ **/__pycache__/
3
  *.py[cod]
4
  *$py.class
5
 
app.py CHANGED
@@ -23,6 +23,17 @@ pipeline = None
23
  rembg_session = rembg.new_session()
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  def check_input_image(input_image):
27
  if input_image is None:
28
  raise gr.Error("No image uploaded!")
@@ -67,13 +78,18 @@ def add_background(image, bg_color=(255, 255, 255)):
67
  return Image.alpha_composite(background, image)
68
 
69
 
70
- def preprocess_image(input_image, do_remove_background, force_remove, foreground_ratio, backgroud_color):
71
  """
72
  input image is a pil image in RGBA, return RGB image
73
  """
74
- if do_remove_background:
75
- image = remove_background(input_image, rembg_session, force_remove)
 
 
 
 
76
  image = do_resize_content(image, foreground_ratio)
 
77
  image = add_background(image, backgroud_color)
78
  return image.convert("RGB")
79
 
@@ -150,8 +166,13 @@ with gr.Blocks() as demo:
150
  with gr.Row():
151
  with gr.Column():
152
  with gr.Row():
153
- do_remove_background = gr.Checkbox(label="Remove Background", value=True)
154
- force_remove = gr.Checkbox(label="Force Remove", value=False)
 
 
 
 
 
155
  back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
156
  foreground_ratio = gr.Slider(
157
  label="Foreground Ratio",
@@ -163,9 +184,13 @@ with gr.Blocks() as demo:
163
 
164
  with gr.Column():
165
  seed = gr.Number(value=1234, label="seed", precision=0)
166
- guidance_scale = gr.Number(value=5.5, minimum=0, maximum=20, label="guidance_scale")
167
- step = gr.Number(value=50, minimum=1, maximum=100, label="sample steps", precision=0)
168
  text_button = gr.Button("Generate 3D shape")
 
 
 
 
169
  with gr.Column():
170
  image_output = gr.Image(interactive=False, label="Output RGB image")
171
  xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
@@ -188,14 +213,11 @@ with gr.Blocks() as demo:
188
  output_model,
189
  output_obj,
190
  ]
191
- gr.Examples(
192
- examples=[os.path.join("examples", i) for i in os.listdir("examples")],
193
- inputs=[image_input],
194
- )
195
 
196
  text_button.click(fn=check_input_image, inputs=[image_input]).success(
197
  fn=preprocess_image,
198
- inputs=[image_input, do_remove_background, force_remove, foreground_ratio, back_groud_color],
199
  outputs=[processed_image],
200
  ).success(
201
  fn=gen_image,
 
23
  rembg_session = rembg.new_session()
24
 
25
 
26
+ def expand_to_square(image, bg_color=(0, 0, 0, 0)):
27
+ # expand image to 1:1
28
+ width, height = image.size
29
+ if width == height:
30
+ return image
31
+ new_size = (max(width, height), max(width, height))
32
+ new_image = Image.new("RGBA", new_size, bg_color)
33
+ paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
34
+ new_image.paste(image, paste_position)
35
+ return new_image
36
+
37
  def check_input_image(input_image):
38
  if input_image is None:
39
  raise gr.Error("No image uploaded!")
 
78
  return Image.alpha_composite(background, image)
79
 
80
 
81
+ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
82
  """
83
  input image is a pil image in RGBA, return RGB image
84
  """
85
+ print(background_choice)
86
+ if background_choice == "Alpha as mask":
87
+ background = Image.new("RGBA", image.size, (0, 0, 0, 0))
88
+ image = Image.alpha_composite(background, image)
89
+ else:
90
+ image = remove_background(image, rembg_session, force_remove=True)
91
  image = do_resize_content(image, foreground_ratio)
92
+ image = expand_to_square(image)
93
  image = add_background(image, backgroud_color)
94
  return image.convert("RGB")
95
 
 
166
  with gr.Row():
167
  with gr.Column():
168
  with gr.Row():
169
+ background_choice = gr.Radio([
170
+ "Alpha as mask",
171
+ "Auto Remove background"
172
+ ], value="Alpha as mask",
173
+ label="backgroud choice")
174
+ # do_remove_background = gr.Checkbox(label=, value=True)
175
+ # force_remove = gr.Checkbox(label=, value=False)
176
  back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
177
  foreground_ratio = gr.Slider(
178
  label="Foreground Ratio",
 
184
 
185
  with gr.Column():
186
  seed = gr.Number(value=1234, label="seed", precision=0)
187
+ guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="guidance_scale")
188
+ step = gr.Number(value=50, minimum=30, maximum=100, label="sample steps", precision=0)
189
  text_button = gr.Button("Generate 3D shape")
190
+ gr.Examples(
191
+ examples=[os.path.join("examples", i) for i in os.listdir("examples")],
192
+ inputs=[image_input],
193
+ )
194
  with gr.Column():
195
  image_output = gr.Image(interactive=False, label="Output RGB image")
196
  xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
 
213
  output_model,
214
  output_obj,
215
  ]
216
+
 
 
 
217
 
218
  text_button.click(fn=check_input_image, inputs=[image_input]).success(
219
  fn=preprocess_image,
220
+ inputs=[image_input, background_choice, foreground_ratio, back_groud_color],
221
  outputs=[processed_image],
222
  ).success(
223
  fn=gen_image,
pipelines.py CHANGED
@@ -92,7 +92,7 @@ class TwoStagePipeline(object):
92
  stage1_images.pop(self.stage1_sampler.ref_position)
93
  return stage1_images
94
 
95
- def stage2_sample(self, pixel_img, stage1_images):
96
  if type(pixel_img) == str:
97
  pixel_img = Image.open(pixel_img)
98
 
@@ -112,8 +112,8 @@ class TwoStagePipeline(object):
112
  self.stage2_sampler.sampler,
113
  pixel_images=stage1_images,
114
  ip=pixel_img,
115
- step=50,
116
- scale=5,
117
  batch_size=self.stage2_sampler.batch_size,
118
  ddim_eta=0.0,
119
  dtype=self.stage2_sampler.dtype,
@@ -134,7 +134,7 @@ class TwoStagePipeline(object):
134
  def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
135
  pixel_img = do_resize_content(pixel_img, self.resize_rate)
136
  stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
137
- stage2_images = self.stage2_sample(pixel_img, stage1_images)
138
 
139
  return {
140
  "ref_img": pixel_img,
 
92
  stage1_images.pop(self.stage1_sampler.ref_position)
93
  return stage1_images
94
 
95
+ def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50):
96
  if type(pixel_img) == str:
97
  pixel_img = Image.open(pixel_img)
98
 
 
112
  self.stage2_sampler.sampler,
113
  pixel_images=stage1_images,
114
  ip=pixel_img,
115
+ step=step,
116
+ scale=scale,
117
  batch_size=self.stage2_sampler.batch_size,
118
  ddim_eta=0.0,
119
  dtype=self.stage2_sampler.dtype,
 
134
  def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
135
  pixel_img = do_resize_content(pixel_img, self.resize_rate)
136
  stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
137
+ stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step)
138
 
139
  return {
140
  "ref_img": pixel_img,