bg removal
Browse files- app.py +136 -172
- briarmbg.py +460 -0
- utils.py +114 -0
app.py
CHANGED
@@ -1,11 +1,18 @@
|
|
1 |
import sys
|
2 |
import os
|
3 |
import torch
|
4 |
-
|
5 |
-
from
|
6 |
-
from PIL import Image, ImageSequence, ImageOps
|
7 |
from typing import List
|
8 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
sys.path.append(os.path.dirname("./ComfyUI/"))
|
11 |
from ComfyUI.nodes import (
|
@@ -27,20 +34,11 @@ from ComfyUI.custom_nodes.layerdiffuse.layered_diffusion import (
|
|
27 |
LayeredDiffusionCond,
|
28 |
)
|
29 |
import gradio as gr
|
|
|
30 |
|
|
|
31 |
|
32 |
-
|
33 |
-
repo_id="lllyasviel/fav_models",
|
34 |
-
subfolder="fav",
|
35 |
-
filename="juggernautXL_v8Rundiffusion.safetensors",
|
36 |
-
)
|
37 |
-
try:
|
38 |
-
os.symlink(
|
39 |
-
MODEL_PATH,
|
40 |
-
Path("./ComfyUI/models/checkpoints/juggernautXL_v8Rundiffusion.safetensors"),
|
41 |
-
)
|
42 |
-
except FileExistsError:
|
43 |
-
pass
|
44 |
|
45 |
with torch.inference_mode():
|
46 |
ckpt_load_checkpoint = CheckpointLoaderSimple().load_checkpoint
|
@@ -58,73 +56,14 @@ ld_decode = LayeredDiffusionDecode().decode
|
|
58 |
mask_to_image = MaskToImage().mask_to_image
|
59 |
invert_mask = InvertMask().invert
|
60 |
join_image_with_alpha = JoinImageWithAlpha().join_image_with_alpha
|
61 |
-
|
62 |
-
|
63 |
-
def tensor_to_pil(images: torch.Tensor | List[torch.Tensor]) -> List[Image.Image]:
|
64 |
-
if not isinstance(images, list):
|
65 |
-
images = [images]
|
66 |
-
imgs = []
|
67 |
-
for image in images:
|
68 |
-
i = 255.0 * image.cpu().numpy()
|
69 |
-
img = Image.fromarray(np.clip(np.squeeze(i), 0, 255).astype(np.uint8))
|
70 |
-
imgs.append(img)
|
71 |
-
return imgs
|
72 |
-
|
73 |
-
|
74 |
-
def pad_image(input_image):
|
75 |
-
pad_w, pad_h = (
|
76 |
-
np.max(((2, 2), np.ceil(np.array(input_image.size) / 64).astype(int)), axis=0)
|
77 |
-
* 64
|
78 |
-
- input_image.size
|
79 |
-
)
|
80 |
-
im_padded = Image.fromarray(
|
81 |
-
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode="edge")
|
82 |
-
)
|
83 |
-
w, h = im_padded.size
|
84 |
-
if w == h:
|
85 |
-
return im_padded
|
86 |
-
elif w > h:
|
87 |
-
new_image = Image.new(im_padded.mode, (w, w), (0, 0, 0))
|
88 |
-
new_image.paste(im_padded, (0, (w - h) // 2))
|
89 |
-
return new_image
|
90 |
-
else:
|
91 |
-
new_image = Image.new(im_padded.mode, (h, h), (0, 0, 0))
|
92 |
-
new_image.paste(im_padded, ((h - w) // 2, 0))
|
93 |
-
return new_image
|
94 |
-
|
95 |
-
|
96 |
-
def pil_to_tensor(image: Image.Image) -> tuple[torch.Tensor, torch.Tensor]:
|
97 |
-
output_images = []
|
98 |
-
output_masks = []
|
99 |
-
for i in ImageSequence.Iterator(image):
|
100 |
-
i = ImageOps.exif_transpose(i)
|
101 |
-
if i.mode == "I":
|
102 |
-
i = i.point(lambda i: i * (1 / 255))
|
103 |
-
image = i.convert("RGB")
|
104 |
-
image = np.array(image).astype(np.float32) / 255.0
|
105 |
-
image = torch.from_numpy(image)[None,]
|
106 |
-
if "A" in i.getbands():
|
107 |
-
mask = np.array(i.getchannel("A")).astype(np.float32) / 255.0
|
108 |
-
mask = 1.0 - torch.from_numpy(mask)
|
109 |
-
else:
|
110 |
-
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
111 |
-
output_images.append(image)
|
112 |
-
output_masks.append(mask.unsqueeze(0))
|
113 |
-
|
114 |
-
if len(output_images) > 1:
|
115 |
-
output_image = torch.cat(output_images, dim=0)
|
116 |
-
output_mask = torch.cat(output_masks, dim=0)
|
117 |
-
else:
|
118 |
-
output_image = output_images[0]
|
119 |
-
output_mask = output_masks[0]
|
120 |
-
|
121 |
-
return (output_image, output_mask)
|
122 |
|
123 |
|
124 |
def predict(
|
125 |
prompt: str,
|
126 |
negative_prompt: str,
|
127 |
input_image: Image.Image | None,
|
|
|
128 |
cond_mode: str,
|
129 |
seed: int,
|
130 |
sampler_name: str,
|
@@ -133,95 +72,115 @@ def predict(
|
|
133 |
cfg: float,
|
134 |
denoise: float,
|
135 |
):
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
cliptextencode_negative_prompt = cliptextencode(
|
142 |
-
text=negative_prompt,
|
143 |
-
clip=ckpt[1],
|
144 |
-
)
|
145 |
-
emptylatentimage_sample = emptylatentimage_generate(
|
146 |
-
width=1024, height=1024, batch_size=1
|
147 |
-
)
|
148 |
-
|
149 |
-
if input_image is not None:
|
150 |
-
img_tensor = pil_to_tensor(pad_image(input_image).resize((1024, 1024)))
|
151 |
-
img_latent = vae_encode(pixels=img_tensor[0], vae=ckpt[2])
|
152 |
-
layereddiffusionapply_sample = ld_cond_apply_layered_diffusion(
|
153 |
-
config=cond_mode,
|
154 |
-
weight=1,
|
155 |
-
model=ckpt[0],
|
156 |
-
cond=cliptextencode_prompt[0],
|
157 |
-
uncond=cliptextencode_negative_prompt[0],
|
158 |
-
latent=img_latent[0],
|
159 |
)
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
sampler_name=sampler_name,
|
164 |
-
scheduler=scheduler,
|
165 |
-
seed=seed,
|
166 |
-
model=layereddiffusionapply_sample[0],
|
167 |
-
positive=layereddiffusionapply_sample[1],
|
168 |
-
negative=layereddiffusionapply_sample[2],
|
169 |
-
latent_image=emptylatentimage_sample[0],
|
170 |
-
denoise=denoise,
|
171 |
-
)
|
172 |
-
|
173 |
-
vaedecode_sample = vae_decode(
|
174 |
-
samples=ksampler[0],
|
175 |
-
vae=ckpt[2],
|
176 |
)
|
177 |
-
|
178 |
-
|
179 |
-
sub_batch_size=16,
|
180 |
-
samples=ksampler[0],
|
181 |
-
images=vaedecode_sample[0],
|
182 |
)
|
183 |
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
model=layereddiffusionapply_sample[0],
|
198 |
-
positive=cliptextencode_prompt[0],
|
199 |
-
negative=cliptextencode_negative_prompt[0],
|
200 |
-
latent_image=emptylatentimage_sample[0],
|
201 |
-
denoise=denoise,
|
202 |
-
)
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
rgb_img = tensor_to_pil(vaedecode_sample[0])
|
223 |
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
|
227 |
examples = [["An old men sit on a chair looking at the sky"]]
|
@@ -233,18 +192,18 @@ def flatten(l: List[List[any]]) -> List[any]:
|
|
233 |
|
234 |
def predict_examples(prompt, negative_prompt):
|
235 |
return predict(
|
236 |
-
prompt, negative_prompt, None, None, 0, "euler", "normal", 20, 8.0, 1.0
|
237 |
)
|
238 |
|
239 |
|
240 |
css = """
|
241 |
.gradio-container{
|
242 |
-
max-width:
|
243 |
}
|
244 |
"""
|
245 |
with gr.Blocks(css=css) as blocks:
|
246 |
gr.Markdown("""# LayerDiffuse (unofficial)
|
247 |
-
|
248 |
""")
|
249 |
|
250 |
with gr.Row():
|
@@ -253,12 +212,18 @@ with gr.Blocks(css=css) as blocks:
|
|
253 |
negative_prompt = gr.Text(label="Negative Prompt")
|
254 |
button = gr.Button("Generate")
|
255 |
with gr.Accordion(open=False, label="Input Images (Optional)"):
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
with gr.Accordion(open=False, label="Advanced Options"):
|
263 |
seed = gr.Slider(
|
264 |
label="Seed",
|
@@ -278,8 +243,8 @@ with gr.Blocks(css=css) as blocks:
|
|
278 |
label="Scheduler",
|
279 |
value=samplers.KSampler.SCHEDULERS[0],
|
280 |
)
|
281 |
-
steps = gr.
|
282 |
-
label="Steps", value=20, minimum=1, maximum=
|
283 |
)
|
284 |
cfg = gr.Number(
|
285 |
label="CFG", value=8.0, minimum=0.0, maximum=100.0, step=0.1
|
@@ -289,14 +254,13 @@ with gr.Blocks(css=css) as blocks:
|
|
289 |
)
|
290 |
|
291 |
with gr.Column(scale=1.8):
|
292 |
-
gallery = gr.Gallery(
|
293 |
-
columns=[2], rows=[2], object_fit="contain", height="unset"
|
294 |
-
)
|
295 |
|
296 |
inputs = [
|
297 |
prompt,
|
298 |
negative_prompt,
|
299 |
input_image,
|
|
|
300 |
cond_mode,
|
301 |
seed,
|
302 |
sampler_name,
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
import torch
|
4 |
+
|
5 |
+
from PIL import Image
|
|
|
6 |
from typing import List
|
7 |
import numpy as np
|
8 |
+
from utils import (
|
9 |
+
tensor_to_pil,
|
10 |
+
pil_to_tensor,
|
11 |
+
pad_image,
|
12 |
+
postprocess_image,
|
13 |
+
preprocess_image,
|
14 |
+
downloadModels,
|
15 |
+
)
|
16 |
|
17 |
sys.path.append(os.path.dirname("./ComfyUI/"))
|
18 |
from ComfyUI.nodes import (
|
|
|
34 |
LayeredDiffusionCond,
|
35 |
)
|
36 |
import gradio as gr
|
37 |
+
from briarmbg import BriaRMBG
|
38 |
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
|
41 |
+
downloadModels()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
with torch.inference_mode():
|
44 |
ckpt_load_checkpoint = CheckpointLoaderSimple().load_checkpoint
|
|
|
56 |
mask_to_image = MaskToImage().mask_to_image
|
57 |
invert_mask = InvertMask().invert
|
58 |
join_image_with_alpha = JoinImageWithAlpha().join_image_with_alpha
|
59 |
+
rmbg_model = BriaRMBG.from_pretrained("briaai/RMBG-1.4").to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
def predict(
|
63 |
prompt: str,
|
64 |
negative_prompt: str,
|
65 |
input_image: Image.Image | None,
|
66 |
+
remove_bg: bool,
|
67 |
cond_mode: str,
|
68 |
seed: int,
|
69 |
sampler_name: str,
|
|
|
72 |
cfg: float,
|
73 |
denoise: float,
|
74 |
):
|
75 |
+
try:
|
76 |
+
with torch.inference_mode():
|
77 |
+
cliptextencode_prompt = cliptextencode(
|
78 |
+
text=prompt,
|
79 |
+
clip=ckpt[1],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
)
|
81 |
+
cliptextencode_negative_prompt = cliptextencode(
|
82 |
+
text=negative_prompt,
|
83 |
+
clip=ckpt[1],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
)
|
85 |
+
emptylatentimage_sample = emptylatentimage_generate(
|
86 |
+
width=1024, height=1024, batch_size=1
|
|
|
|
|
|
|
87 |
)
|
88 |
|
89 |
+
if input_image is not None:
|
90 |
+
input_image = pad_image(input_image).resize((1024, 1024))
|
91 |
+
if remove_bg:
|
92 |
+
orig_im_size = input_image.size
|
93 |
+
image = preprocess_image(np.array(input_image), [1024, 1024]).to(
|
94 |
+
device
|
95 |
+
)
|
96 |
+
|
97 |
+
result = rmbg_model(image)
|
98 |
+
# post process
|
99 |
+
result_mask_image = postprocess_image(result[0][0], orig_im_size)
|
100 |
+
|
101 |
+
# save result
|
102 |
+
pil_mask = Image.fromarray(result_mask_image)
|
103 |
+
no_bg_image = Image.new("RGBA", pil_mask.size, (0, 0, 0, 0))
|
104 |
+
no_bg_image.paste(input_image, mask=pil_mask)
|
105 |
+
input_image = no_bg_image
|
106 |
+
|
107 |
+
img_tensor = pil_to_tensor(input_image)
|
108 |
+
img_latent = vae_encode(pixels=img_tensor[0], vae=ckpt[2])
|
109 |
+
layereddiffusionapply_sample = ld_cond_apply_layered_diffusion(
|
110 |
+
config=cond_mode,
|
111 |
+
weight=1,
|
112 |
+
model=ckpt[0],
|
113 |
+
cond=cliptextencode_prompt[0],
|
114 |
+
uncond=cliptextencode_negative_prompt[0],
|
115 |
+
latent=img_latent[0],
|
116 |
+
)
|
117 |
+
ksampler = ksampler_sample(
|
118 |
+
steps=steps,
|
119 |
+
cfg=cfg,
|
120 |
+
sampler_name=sampler_name,
|
121 |
+
scheduler=scheduler,
|
122 |
+
seed=seed,
|
123 |
+
model=layereddiffusionapply_sample[0],
|
124 |
+
positive=layereddiffusionapply_sample[1],
|
125 |
+
negative=layereddiffusionapply_sample[2],
|
126 |
+
latent_image=emptylatentimage_sample[0],
|
127 |
+
denoise=denoise,
|
128 |
+
)
|
129 |
|
130 |
+
vaedecode_sample = vae_decode(
|
131 |
+
samples=ksampler[0],
|
132 |
+
vae=ckpt[2],
|
133 |
+
)
|
134 |
+
layereddiffusiondecode_sample = ld_decode(
|
135 |
+
sd_version="SDXL",
|
136 |
+
sub_batch_size=16,
|
137 |
+
samples=ksampler[0],
|
138 |
+
images=vaedecode_sample[0],
|
139 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
rgb_img = tensor_to_pil(vaedecode_sample[0])
|
142 |
+
return flatten([rgb_img])
|
143 |
+
else:
|
144 |
+
layereddiffusionapply_sample = ld_fg_apply_layered_diffusion(
|
145 |
+
config="SDXL, Conv Injection", weight=1, model=ckpt[0]
|
146 |
+
)
|
147 |
+
ksampler = ksampler_sample(
|
148 |
+
steps=steps,
|
149 |
+
cfg=cfg,
|
150 |
+
sampler_name=sampler_name,
|
151 |
+
scheduler=scheduler,
|
152 |
+
seed=seed,
|
153 |
+
model=layereddiffusionapply_sample[0],
|
154 |
+
positive=cliptextencode_prompt[0],
|
155 |
+
negative=cliptextencode_negative_prompt[0],
|
156 |
+
latent_image=emptylatentimage_sample[0],
|
157 |
+
denoise=denoise,
|
158 |
+
)
|
|
|
159 |
|
160 |
+
vaedecode_sample = vae_decode(
|
161 |
+
samples=ksampler[0],
|
162 |
+
vae=ckpt[2],
|
163 |
+
)
|
164 |
+
layereddiffusiondecode_sample = ld_decode(
|
165 |
+
sd_version="SDXL",
|
166 |
+
sub_batch_size=16,
|
167 |
+
samples=ksampler[0],
|
168 |
+
images=vaedecode_sample[0],
|
169 |
+
)
|
170 |
+
mask = mask_to_image(mask=layereddiffusiondecode_sample[1])
|
171 |
+
ld_image = tensor_to_pil(layereddiffusiondecode_sample[0][0])
|
172 |
+
inverted_mask = invert_mask(mask=layereddiffusiondecode_sample[1])
|
173 |
+
rgba_img = join_image_with_alpha(
|
174 |
+
image=layereddiffusiondecode_sample[0], alpha=inverted_mask[0]
|
175 |
+
)
|
176 |
+
rgba_img = tensor_to_pil(rgba_img[0])
|
177 |
+
mask = tensor_to_pil(mask[0])
|
178 |
+
rgb_img = tensor_to_pil(vaedecode_sample[0])
|
179 |
+
|
180 |
+
return flatten([rgba_img, mask])
|
181 |
+
# return flatten([rgba_img, mask, rgb_img, ld_image])
|
182 |
+
except Exception as e:
|
183 |
+
raise gr.Error(e)
|
184 |
|
185 |
|
186 |
examples = [["An old men sit on a chair looking at the sky"]]
|
|
|
192 |
|
193 |
def predict_examples(prompt, negative_prompt):
|
194 |
return predict(
|
195 |
+
prompt, negative_prompt, None, False, None, 0, "euler", "normal", 20, 8.0, 1.0
|
196 |
)
|
197 |
|
198 |
|
199 |
css = """
|
200 |
.gradio-container{
|
201 |
+
max-width: 50rem;
|
202 |
}
|
203 |
"""
|
204 |
with gr.Blocks(css=css) as blocks:
|
205 |
gr.Markdown("""# LayerDiffuse (unofficial)
|
206 |
+
Using ComfyUI building blocks with custom node by [huchenlei](https://github.com/huchenlei/ComfyUI-layerdiffuse)
|
207 |
""")
|
208 |
|
209 |
with gr.Row():
|
|
|
212 |
negative_prompt = gr.Text(label="Negative Prompt")
|
213 |
button = gr.Button("Generate")
|
214 |
with gr.Accordion(open=False, label="Input Images (Optional)"):
|
215 |
+
with gr.Group():
|
216 |
+
cond_mode = gr.Radio(
|
217 |
+
value="SDXL, Foreground",
|
218 |
+
choices=["SDXL, Foreground", "SDXL, Background"],
|
219 |
+
info="Whether to use input image as foreground or background",
|
220 |
+
)
|
221 |
+
remove_bg = gr.Checkbox(
|
222 |
+
info="Remove background using BriaRMBG",
|
223 |
+
label="Remove Background",
|
224 |
+
value=False,
|
225 |
+
)
|
226 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
227 |
with gr.Accordion(open=False, label="Advanced Options"):
|
228 |
seed = gr.Slider(
|
229 |
label="Seed",
|
|
|
243 |
label="Scheduler",
|
244 |
value=samplers.KSampler.SCHEDULERS[0],
|
245 |
)
|
246 |
+
steps = gr.Slider(
|
247 |
+
label="Steps", value=20, minimum=1, maximum=30, step=1
|
248 |
)
|
249 |
cfg = gr.Number(
|
250 |
label="CFG", value=8.0, minimum=0.0, maximum=100.0, step=0.1
|
|
|
254 |
)
|
255 |
|
256 |
with gr.Column(scale=1.8):
|
257 |
+
gallery = gr.Gallery(columns=[2], object_fit="contain", height="unset")
|
|
|
|
|
258 |
|
259 |
inputs = [
|
260 |
prompt,
|
261 |
negative_prompt,
|
262 |
input_image,
|
263 |
+
remove_bg,
|
264 |
cond_mode,
|
265 |
seed,
|
266 |
sampler_name,
|
briarmbg.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
|
6 |
+
|
7 |
+
class REBNCONV(nn.Module):
|
8 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
9 |
+
super(REBNCONV, self).__init__()
|
10 |
+
|
11 |
+
self.conv_s1 = nn.Conv2d(
|
12 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
13 |
+
)
|
14 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
15 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
hx = x
|
19 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
20 |
+
|
21 |
+
return xout
|
22 |
+
|
23 |
+
|
24 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
25 |
+
def _upsample_like(src, tar):
|
26 |
+
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
|
27 |
+
|
28 |
+
return src
|
29 |
+
|
30 |
+
|
31 |
+
### RSU-7 ###
|
32 |
+
class RSU7(nn.Module):
|
33 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
34 |
+
super(RSU7, self).__init__()
|
35 |
+
|
36 |
+
self.in_ch = in_ch
|
37 |
+
self.mid_ch = mid_ch
|
38 |
+
self.out_ch = out_ch
|
39 |
+
|
40 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
41 |
+
|
42 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
43 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
44 |
+
|
45 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
46 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
47 |
+
|
48 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
49 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
50 |
+
|
51 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
52 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
53 |
+
|
54 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
55 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
56 |
+
|
57 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
58 |
+
|
59 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
60 |
+
|
61 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
62 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
63 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
64 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
65 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
66 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
b, c, h, w = x.shape
|
70 |
+
|
71 |
+
hx = x
|
72 |
+
hxin = self.rebnconvin(hx)
|
73 |
+
|
74 |
+
hx1 = self.rebnconv1(hxin)
|
75 |
+
hx = self.pool1(hx1)
|
76 |
+
|
77 |
+
hx2 = self.rebnconv2(hx)
|
78 |
+
hx = self.pool2(hx2)
|
79 |
+
|
80 |
+
hx3 = self.rebnconv3(hx)
|
81 |
+
hx = self.pool3(hx3)
|
82 |
+
|
83 |
+
hx4 = self.rebnconv4(hx)
|
84 |
+
hx = self.pool4(hx4)
|
85 |
+
|
86 |
+
hx5 = self.rebnconv5(hx)
|
87 |
+
hx = self.pool5(hx5)
|
88 |
+
|
89 |
+
hx6 = self.rebnconv6(hx)
|
90 |
+
|
91 |
+
hx7 = self.rebnconv7(hx6)
|
92 |
+
|
93 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
94 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
95 |
+
|
96 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
97 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
98 |
+
|
99 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
100 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
101 |
+
|
102 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
103 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
104 |
+
|
105 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
106 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
107 |
+
|
108 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
109 |
+
|
110 |
+
return hx1d + hxin
|
111 |
+
|
112 |
+
|
113 |
+
### RSU-6 ###
|
114 |
+
class RSU6(nn.Module):
|
115 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
116 |
+
super(RSU6, self).__init__()
|
117 |
+
|
118 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
119 |
+
|
120 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
121 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
122 |
+
|
123 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
124 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
125 |
+
|
126 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
127 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
128 |
+
|
129 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
130 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
131 |
+
|
132 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
133 |
+
|
134 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
135 |
+
|
136 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
137 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
138 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
139 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
140 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
hx = x
|
144 |
+
|
145 |
+
hxin = self.rebnconvin(hx)
|
146 |
+
|
147 |
+
hx1 = self.rebnconv1(hxin)
|
148 |
+
hx = self.pool1(hx1)
|
149 |
+
|
150 |
+
hx2 = self.rebnconv2(hx)
|
151 |
+
hx = self.pool2(hx2)
|
152 |
+
|
153 |
+
hx3 = self.rebnconv3(hx)
|
154 |
+
hx = self.pool3(hx3)
|
155 |
+
|
156 |
+
hx4 = self.rebnconv4(hx)
|
157 |
+
hx = self.pool4(hx4)
|
158 |
+
|
159 |
+
hx5 = self.rebnconv5(hx)
|
160 |
+
|
161 |
+
hx6 = self.rebnconv6(hx5)
|
162 |
+
|
163 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
164 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
165 |
+
|
166 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
167 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
168 |
+
|
169 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
170 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
171 |
+
|
172 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
173 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
174 |
+
|
175 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
176 |
+
|
177 |
+
return hx1d + hxin
|
178 |
+
|
179 |
+
|
180 |
+
### RSU-5 ###
|
181 |
+
class RSU5(nn.Module):
|
182 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
183 |
+
super(RSU5, self).__init__()
|
184 |
+
|
185 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
186 |
+
|
187 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
188 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
189 |
+
|
190 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
191 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
192 |
+
|
193 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
194 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
195 |
+
|
196 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
197 |
+
|
198 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
199 |
+
|
200 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
201 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
202 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
203 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
hx = x
|
207 |
+
|
208 |
+
hxin = self.rebnconvin(hx)
|
209 |
+
|
210 |
+
hx1 = self.rebnconv1(hxin)
|
211 |
+
hx = self.pool1(hx1)
|
212 |
+
|
213 |
+
hx2 = self.rebnconv2(hx)
|
214 |
+
hx = self.pool2(hx2)
|
215 |
+
|
216 |
+
hx3 = self.rebnconv3(hx)
|
217 |
+
hx = self.pool3(hx3)
|
218 |
+
|
219 |
+
hx4 = self.rebnconv4(hx)
|
220 |
+
|
221 |
+
hx5 = self.rebnconv5(hx4)
|
222 |
+
|
223 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
224 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
225 |
+
|
226 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
227 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
228 |
+
|
229 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
230 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
231 |
+
|
232 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
233 |
+
|
234 |
+
return hx1d + hxin
|
235 |
+
|
236 |
+
|
237 |
+
### RSU-4 ###
|
238 |
+
class RSU4(nn.Module):
|
239 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
240 |
+
super(RSU4, self).__init__()
|
241 |
+
|
242 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
243 |
+
|
244 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
245 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
246 |
+
|
247 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
248 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
249 |
+
|
250 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
251 |
+
|
252 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
253 |
+
|
254 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
255 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
256 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
hx = x
|
260 |
+
|
261 |
+
hxin = self.rebnconvin(hx)
|
262 |
+
|
263 |
+
hx1 = self.rebnconv1(hxin)
|
264 |
+
hx = self.pool1(hx1)
|
265 |
+
|
266 |
+
hx2 = self.rebnconv2(hx)
|
267 |
+
hx = self.pool2(hx2)
|
268 |
+
|
269 |
+
hx3 = self.rebnconv3(hx)
|
270 |
+
|
271 |
+
hx4 = self.rebnconv4(hx3)
|
272 |
+
|
273 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
274 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
275 |
+
|
276 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
277 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
278 |
+
|
279 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
280 |
+
|
281 |
+
return hx1d + hxin
|
282 |
+
|
283 |
+
|
284 |
+
### RSU-4F ###
|
285 |
+
class RSU4F(nn.Module):
|
286 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
287 |
+
super(RSU4F, self).__init__()
|
288 |
+
|
289 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
290 |
+
|
291 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
292 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
293 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
294 |
+
|
295 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
296 |
+
|
297 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
298 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
299 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
hx = x
|
303 |
+
|
304 |
+
hxin = self.rebnconvin(hx)
|
305 |
+
|
306 |
+
hx1 = self.rebnconv1(hxin)
|
307 |
+
hx2 = self.rebnconv2(hx1)
|
308 |
+
hx3 = self.rebnconv3(hx2)
|
309 |
+
|
310 |
+
hx4 = self.rebnconv4(hx3)
|
311 |
+
|
312 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
313 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
314 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
315 |
+
|
316 |
+
return hx1d + hxin
|
317 |
+
|
318 |
+
|
319 |
+
class myrebnconv(nn.Module):
|
320 |
+
def __init__(
|
321 |
+
self,
|
322 |
+
in_ch=3,
|
323 |
+
out_ch=1,
|
324 |
+
kernel_size=3,
|
325 |
+
stride=1,
|
326 |
+
padding=1,
|
327 |
+
dilation=1,
|
328 |
+
groups=1,
|
329 |
+
):
|
330 |
+
super(myrebnconv, self).__init__()
|
331 |
+
|
332 |
+
self.conv = nn.Conv2d(
|
333 |
+
in_ch,
|
334 |
+
out_ch,
|
335 |
+
kernel_size=kernel_size,
|
336 |
+
stride=stride,
|
337 |
+
padding=padding,
|
338 |
+
dilation=dilation,
|
339 |
+
groups=groups,
|
340 |
+
)
|
341 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
342 |
+
self.rl = nn.ReLU(inplace=True)
|
343 |
+
|
344 |
+
def forward(self, x):
|
345 |
+
return self.rl(self.bn(self.conv(x)))
|
346 |
+
|
347 |
+
|
348 |
+
class BriaRMBG(nn.Module, PyTorchModelHubMixin):
|
349 |
+
def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
|
350 |
+
super(BriaRMBG, self).__init__()
|
351 |
+
in_ch = config["in_ch"]
|
352 |
+
out_ch = config["out_ch"]
|
353 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
354 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
355 |
+
|
356 |
+
self.stage1 = RSU7(64, 32, 64)
|
357 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
358 |
+
|
359 |
+
self.stage2 = RSU6(64, 32, 128)
|
360 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
361 |
+
|
362 |
+
self.stage3 = RSU5(128, 64, 256)
|
363 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
364 |
+
|
365 |
+
self.stage4 = RSU4(256, 128, 512)
|
366 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
367 |
+
|
368 |
+
self.stage5 = RSU4F(512, 256, 512)
|
369 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
370 |
+
|
371 |
+
self.stage6 = RSU4F(512, 256, 512)
|
372 |
+
|
373 |
+
# decoder
|
374 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
375 |
+
self.stage4d = RSU4(1024, 128, 256)
|
376 |
+
self.stage3d = RSU5(512, 64, 128)
|
377 |
+
self.stage2d = RSU6(256, 32, 64)
|
378 |
+
self.stage1d = RSU7(128, 16, 64)
|
379 |
+
|
380 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
381 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
382 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
383 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
384 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
385 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
386 |
+
|
387 |
+
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
388 |
+
|
389 |
+
def forward(self, x):
|
390 |
+
hx = x
|
391 |
+
|
392 |
+
hxin = self.conv_in(hx)
|
393 |
+
# hx = self.pool_in(hxin)
|
394 |
+
|
395 |
+
# stage 1
|
396 |
+
hx1 = self.stage1(hxin)
|
397 |
+
hx = self.pool12(hx1)
|
398 |
+
|
399 |
+
# stage 2
|
400 |
+
hx2 = self.stage2(hx)
|
401 |
+
hx = self.pool23(hx2)
|
402 |
+
|
403 |
+
# stage 3
|
404 |
+
hx3 = self.stage3(hx)
|
405 |
+
hx = self.pool34(hx3)
|
406 |
+
|
407 |
+
# stage 4
|
408 |
+
hx4 = self.stage4(hx)
|
409 |
+
hx = self.pool45(hx4)
|
410 |
+
|
411 |
+
# stage 5
|
412 |
+
hx5 = self.stage5(hx)
|
413 |
+
hx = self.pool56(hx5)
|
414 |
+
|
415 |
+
# stage 6
|
416 |
+
hx6 = self.stage6(hx)
|
417 |
+
hx6up = _upsample_like(hx6, hx5)
|
418 |
+
|
419 |
+
# -------------------- decoder --------------------
|
420 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
421 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
422 |
+
|
423 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
424 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
425 |
+
|
426 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
427 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
428 |
+
|
429 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
430 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
431 |
+
|
432 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
433 |
+
|
434 |
+
# side output
|
435 |
+
d1 = self.side1(hx1d)
|
436 |
+
d1 = _upsample_like(d1, x)
|
437 |
+
|
438 |
+
d2 = self.side2(hx2d)
|
439 |
+
d2 = _upsample_like(d2, x)
|
440 |
+
|
441 |
+
d3 = self.side3(hx3d)
|
442 |
+
d3 = _upsample_like(d3, x)
|
443 |
+
|
444 |
+
d4 = self.side4(hx4d)
|
445 |
+
d4 = _upsample_like(d4, x)
|
446 |
+
|
447 |
+
d5 = self.side5(hx5d)
|
448 |
+
d5 = _upsample_like(d5, x)
|
449 |
+
|
450 |
+
d6 = self.side6(hx6)
|
451 |
+
d6 = _upsample_like(d6, x)
|
452 |
+
|
453 |
+
return [
|
454 |
+
F.sigmoid(d1),
|
455 |
+
F.sigmoid(d2),
|
456 |
+
F.sigmoid(d3),
|
457 |
+
F.sigmoid(d4),
|
458 |
+
F.sigmoid(d5),
|
459 |
+
F.sigmoid(d6),
|
460 |
+
], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
utils.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.transforms.functional import normalize
|
6 |
+
from PIL import Image, ImageOps, ImageSequence
|
7 |
+
from typing import List
|
8 |
+
from pathlib import Path
|
9 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
10 |
+
|
11 |
+
|
12 |
+
def tensor_to_pil(images: torch.Tensor | List[torch.Tensor]) -> List[Image.Image]:
|
13 |
+
if not isinstance(images, list):
|
14 |
+
images = [images]
|
15 |
+
imgs = []
|
16 |
+
for image in images:
|
17 |
+
i = 255.0 * image.cpu().numpy()
|
18 |
+
img = Image.fromarray(np.clip(np.squeeze(i), 0, 255).astype(np.uint8))
|
19 |
+
imgs.append(img)
|
20 |
+
return imgs
|
21 |
+
|
22 |
+
|
23 |
+
def pad_image(input_image):
|
24 |
+
pad_w, pad_h = (
|
25 |
+
np.max(((2, 2), np.ceil(np.array(input_image.size) / 64).astype(int)), axis=0)
|
26 |
+
* 64
|
27 |
+
- input_image.size
|
28 |
+
)
|
29 |
+
im_padded = Image.fromarray(
|
30 |
+
np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode="edge")
|
31 |
+
)
|
32 |
+
w, h = im_padded.size
|
33 |
+
if w == h:
|
34 |
+
return im_padded
|
35 |
+
elif w > h:
|
36 |
+
new_image = Image.new(im_padded.mode, (w, w), (0, 0, 0))
|
37 |
+
new_image.paste(im_padded, (0, (w - h) // 2))
|
38 |
+
return new_image
|
39 |
+
else:
|
40 |
+
new_image = Image.new(im_padded.mode, (h, h), (0, 0, 0))
|
41 |
+
new_image.paste(im_padded, ((h - w) // 2, 0))
|
42 |
+
return new_image
|
43 |
+
|
44 |
+
|
45 |
+
def pil_to_tensor(image: Image.Image) -> tuple[torch.Tensor, torch.Tensor]:
|
46 |
+
output_images = []
|
47 |
+
output_masks = []
|
48 |
+
for i in ImageSequence.Iterator(image):
|
49 |
+
i = ImageOps.exif_transpose(i)
|
50 |
+
if i.mode == "I":
|
51 |
+
i = i.point(lambda i: i * (1 / 255))
|
52 |
+
image = i.convert("RGB")
|
53 |
+
image = np.array(image).astype(np.float32) / 255.0
|
54 |
+
image = torch.from_numpy(image)[None,]
|
55 |
+
if "A" in i.getbands():
|
56 |
+
mask = np.array(i.getchannel("A")).astype(np.float32) / 255.0
|
57 |
+
mask = 1.0 - torch.from_numpy(mask)
|
58 |
+
else:
|
59 |
+
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
60 |
+
output_images.append(image)
|
61 |
+
output_masks.append(mask.unsqueeze(0))
|
62 |
+
|
63 |
+
if len(output_images) > 1:
|
64 |
+
output_image = torch.cat(output_images, dim=0)
|
65 |
+
output_mask = torch.cat(output_masks, dim=0)
|
66 |
+
else:
|
67 |
+
output_image = output_images[0]
|
68 |
+
output_mask = output_masks[0]
|
69 |
+
|
70 |
+
return (output_image, output_mask)
|
71 |
+
|
72 |
+
|
73 |
+
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
|
74 |
+
if len(im.shape) < 3:
|
75 |
+
im = im[:, :, np.newaxis]
|
76 |
+
# orig_im_size=im.shape[0:2]
|
77 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
|
78 |
+
im_tensor = F.interpolate(
|
79 |
+
torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
|
80 |
+
).type(torch.uint8)
|
81 |
+
image = torch.divide(im_tensor, 255.0)
|
82 |
+
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
83 |
+
return image
|
84 |
+
|
85 |
+
|
86 |
+
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
|
87 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
|
88 |
+
ma = torch.max(result)
|
89 |
+
mi = torch.min(result)
|
90 |
+
result = (result - mi) / (ma - mi)
|
91 |
+
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
|
92 |
+
im_array = np.squeeze(im_array)
|
93 |
+
return im_array
|
94 |
+
|
95 |
+
|
96 |
+
def downloadModels():
|
97 |
+
MODEL_PATH = hf_hub_download(
|
98 |
+
repo_id="lllyasviel/fav_models",
|
99 |
+
subfolder="fav",
|
100 |
+
filename="juggernautXL_v8Rundiffusion.safetensors",
|
101 |
+
)
|
102 |
+
LAYERS_PATH = snapshot_download(
|
103 |
+
repo_id="LayerDiffusion/layerdiffusion-v1", allow_patterns="*.safetensors"
|
104 |
+
)
|
105 |
+
for file in Path(LAYERS_PATH).glob("*.safetensors"):
|
106 |
+
target_path = Path(f"./ComfyUI/models/layer_model/{file.name}")
|
107 |
+
if not target_path.exists():
|
108 |
+
os.symlink(file, target_path)
|
109 |
+
|
110 |
+
model_target_path = Path(
|
111 |
+
"./ComfyUI/models/checkpoints/juggernautXL_v8Rundiffusion.safetensors"
|
112 |
+
)
|
113 |
+
if not model_target_path.exists():
|
114 |
+
os.symlink(MODEL_PATH, model_target_path)
|