jiuface commited on
Commit
5903bf0
1 Parent(s): db8c7ad

add controlnet inpaint

Browse files
Files changed (3) hide show
  1. app.py +79 -1
  2. preprocessor.py +84 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -20,6 +20,8 @@ from io import BytesIO
20
  from datetime import datetime
21
  from diffusers.utils import load_image
22
  import json
 
 
23
 
24
  HF_TOKEN = os.environ.get("HF_TOKEN")
25
 
@@ -33,9 +35,27 @@ dtype = torch.bfloat16
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  base_model = "black-forest-labs/FLUX.1-dev"
35
 
 
 
 
36
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
37
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
38
- pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  class calculateDuration:
@@ -129,6 +149,8 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
129
  def run_flux(
130
  image: Image.Image,
131
  mask: Image.Image,
 
 
132
  prompt: str,
133
  lora_path: str,
134
  lora_weights: str,
@@ -157,6 +179,8 @@ def run_flux(
157
  prompt=prompt,
158
  image=image,
159
  mask_image=mask,
 
 
160
  width=width,
161
  height=height,
162
  strength=strength_slider,
@@ -175,6 +199,7 @@ def process(
175
  inpainting_prompt_text: str,
176
  mask_inflation_slider: int,
177
  mask_blur_slider: int,
 
178
  seed_slicer: int,
179
  randomize_seed_checkbox: bool,
180
  strength_slider: float,
@@ -217,10 +242,58 @@ def process(
217
  mask = mask.resize((width, height), Image.LANCZOS)
218
  mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  try:
221
  generated_image = run_flux(
222
  image=image,
223
  mask=mask,
 
 
224
  prompt=inpainting_prompt_text,
225
  lora_path=lora_path,
226
  lora_scale=lora_scale,
@@ -275,6 +348,10 @@ with gr.Blocks() as demo:
275
  placeholder="Enter text to generate inpainting",
276
  container=False,
277
  )
 
 
 
 
278
 
279
  submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
280
 
@@ -382,6 +459,7 @@ with gr.Blocks() as demo:
382
  inpainting_prompt_text_component,
383
  mask_inflation_slider_component,
384
  mask_blur_slider_component,
 
385
  seed_slicer_component,
386
  randomize_seed_checkbox_component,
387
  strength_slider_component,
 
20
  from datetime import datetime
21
  from diffusers.utils import load_image
22
  import json
23
+ from preprocessor import Preprocessor
24
+ from diffusers.pipelines.flux.pipeline_flux_controlnet_inpaint import FluxControlNetInpaintPipeline
25
 
26
  HF_TOKEN = os.environ.get("HF_TOKEN")
27
 
 
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  base_model = "black-forest-labs/FLUX.1-dev"
37
 
38
+ controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union-alpha'
39
+ controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
40
+
41
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
42
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
43
+ pipe = FluxControlNetInpaintPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=dtype, vae=taef1).to(device)
44
+
45
+
46
+
47
+ control_mode_ids = {
48
+ "scribble_hed": 0,
49
+ "canny": 0, # supported
50
+ "mlsd": 0, # supported
51
+ "tile": 1, # supported
52
+ "depth_midas": 2, # supported
53
+ "blur": 3, # supported
54
+ "openpose": 4, # supported
55
+ "gray": 5, # supported
56
+ "low_quality": 6, # supported
57
+ }
58
+
59
 
60
 
61
  class calculateDuration:
 
149
  def run_flux(
150
  image: Image.Image,
151
  mask: Image.Image,
152
+ control_image: Image.Image,
153
+ control_mode: int,
154
  prompt: str,
155
  lora_path: str,
156
  lora_weights: str,
 
179
  prompt=prompt,
180
  image=image,
181
  mask_image=mask,
182
+ control_image=control_image,
183
+ control_mode=control_mode,
184
  width=width,
185
  height=height,
186
  strength=strength_slider,
 
199
  inpainting_prompt_text: str,
200
  mask_inflation_slider: int,
201
  mask_blur_slider: int,
202
+ control_mode: str,
203
  seed_slicer: int,
204
  randomize_seed_checkbox: bool,
205
  strength_slider: float,
 
242
  mask = mask.resize((width, height), Image.LANCZOS)
243
  mask = process_mask(mask, mask_inflation=mask_inflation_slider, mask_blur=mask_blur_slider)
244
 
245
+
246
+ # generated control_
247
+ with calculateDuration("Preprocessor Image"):
248
+ print("start to generate control image")
249
+ preprocessor = Preprocessor()
250
+ if control_mode == "depth_midas":
251
+ preprocessor.load("Midas")
252
+ control_image = preprocessor(
253
+ image=image,
254
+ image_resolution=width,
255
+ detect_resolution=512,
256
+ )
257
+ if control_mode == "openpose":
258
+ preprocessor.load("Openpose")
259
+ control_image = preprocessor(
260
+ image=image,
261
+ hand_and_face=True,
262
+ image_resolution=width,
263
+ detect_resolution=512,
264
+ )
265
+ if control_mode == "canny":
266
+ preprocessor.load("Canny")
267
+ control_image = preprocessor(
268
+ image=image,
269
+ image_resolution=width,
270
+ detect_resolution=512,
271
+ )
272
+
273
+ if control_mode == "mlsd":
274
+ preprocessor.load("MLSD")
275
+ control_image = preprocessor(
276
+ image=image_before,
277
+ image_resolution=width,
278
+ detect_resolution=512,
279
+ )
280
+
281
+ if control_mode == "scribble_hed":
282
+ preprocessor.load("HED")
283
+ control_image = preprocessor(
284
+ image=image_before,
285
+ image_resolution=image_resolution,
286
+ detect_resolution=preprocess_resolution,
287
+ )
288
+
289
+ control_mode_id = control_mode_ids[control_mode]
290
+
291
  try:
292
  generated_image = run_flux(
293
  image=image,
294
  mask=mask,
295
+ control_image=control_image,
296
+ control_mode=control_mode_id,
297
  prompt=inpainting_prompt_text,
298
  lora_path=lora_path,
299
  lora_scale=lora_scale,
 
348
  placeholder="Enter text to generate inpainting",
349
  container=False,
350
  )
351
+
352
+ control_mode = gr.Dropdown(
353
+ [ "canny", "depth_midas", "openpose", "mlsd", "low_quality", "gray", "blur", "tile"], label="Controlnet Model", info="choose controlnet model!", value="canny"
354
+ )
355
 
356
  submit_button_component = gr.Button(value='Submit', variant='primary', scale=0)
357
 
 
459
  inpainting_prompt_text_component,
460
  mask_inflation_slider_component,
461
  mask_blur_slider_component,
462
+ control_mode,
463
  seed_slicer_component,
464
  randomize_seed_checkbox_component,
465
  strength_slider_component,
preprocessor.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import PIL.Image
5
+ import torch
6
+ import torchvision
7
+ from controlnet_aux import (
8
+ CannyDetector,
9
+ ContentShuffleDetector,
10
+ HEDdetector,
11
+ LineartAnimeDetector,
12
+ LineartDetector,
13
+ MidasDetector,
14
+ MLSDdetector,
15
+ NormalBaeDetector,
16
+ OpenposeDetector,
17
+ PidiNetDetector,
18
+ )
19
+ from controlnet_aux.util import HWC3
20
+
21
+ from cv_utils import resize_image
22
+ from depth_estimator import DepthEstimator
23
+ from image_segmentor import ImageSegmentor
24
+
25
+ from kornia.core import Tensor
26
+
27
+ # load preprocessor
28
+
29
+ # HED = HEDdetector.from_pretrained("lllyasviel/Annotators")
30
+ Midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
31
+ MLSD = MLSDdetector.from_pretrained("lllyasviel/Annotators")
32
+ Canny = CannyDetector()
33
+ OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
34
+
35
+
36
+ class Preprocessor:
37
+ MODEL_ID = "lllyasviel/Annotators"
38
+
39
+ def __init__(self):
40
+ self.model = None
41
+ self.name = ""
42
+
43
+ def load(self, name: str) -> None:
44
+ if name == self.name:
45
+ return
46
+
47
+ if name == "Midas":
48
+ self.model = Midas
49
+ elif name == "MLSD":
50
+ self.model =MLSD
51
+ elif name == "Openpose":
52
+ self.model = OPENPOSE
53
+ elif name == "Canny":
54
+ self.model = Canny
55
+ else:
56
+ raise ValueError
57
+ torch.cuda.empty_cache()
58
+ gc.collect()
59
+ self.name = name
60
+
61
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
62
+ if self.name == "Canny" or self.name == "MLSD":
63
+ detect_resolution = kwargs.pop("detect_resolution")
64
+ image_resolution = kwargs.pop("image_resolution", 512)
65
+ image = np.array(image)
66
+ image = HWC3(image)
67
+ image = resize_image(image, resolution=detect_resolution)
68
+ image = self.model(image, **kwargs)
69
+ image = np.array(image)
70
+ image = HWC3(image)
71
+ image = resize_image(image, resolution=image_resolution)
72
+ return PIL.Image.fromarray(image).convert('RGB')
73
+
74
+ else:
75
+ detect_resolution = kwargs.pop("detect_resolution", 512)
76
+ image_resolution = kwargs.pop("image_resolution", 512)
77
+ image = np.array(image)
78
+ image = HWC3(image)
79
+ image = resize_image(image, resolution=detect_resolution)
80
+ image = self.model(image, **kwargs)
81
+ image = np.array(image)
82
+ image = HWC3(image)
83
+ image = resize_image(image, resolution=image_resolution)
84
+ return PIL.Image.fromarray(image)
requirements.txt CHANGED
@@ -16,4 +16,6 @@ requests
16
  git+https://github.com/mylovelycodes/diffusers.git
17
  boto3
18
  sentencepiece
19
- peft
 
 
 
16
  git+https://github.com/mylovelycodes/diffusers.git
17
  boto3
18
  sentencepiece
19
+ peft
20
+ controlnet-aux
21
+ kornia