John6666 commited on
Commit
f1d6334
1 Parent(s): fc513dd

Upload 49 files

Browse files
README.md CHANGED
@@ -11,6 +11,8 @@ license: mit
11
  duplicated_from:
12
  - multimodalart/flux-lora-the-explorer
13
  - gokaygokay/FLUX-Prompt-Generator
 
 
14
  models:
15
  - black-forest-labs/FLUX.1-dev
16
  - alvdansen/frosting_lane_flux
 
11
  duplicated_from:
12
  - multimodalart/flux-lora-the-explorer
13
  - gokaygokay/FLUX-Prompt-Generator
14
+ - jiuface/FLUX.1-dev-Controlnet-Union
15
+ - DamarJati/FLUX.1-DEV-Canny
16
  models:
17
  - black-forest-labs/FLUX.1-dev
18
  - alvdansen/frosting_lane_flux
app.py CHANGED
@@ -12,7 +12,8 @@ import time
12
 
13
  from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
14
  description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
15
- get_trigger_word, pipe, enhance_prompt)
 
16
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
17
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
18
  update_loras)
@@ -64,26 +65,42 @@ def update_selection(evt: gr.SelectData, width, height):
64
  )
65
 
66
  @spaces.GPU(duration=70)
67
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
68
  pipe.to("cuda")
69
  generator = torch.Generator(device="cuda").manual_seed(seed)
70
 
71
  progress(0, desc="Start Inference.")
72
  with calculateDuration("Generating image"):
73
  # Generate image
74
- image = pipe(
75
- prompt=prompt_mash,
76
- num_inference_steps=steps,
77
- guidance_scale=cfg_scale,
78
- width=width,
79
- height=height,
80
- generator=generator,
81
- joint_attention_kwargs={"scale": lora_scale},
82
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return image
84
 
85
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
86
- lora_scale, lora_json, progress=gr.Progress(track_tqdm=True)):
87
  if selected_index is None and not is_valid_lora(lora_json):
88
  gr.Info("LoRA isn't selected.")
89
  # raise gr.Error("You must select a LoRA before proceeding.")
@@ -123,7 +140,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
123
 
124
  progress(1, desc="Preparing Inference.")
125
 
126
- image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
127
  if is_valid_lora(lora_json):
128
  pipe.unfuse_lora()
129
  pipe.unload_lora_weights()
@@ -318,7 +335,24 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
318
  lora_download = [None] * num_loras
319
  for i in range(num_loras):
320
  lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
321
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  gallery.select(
323
  update_selection,
324
  inputs=[width, height],
@@ -336,16 +370,21 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
336
  gr.on(
337
  triggers=[generate_button.click, prompt.submit],
338
  fn=change_base_model,
339
- inputs=[model_name],
340
  outputs=[result]
341
  ).success(
342
  fn=run_lora,
343
  inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
344
- lora_scale, lora_repo_json],
345
  outputs=[result, seed]
346
  )
347
 
348
- model_name.change(change_base_model, [model_name], [result])
 
 
 
 
 
349
  prompt_enhance.click(enhance_prompt, [prompt], [prompt])
350
 
351
  gr.on(
@@ -382,6 +421,16 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
382
  ).success(apply_lora_prompt, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
383
  ).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
384
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  tagger_generate_from_image.click(
387
  lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
 
12
 
13
  from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
14
  description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
15
+ get_trigger_word, enhance_prompt, pipe, controlnet, num_cns, set_control_union_image,
16
+ get_control_union_mode, set_control_union_mode, get_control_params)
17
  from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
18
  download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
19
  update_loras)
 
65
  )
66
 
67
  @spaces.GPU(duration=70)
68
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress):
69
  pipe.to("cuda")
70
  generator = torch.Generator(device="cuda").manual_seed(seed)
71
 
72
  progress(0, desc="Start Inference.")
73
  with calculateDuration("Generating image"):
74
  # Generate image
75
+ modes, images, scales = get_control_params()
76
+ if not cn_on or controlnet is None or len(modes) == 0:
77
+ image = pipe(
78
+ prompt=prompt_mash,
79
+ num_inference_steps=steps,
80
+ guidance_scale=cfg_scale,
81
+ width=width,
82
+ height=height,
83
+ generator=generator,
84
+ joint_attention_kwargs={"scale": lora_scale},
85
+ ).images[0]
86
+ else:
87
+
88
+ image = pipe(
89
+ prompt=prompt_mash,
90
+ control_image=images,
91
+ control_mode=modes,
92
+ num_inference_steps=steps,
93
+ guidance_scale=cfg_scale,
94
+ width=width,
95
+ height=height,
96
+ controlnet_conditioning_scale=scales,
97
+ generator=generator,
98
+ joint_attention_kwargs={"scale": lora_scale},
99
+ ).images[0]
100
  return image
