Spaces:
Running
on
Zero
Running
on
Zero
add controlnet inpaint
Browse files- app.py +79 -1
- preprocessor.py +84 -0
- requirements.txt +3 -1
app.py
CHANGED
@@ -20,6 +20,8 @@ from io import BytesIO
|
|
20 |
from datetime import datetime
|
21 |
from diffusers.utils import load_image
|
22 |
import json
|
|
|
|
|
23 |
|
24 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
25 |
|
@@ -33,9 +35,27 @@ dtype = torch.bfloat16
|
|
33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
base_model = "black-forest-labs/FLUX.1-dev"
|
35 |
|
|
|
|
|
|
|
36 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
37 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
38 |
-
pipe =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
class calculateDuration:
|
@@ -129,6 +149,8 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
|
129 |
def run_flux(
|
130 |
image: Image.Image,
|
131 |
mask: Image.Image,
|
|
|
|
|
132 |
prompt: str,
|
133 |
lora_path: str,
|
134 |
lora_weights: str,
|
@@ -157,6 +179,8 @@ def run_flux(
|
|
157 |
prompt=prompt,
|
158 |
image=image,
|
159 |
mask_image=mask,
|
|
|
|
|
160 |
width=width,
|
161 |
height=height,
|
162 |
strength=strength_slider,
|
@@ -175,6 +199,7 @@ def process(
|
|
175 |
inpainting_prompt_text: str,
|
176 |
mask_inflation_slider: int,
|
177 |
mask_blur_slider: int,
|
|
|
178 |
seed_slicer: int,
|
179 |
randomize_seed_checkbox: bool,
|
180 |
strength_slider: float,
|
@@ -217,10 +242,58 @@ def process(
|
|
217 |
mask = mask.resize((width, height), Image.LANCZOS)
|
218 |
mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
try:
|
221 |
generated_image = run_flux(
|
222 |
image=image,
|
223 |
mask=mask,
|
|
|
|
|
224 |
prompt=inpainting_prompt_text,
|
225 |
lora_path=lora_path,
|
226 |
lora_scale=lora_scale,
|
@@ -275,6 +348,10 @@ with gr.Blocks() as demo:
|
|
275 |
placeholder="Enter text to generate inpainting",
|
276 |
container=False,
|
277 |
)
|
|
|
|
|
|
|
|
|
278 |
|
279 |
submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
|
280 |
|
@@ -382,6 +459,7 @@ with gr.Blocks() as demo:
|
|
382 |
inpainting_prompt_text_component,
|
383 |
mask_inflation_slider_component,
|
384 |
mask_blur_slider_component,
|
|
|
385 |
seed_slicer_component,
|
386 |
randomize_seed_checkbox_component,
|
387 |
strength_slider_component,
|
|
|
20 |
from datetime import datetime
|
21 |
from diffusers.utils import load_image
|
22 |
import json
|
23 |
+
from preprocessor import Preprocessor
|
24 |
+
from diffusers.pipelines.flux.pipeline_flux_controlnet_inpaint import FluxControlNetInpaintPipeline
|
25 |
|
26 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
27 |
|
|
|
35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
base_model = "black-forest-labs/FLUX.1-dev"
|
37 |
|
38 |
+
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
|
39 |
+
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
40 |
+
|
41 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
42 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
43 |
+
pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=dtype, vae=taef1).to(device)
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
control_mode_ids = {
|
48 |
+
"scribble_hed": 0,
|
49 |
+
"canny": 0, # supported
|
50 |
+
"mlsd": 0, # supported
|
51 |
+
"tile": 1, # supported
|
52 |
+
"depth_midas": 2, # supported
|
53 |
+
"blur": 3, # supported
|
54 |
+
"openpose": 4, # supported
|
55 |
+
"gray": 5, # supported
|
56 |
+
"low_quality": 6, # supported
|
57 |
+
}
|
58 |
+
|
59 |
|
60 |
|
61 |
class calculateDuration:
|
|
|
149 |
def run_flux(
|
150 |
image: Image.Image,
|
151 |
mask: Image.Image,
|
152 |
+
control_image: Image.Image,
|
153 |
+
control_mode: int,
|
154 |
prompt: str,
|
155 |
lora_path: str,
|
156 |
lora_weights: str,
|
|
|
179 |
prompt=prompt,
|
180 |
image=image,
|
181 |
mask_image=mask,
|
182 |
+
control_image=control_image,
|
183 |
+
control_mode=control_mode,
|
184 |
width=width,
|
185 |
height=height,
|
186 |
strength=strength_slider,
|
|
|
199 |
inpainting_prompt_text: str,
|
200 |
mask_inflation_slider: int,
|
201 |
mask_blur_slider: int,
|
202 |
+
control_mode: str,
|
203 |
seed_slicer: int,
|
204 |
randomize_seed_checkbox: bool,
|
205 |
strength_slider: float,
|
|
|
242 |
mask = mask.resize((width, height), Image.LANCZOS)
|
243 |
mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
|
244 |
|
245 |
+
|
246 |
+
# generated control_
|
247 |
+
with calculateDuration("Preprocessor Image"):
|
248 |
+
print("start to generate control image")
|
249 |
+
preprocessor = Preprocessor()
|
250 |
+
if control_mode == "depth_midas":
|
251 |
+
preprocessor.load("Midas")
|
252 |
+
control_image = preprocessor(
|
253 |
+
image=image,
|
254 |
+
image_resolution=width,
|
255 |
+
detect_resolution=512,
|
256 |
+
)
|
257 |
+
if control_mode == "openpose":
|
258 |
+
preprocessor.load("Openpose")
|
259 |
+
control_image = preprocessor(
|
260 |
+
image=image,
|
261 |
+
hand_and_face=True,
|
262 |
+
image_resolution=width,
|
263 |
+
detect_resolution=512,
|
264 |
+
)
|
265 |
+
if control_mode == "canny":
|
266 |
+
preprocessor.load("Canny")
|
267 |
+
control_image = preprocessor(
|
268 |
+
image=image,
|
269 |
+
image_resolution=width,
|
270 |
+
detect_resolution=512,
|
271 |
+
)
|
272 |
+
|
273 |
+
if control_mode == "mlsd":
|
274 |
+
preprocessor.load("MLSD")
|
275 |
+
control_image = preprocessor(
|
276 |
+
image=image_before,
|
277 |
+
image_resolution=width,
|
278 |
+
detect_resolution=512,
|
279 |
+
)
|
280 |
+
|
281 |
+
if control_mode == "scribble_hed":
|
282 |
+
preprocessor.load("HED")
|
283 |
+
control_image = preprocessor(
|
284 |
+
image=image_before,
|
285 |
+
image_resolution=image_resolution,
|
286 |
+
detect_resolution=preprocess_resolution,
|
287 |
+
)
|
288 |
+
|
289 |
+
control_mode_id = control_mode_ids[control_mode]
|
290 |
+
|
291 |
try:
|
292 |
generated_image = run_flux(
|
293 |
image=image,
|
294 |
mask=mask,
|
295 |
+
control_image=control_image,
|
296 |
+
control_mode=control_mode_id,
|
297 |
prompt=inpainting_prompt_text,
|
298 |
lora_path=lora_path,
|
299 |
lora_scale=lora_scale,
|
|
|
348 |
placeholder="Enter text to generate inpainting",
|
349 |
container=False,
|
350 |
)
|
351 |
+
|
352 |
+
control_mode = gr.Dropdown(
|
353 |
+
[ "canny", "depth_midas", "openpose", "mlsd", "low_quality", "gray", "blur", "tile"], label="Controlnet Model", info="choose controlnet model!", value="canny"
|
354 |
+
)
|
355 |
|
356 |
submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
|
357 |
|
|
|
459 |
inpainting_prompt_text_component,
|
460 |
mask_inflation_slider_component,
|
461 |
mask_blur_slider_component,
|
462 |
+
control_mode,
|
463 |
seed_slicer_component,
|
464 |
randomize_seed_checkbox_component,
|
465 |
strength_slider_component,
|
preprocessor.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import PIL.Image
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from controlnet_aux import (
|
8 |
+
CannyDetector,
|
9 |
+
ContentShuffleDetector,
|
10 |
+
HEDdetector,
|
11 |
+
LineartAnimeDetector,
|
12 |
+
LineartDetector,
|
13 |
+
MidasDetector,
|
14 |
+
MLSDdetector,
|
15 |
+
NormalBaeDetector,
|
16 |
+
OpenposeDetector,
|
17 |
+
PidiNetDetector,
|
18 |
+
)
|
19 |
+
from controlnet_aux.util import HWC3
|
20 |
+
|
21 |
+
from cv_utils import resize_image
|
22 |
+
from depth_estimator import DepthEstimator
|
23 |
+
from image_segmentor import ImageSegmentor
|
24 |
+
|
25 |
+
from kornia.core import Tensor
|
26 |
+
|
27 |
+
# load preprocessor
|
28 |
+
|
29 |
+
# HED = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
30 |
+
Midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
31 |
+
MLSD = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
32 |
+
Canny = CannyDetector()
|
33 |
+
OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
34 |
+
|
35 |
+
|
36 |
+
class Preprocessor:
|
37 |
+
MODEL_ID = "lllyasviel/Annotators"
|
38 |
+
|
39 |
+
def __init__(self):
|
40 |
+
self.model = None
|
41 |
+
self.name = ""
|
42 |
+
|
43 |
+
def load(self, name: str) -> None:
|
44 |
+
if name == self.name:
|
45 |
+
return
|
46 |
+
|
47 |
+
if name == "Midas":
|
48 |
+
self.model = Midas
|
49 |
+
elif name == "MLSD":
|
50 |
+
self.model =MLSD
|
51 |
+
elif name == "Openpose":
|
52 |
+
self.model = OPENPOSE
|
53 |
+
elif name == "Canny":
|
54 |
+
self.model = Canny
|
55 |
+
else:
|
56 |
+
raise ValueError
|
57 |
+
torch.cuda.empty_cache()
|
58 |
+
gc.collect()
|
59 |
+
self.name = name
|
60 |
+
|
61 |
+
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
|
62 |
+
if self.name == "Canny" or self.name == "MLSD":
|
63 |
+
detect_resolution = kwargs.pop("detect_resolution")
|
64 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
65 |
+
image = np.array(image)
|
66 |
+
image = HWC3(image)
|
67 |
+
image = resize_image(image, resolution=detect_resolution)
|
68 |
+
image = self.model(image, **kwargs)
|
69 |
+
image = np.array(image)
|
70 |
+
image = HWC3(image)
|
71 |
+
image = resize_image(image, resolution=image_resolution)
|
72 |
+
return PIL.Image.fromarray(image).convert('RGB')
|
73 |
+
|
74 |
+
else:
|
75 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
76 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
77 |
+
image = np.array(image)
|
78 |
+
image = HWC3(image)
|
79 |
+
image = resize_image(image, resolution=detect_resolution)
|
80 |
+
image = self.model(image, **kwargs)
|
81 |
+
image = np.array(image)
|
82 |
+
image = HWC3(image)
|
83 |
+
image = resize_image(image, resolution=image_resolution)
|
84 |
+
return PIL.Image.fromarray(image)
|
requirements.txt
CHANGED
@@ -16,4 +16,6 @@ requests
|
|
16 |
git+https://github.com/mylovelycodes/diffusers.git
|
17 |
boto3
|
18 |
sentencepiece
|
19 |
-
peft
|
|
|
|
|
|
16 |
git+https://github.com/mylovelycodes/diffusers.git
|
17 |
boto3
|
18 |
sentencepiece
|
19 |
+
peft
|
20 |
+
controlnet-aux
|
21 |
+
kornia
|