101
 
102
  def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
103
+ lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
104
  if selected_index is None and not is_valid_lora(lora_json):
105
  gr.Info("LoRA isn't selected.")
106
  # raise gr.Error("You must select a LoRA before proceeding.")
 
140
 
141
  progress(1, desc="Preparing Inference.")
142
 
143
+ image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
144
  if is_valid_lora(lora_json):
145
  pipe.unfuse_lora()
146
  pipe.unload_lora_weights()
 
335
  lora_download = [None] * num_loras
336
  for i in range(num_loras):
337
  lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
338
+
339
+ with gr.Accordion("ControlNet", open=True):
340
+ with gr.Column():
341
+ cn_on = gr.Checkbox(False, label="Use ControlNet")
342
+ cn_mode = [None] * num_cns
343
+ cn_scale = [None] * num_cns
344
+ cn_image = [None] * num_cns
345
+ cn_res = [None] * num_cns
346
+ cn_num = [None] * num_cns
347
+ for i in range(num_cns):
348
+ with gr.Group():
349
+ with gr.Row():
350
+ cn_mode[i] = gr.Dropdown(label=f"ControlNet {int(i+1)} Mode", choices=get_control_union_mode(), value=get_control_union_mode()[0], allow_custom_value=False)
351
+ cn_scale[i] = gr.Slider(label=f"ControlNet {int(i+1)} Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.75)
352
+ cn_res[i] = gr.Slider(label=f"ControlNet {int(i+1)} Preprocess resolution", minimum=128, maximum=512, value=384, step=1)
353
+ cn_num[i] = gr.Number(i, visible=False)
354
+ cn_image[i] = gr.Image(type="pil", label="Control Image", height=256)
355
+
356
  gallery.select(
357
  update_selection,
358
  inputs=[width, height],
 
370
  gr.on(
371
  triggers=[generate_button.click, prompt.submit],
372
  fn=change_base_model,
373
+ inputs=[model_name, cn_on],
374
  outputs=[result]
375
  ).success(
376
  fn=run_lora,
377
  inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
378
+ lora_scale, lora_repo_json, cn_on],
379
  outputs=[result, seed]
380
  )
381
 
382
+ gr.on(
383
+ triggers=[model_name.change, cn_on.change],
384
+ fn=change_base_model,
385
+ inputs=[model_name, cn_on],
386
+ outputs=[result]
387
+ )
388
  prompt_enhance.click(enhance_prompt, [prompt], [prompt])
389
 
390
  gr.on(
 
421
  ).success(apply_lora_prompt, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
422
  ).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
423
 
424
+ for i, m in enumerate(cn_mode):
425
+ gr.on(
426
+ triggers=[cn_mode[i].change, cn_scale[i].change],
427
+ fn=set_control_union_mode,
428
+ inputs=[cn_num[i], cn_mode[i], cn_scale[i]],
429
+ outputs=[cn_on],
430
+ queue=True,
431
+ show_api=False,
432
+ )
433
+ cn_image[i].upload(set_control_union_image, [cn_num[i], cn_mode[i], cn_image[i], height, width, cn_res[i]], [cn_image[i]])
434
 
435
  tagger_generate_from_image.click(
436
  lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
cv_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ MAX_IMAGE_SIZE = 512
5
+
6
+ def resize_image(input_image, resolution=MAX_IMAGE_SIZE, interpolation=None):
7
+ H, W, C = input_image.shape
8
+ H = float(H)
9
+ W = float(W)
10
+ k = float(resolution) / max(H, W)
11
+ H *= k
12
+ W *= k
13
+ H = int(np.round(H / 64.0)) * 64
14
+ W = int(np.round(W / 64.0)) * 64
15
+ if interpolation is None:
16
+ interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
17
+ img = cv2.resize(input_image, (W, H), interpolation=interpolation)
18
+ return img
depth_estimator.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL.Image
3
+ from controlnet_aux.util import HWC3
4
+ from transformers import pipeline
5
+
6
+ from cv_utils import resize_image
7
+
8
+
9
+ class DepthEstimator:
10
+ def __init__(self):
11
+ self.model = pipeline("depth-estimation")
12
+
13
+ def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
14
+ return image
image_datasets/canny_dataset.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import json
8
+ import random
9
+ import cv2
10
+
11
+
12
+ def canny_processor(image, low_threshold=100, high_threshold=200):
13
+ image = np.array(image)
14
+ image = cv2.Canny(image, low_threshold, high_threshold)
15
+ image = image[:, :, None]
16
+ image = np.concatenate([image, image, image], axis=2)
17
+ canny_image = Image.fromarray(image)
18
+ return canny_image
19
+
20
+
21
+ def c_crop(image):
22
+ width, height = image.size
23
+ new_size = min(width, height)
24
+ left = (width - new_size) / 2
25
+ top = (height - new_size) / 2
26
+ right = (width + new_size) / 2
27
+ bottom = (height + new_size) / 2
28
+ return image.crop((left, top, right, bottom))
29
+
30
+ class CustomImageDataset(Dataset):
31
+ def __init__(self, img_dir, img_size=512):
32
+ self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
33
+ self.images.sort()
34
+ self.img_size = img_size
35
+
36
+ def __len__(self):
37
+ return len(self.images)
38
+
39
+ def __getitem__(self, idx):
40
+ try:
41
+ img = Image.open(self.images[idx])
42
+ img = c_crop(img)
43
+ img = img.resize((self.img_size, self.img_size))
44
+ hint = canny_processor(img)
45
+ img = torch.from_numpy((np.array(img) / 127.5) - 1)
46
+ img = img.permute(2, 0, 1)
47
+ hint = torch.from_numpy((np.array(hint) / 127.5) - 1)
48
+ hint = hint.permute(2, 0, 1)
49
+ json_path = self.images[idx].split('.')[0] + '.json'
50
+ prompt = json.load(open(json_path))['caption']
51
+ return img, hint, prompt
52
+ except Exception as e:
53
+ print(e)
54
+ return self.__getitem__(random.randint(0, len(self.images) - 1))
55
+
56
+
57
+ def loader(train_batch_size, num_workers, **args):
58
+ dataset = CustomImageDataset(**args)
59
+ return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers)
image_datasets/dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import json
8
+ import random
9
+
10
+ def c_crop(image):
11
+ width, height = image.size
12
+ new_size = min(width, height)
13
+ left = (width - new_size) / 2
14
+ top = (height - new_size) / 2
15
+ right = (width + new_size) / 2
16
+ bottom = (height + new_size) / 2
17
+ return image.crop((left, top, right, bottom))
18
+
19
+ class CustomImageDataset(Dataset):
20
+ def __init__(self, img_dir, img_size=512):
21
+ self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
22
+ self.images.sort()
23
+ self.img_size = img_size
24
+
25
+ def __len__(self):
26
+ return len(self.images)
27
+
28
+ def __getitem__(self, idx):
29
+ try:
30
+ img = Image.open(self.images[idx])
31
+ img = c_crop(img)
32
+ img = img.resize((self.img_size, self.img_size))
33
+ img = torch.from_numpy((np.array(img) / 127.5) - 1)
34
+ img = img.permute(2, 0, 1)
35
+ json_path = self.images[idx].split('.')[0] + '.json'
36
+ prompt = json.load(open(json_path))['caption']
37
+ return img, prompt
38
+ except Exception as e:
39
+ print(e)
40
+ return self.__getitem__(random.randint(0, len(self.images) - 1))
41
+
42
+
43
+ def loader(train_batch_size, num_workers, **args):
44
+ dataset = CustomImageDataset(**args)
45
+ return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
image_segmentor.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import PIL.Image
4
+ import torch
5
+ from controlnet_aux.util import HWC3, ade_palette
6
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
7
+
8
+ from cv_utils import resize_image
9
+
10
+
11
+ class ImageSegmentor:
12
+
13
+ def __init__(self):
14
+ self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
15
+ self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
16
+
17
+ @torch.no_grad()
18
+ def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
19
+ detect_resolution = kwargs.pop("detect_resolution", 512)
20
+ image_resolution = kwargs.pop("image_resolution", 512)
21
+ image = HWC3(image)
22
+ image = resize_image(image, resolution=detect_resolution)
23
+ image = PIL.Image.fromarray(image)
24
+
25
+ pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
26
+ outputs = self.image_segmentor(pixel_values)
27
+ seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
28
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
29
+ for label, color in enumerate(ade_palette()):
30
+ color_seg[seg == label, :] = color
31
+ color_seg = color_seg.astype(np.uint8)
32
+
33
+ color_seg = resize_image(color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST)
34
+ return PIL.Image.fromarray(color_seg)
mod.py CHANGED
@@ -2,9 +2,12 @@ import gradio as gr
2
  import torch
3
  import spaces
4
  from diffusers import DiffusionPipeline
 
 
5
  from pathlib import Path
6
  import gc
7
  import subprocess
 
8
 
9
 
10
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@@ -31,7 +34,17 @@ models = [
31
 
32
 
33
  num_loras = 3
34
-
 
 
 
 
 
 
 
 
 
 
35
 
36
  def is_repo_name(s):
37
  import re
@@ -70,26 +83,169 @@ def get_repo_safetensors(repo_id: str):
70
  else: return gr.update(value=files[0], choices=files)
71
 
72
 
73
- # Initialize the base model
74
- base_model = models[0]
75
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
76
- last_model = models[0]
77
-
78
- def change_base_model(repo_id: str, progress=gr.Progress(track_tqdm=True)):
79
  global pipe
 
80
  global last_model
 
81
  try:
82
- if repo_id == last_model or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
83
- progress(0, desc=f"Loading model: {repo_id}")
84
- clear_cache()
85
- pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
86
- last_model = repo_id
87
- progress(1, desc=f"Model loaded: {repo_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  except Exception as e:
89
  print(e)
90
  return gr.update(visible=True)
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):
94
  lorajson[i]["name"] = str(name) if name != "None" else ""
95
  lorajson[i]["scale"] = float(scale)
@@ -112,6 +268,7 @@ def get_trigger_word(lorajson: list[dict]):
112
  trigger += ", " + d["trigger"]
113
  return trigger
114
 
 
115
  # https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora
116
  # https://github.com/huggingface/diffusers/issues/4919
117
  def fuse_loras(pipe, lorajson: list[dict]):
@@ -139,13 +296,12 @@ def fuse_loras(pipe, lorajson: list[dict]):
139
  #pipe.unload_lora_weights()
140
 
141
 
142
-
143
-
144
-
145
  def description_ui():
146
  gr.Markdown(
147
  """
148
  - Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
 
 
149
  [gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
150
  """
151
  )
 
2
  import torch
3
  import spaces
4
  from diffusers import DiffusionPipeline
5
+ from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
6
+ from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
7
  from pathlib import Path
8
  import gc
9
  import subprocess
10
+ from PIL import Image
11
 
12
 
13
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
34
 
35
 
36
  num_loras = 3
37
+ num_cns = 2
38
+ # Initialize the base model
39
+ base_model = models[0]
40
+ controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
41
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
42
+ controlnet = None
43
+ control_images = [None] * num_cns
44
+ control_modes = [-1] * num_cns
45
+ control_scales = [0] * num_cns
46
+ last_model = models[0]
47
+ last_cn_on = False
48
 
49
  def is_repo_name(s):
50
  import re
 
83
  else: return gr.update(value=files[0], choices=files)
84
 
85
 
86
+ # https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny
87
+ # https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
88
+ # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
89
+ def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm=True)):
 
 
90
  global pipe
91
+ global controlnet
92
  global last_model
93
+ global last_cn_on
94
  try:
95
+ if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
96
+ if cn_on:
97
+ progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
98
+ print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
99
+ clear_cache()
100
+ controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=torch.bfloat16)
101
+ controlnet = FluxMultiControlNetModel([controlnet_union])
102
+ pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=torch.bfloat16)
103
+ last_model = repo_id
104
+ progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
105
+ print(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
106
+ else:
107
+ progress(0, desc=f"Loading model: {repo_id}")
108
+ print(f"Loading model: {repo_id}")
109
+ clear_cache()
110
+ pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
111
+ last_model = repo_id
112
+ progress(1, desc=f"Model loaded: {repo_id}")
113
+ print(f"Model loaded: {repo_id}")
114
  except Exception as e:
115
  print(e)
116
  return gr.update(visible=True)
117
 
118
 
119
+ # https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny/blob/main/app.py
120
+ def resize_image(image, target_width, target_height, crop=True):
121
+ from image_datasets.canny_dataset import c_crop
122
+ if crop:
123
+ image = c_crop(image) # Crop the image to square
124
+ original_width, original_height = image.size
125
+
126
+ # Resize to match the target size without stretching
127
+ scale = max(target_width / original_width, target_height / original_height)
128
+ resized_width = int(scale * original_width)
129
+ resized_height = int(scale * original_height)
130
+
131
+ image = image.resize((resized_width, resized_height), Image.LANCZOS)
132
+
133
+ # Center crop to match the target dimensions
134
+ left = (resized_width - target_width) // 2
135
+ top = (resized_height - target_height) // 2
136
+ image = image.crop((left, top, left + target_width, top + target_height))
137
+ else:
138
+ image = image.resize((target_width, target_height), Image.LANCZOS)
139
+
140
+ return image
141
+
142
+
143
+ # https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union/blob/main/app.py
144
+ controlnet_union_modes = {
145
+ "None": -1,
146
+ #"scribble_hed": 0,
147
+ "canny": 0, # supported
148
+ "mlsd": 0, #supported
149
+ "tile": 1, #supported
150
+ "depth_midas": 2, # supported
151
+ "blur": 3, # supported
152
+ "openpose": 4, # supported
153
+ "gray": 5, # supported
154
+ "low_quality": 6, # supported
155
+ }
156
+
157
+
158
+ def get_control_params():
159
+ modes = []
160
+ images = []
161
+ scales = []
162
+ for i, mode in enumerate(control_modes):
163
+ if mode == -1 or control_images[i] is None: continue
164
+ modes.append(control_modes[i])
165
+ images.append(control_images[i])
166
+ scales.append(control_scales[i])
167
+ return modes, images, scales
168
+
169
+
170
+ from preprocessor import Preprocessor
171
+ def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int, preprocess_resolution: int):
172
+ image_resolution = max(width, height)
173
+ image_before = resize_image(image, image_resolution, image_resolution, True)
174
+ # generated control_
175
+ print("start to generate control image")
176
+ preprocessor = Preprocessor()
177
+ if control_mode == "depth_midas":
178
+ preprocessor.load("Midas")
179
+ control_image = preprocessor(
180
+ image=image_before,
181
+ image_resolution=image_resolution,
182
+ detect_resolution=preprocess_resolution,
183
+ )
184
+ if control_mode == "openpose":
185
+ preprocessor.load("Openpose")
186
+ control_image = preprocessor(
187
+ image=image_before,
188
+ hand_and_face=True,
189
+ image_resolution=image_resolution,
190
+ detect_resolution=preprocess_resolution,
191
+ )
192
+ if control_mode == "canny":
193
+ preprocessor.load("Canny")
194
+ control_image = preprocessor(
195
+ image=image_before,
196
+ image_resolution=image_resolution,
197
+ detect_resolution=preprocess_resolution,
198
+ )
199
+
200
+ if control_mode == "mlsd":
201
+ preprocessor.load("MLSD")
202
+ control_image = preprocessor(
203
+ image=image_before,
204
+ image_resolution=image_resolution,
205
+ detect_resolution=preprocess_resolution,
206
+ )
207
+
208
+ if control_mode == "scribble_hed":
209
+ preprocessor.load("HED")
210
+ control_image = preprocessor(
211
+ image=image_before,
212
+ image_resolution=image_resolution,
213
+ detect_resolution=preprocess_resolution,
214
+ )
215
+
216
+ if control_mode == "low_quality" or control_mode == "gray" or control_mode == "blur" or control_mode == "tile":
217
+ control_image = image_before
218
+ image_width = 768
219
+ image_height = 768
220
+ else:
221
+ # make sure control image size is same as resized_image
222
+ image_width, image_height = control_image.size
223
+
224
+ image_after = resize_image(control_image, width, height, True)
225
+ print(f"generate control image success: {image_width}x{image_height} => {width}x{height}")
226
+
227
+ return image_after
228
+
229
+
230
+ def get_control_union_mode():
231
+ return list(controlnet_union_modes.keys())
232
+
233
+
234
+ def set_control_union_mode(i: int, mode: str, scale: str):
235
+ global control_modes
236
+ global control_scales
237
+ control_modes[i] = controlnet_union_modes.get(mode, 0)
238
+ control_scales[i] = scale
239
+ if mode != "None": return True
240
+ else: return gr.update(visible=True)
241
+
242
+
243
+ def set_control_union_image(i: int, mode: str, image: Image.Image, height: int, width: int, preprocess_resolution: int):
244
+ global control_images
245
+ control_images[i] = preprocess_image(image, mode, height, width, preprocess_resolution)
246
+ return control_images[i]
247
+
248
+
249
  def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):
250
  lorajson[i]["name"] = str(name) if name != "None" else ""
251
  lorajson[i]["scale"] = float(scale)
 
268
  trigger += ", " + d["trigger"]
269
  return trigger
270
 
271
+
272
  # https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora
273
  # https://github.com/huggingface/diffusers/issues/4919
274
  def fuse_loras(pipe, lorajson: list[dict]):
 
296
  #pipe.unload_lora_weights()
297
 
298
 
 
 
 
299
  def description_ui():
300
  gr.Markdown(
301
  """
302
  - Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
303
+ [jiuface/FLUX.1-dev-Controlnet-Union](https://huggingface.co/spaces/jiuface/),
304
+ [DamarJati/FLUX.1-DEV-Canny](https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny),
305
  [gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
306
  """
307
  )
preprocessor.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import PIL.Image
5
+ import torch
6
+ import torchvision
7
+ from controlnet_aux import (
8
+ CannyDetector,
9
+ ContentShuffleDetector,
10
+ HEDdetector,
11
+ LineartAnimeDetector,
12
+ LineartDetector,
13
+ MidasDetector,
14
+ MLSDdetector,
15
+ NormalBaeDetector,
16
+ OpenposeDetector,
17
+ PidiNetDetector,
18
+ )
19
+ from controlnet_aux.util import HWC3
20
+
21
+ from cv_utils import resize_image
22
+ from depth_estimator import DepthEstimator
23
+ from image_segmentor import ImageSegmentor
24
+
25
+ from kornia.core import Tensor
26
+
27
+ # load preprocessor
28
+
29
+ # HED = HEDdetector.from_pretrained("lllyasviel/Annotators")
30
+ Midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
31
+ MLSD = MLSDdetector.from_pretrained("lllyasviel/Annotators")
32
+ Canny = CannyDetector()
33
+ OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
34
+
35
+
36
+ class Preprocessor:
37
+ MODEL_ID = "lllyasviel/Annotators"
38
+
39
+ def __init__(self):
40
+ self.model = None
41
+ self.name = ""
42
+
43
+ def load(self, name: str) -> None:
44
+ if name == self.name:
45
+ return
46
+
47
+ if name == "Midas":
48
+ self.model = Midas
49
+ elif name == "MLSD":
50
+ self.model =MLSD
51
+ elif name == "Openpose":
52
+ self.model = OPENPOSE
53
+ elif name == "Canny":
54
+ self.model = Canny
55
+ else:
56
+ raise ValueError
57
+ torch.cuda.empty_cache()
58
+ gc.collect()
59
+ self.name = name
60
+
61
+ def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
62
+ if self.name == "Canny" or self.name == "MLSD":
63
+ detect_resolution = kwargs.pop("detect_resolution")
64
+ image_resolution = kwargs.pop("image_resolution", 512)
65
+ image = np.array(image)
66
+ image = HWC3(image)
67
+ image = resize_image(image, resolution=detect_resolution)
68
+ image = self.model(image, **kwargs)
69
+ image = np.array(image)
70
+ image = HWC3(image)
71
+ image = resize_image(image, resolution=image_resolution)
72
+ return PIL.Image.fromarray(image).convert('RGB')
73
+
74
+ else:
75
+ detect_resolution = kwargs.pop("detect_resolution", 512)
76
+ image_resolution = kwargs.pop("image_resolution", 512)
77
+ image = np.array(image)
78
+ image = HWC3(image)
79
+ image = resize_image(image, resolution=detect_resolution)
80
+ image = self.model(image, **kwargs)
81
+ image = np.array(image)
82
+ image = HWC3(image)
83
+ image = resize_image(image, resolution=image_resolution)
84
+ return PIL.Image.fromarray(image)
requirements.txt CHANGED
@@ -8,5 +8,7 @@ transformers
8
  peft
9
  sentencepiece
10
  timm
11
- xformers
12
- einops
 
 
 
8
  peft
9
  sentencepiece
10
  timm
11
+ einops
12
+ controlnet-aux
13
+ kornia
14
+ numpy