Spaces:
Running
Running
Upload 49 files
Browse files- README.md +2 -0
- app.py +66 -17
- cv_utils.py +18 -0
- depth_estimator.py +14 -0
- image_datasets/canny_dataset.py +59 -0
- image_datasets/dataset.py +45 -0
- image_segmentor.py +34 -0
- mod.py +172 -16
- preprocessor.py +84 -0
- requirements.txt +4 -2
README.md
CHANGED
@@ -11,6 +11,8 @@ license: mit
|
|
11 |
duplicated_from:
|
12 |
- multimodalart/flux-lora-the-explorer
|
13 |
- gokaygokay/FLUX-Prompt-Generator
|
|
|
|
|
14 |
models:
|
15 |
- black-forest-labs/FLUX.1-dev
|
16 |
- alvdansen/frosting_lane_flux
|
|
|
11 |
duplicated_from:
|
12 |
- multimodalart/flux-lora-the-explorer
|
13 |
- gokaygokay/FLUX-Prompt-Generator
|
14 |
+
- jiuface/FLUX.1-dev-Controlnet-Union
|
15 |
+
- DamarJati/FLUX.1-DEV-Canny
|
16 |
models:
|
17 |
- black-forest-labs/FLUX.1-dev
|
18 |
- alvdansen/frosting_lane_flux
|
app.py
CHANGED
@@ -12,7 +12,8 @@ import time
|
|
12 |
|
13 |
from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
|
14 |
description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
|
15 |
-
get_trigger_word, pipe,
|
|
|
16 |
from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
|
17 |
download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
|
18 |
update_loras)
|
@@ -64,26 +65,42 @@ def update_selection(evt: gr.SelectData, width, height):
|
|
64 |
)
|
65 |
|
66 |
@spaces.GPU(duration=70)
|
67 |
-
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress):
|
68 |
pipe.to("cuda")
|
69 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
70 |
|
71 |
progress(0, desc="Start Inference.")
|
72 |
with calculateDuration("Generating image"):
|
73 |
# Generate image
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
return image
|
84 |
|
85 |
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
|
86 |
-
lora_scale, lora_json, progress=gr.Progress(track_tqdm=True)):
|
87 |
if selected_index is None and not is_valid_lora(lora_json):
|
88 |
gr.Info("LoRA isn't selected.")
|
89 |
# raise gr.Error("You must select a LoRA before proceeding.")
|
@@ -123,7 +140,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, wid
|
|
123 |
|
124 |
progress(1, desc="Preparing Inference.")
|
125 |
|
126 |
-
image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
|
127 |
if is_valid_lora(lora_json):
|
128 |
pipe.unfuse_lora()
|
129 |
pipe.unload_lora_weights()
|
@@ -318,7 +335,24 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
318 |
lora_download = [None] * num_loras
|
319 |
for i in range(num_loras):
|
320 |
lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
gallery.select(
|
323 |
update_selection,
|
324 |
inputs=[width, height],
|
@@ -336,16 +370,21 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
336 |
gr.on(
|
337 |
triggers=[generate_button.click, prompt.submit],
|
338 |
fn=change_base_model,
|
339 |
-
inputs=[model_name],
|
340 |
outputs=[result]
|
341 |
).success(
|
342 |
fn=run_lora,
|
343 |
inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
|
344 |
-
lora_scale, lora_repo_json],
|
345 |
outputs=[result, seed]
|
346 |
)
|
347 |
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
349 |
prompt_enhance.click(enhance_prompt, [prompt], [prompt])
|
350 |
|
351 |
gr.on(
|
@@ -382,6 +421,16 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css) as app:
|
|
382 |
).success(apply_lora_prompt, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
|
383 |
).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
tagger_generate_from_image.click(
|
387 |
lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
|
|
|
12 |
|
13 |
from mod import (models, clear_cache, get_repo_safetensors, change_base_model,
|
14 |
description_ui, num_loras, compose_lora_json, is_valid_lora, fuse_loras,
|
15 |
+
get_trigger_word, enhance_prompt, pipe, controlnet, num_cns, set_control_union_image,
|
16 |
+
get_control_union_mode, set_control_union_mode, get_control_params)
|
17 |
from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
|
18 |
download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
|
19 |
update_loras)
|
|
|
65 |
)
|
66 |
|
67 |
@spaces.GPU(duration=70)
|
68 |
+
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress):
|
69 |
pipe.to("cuda")
|
70 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
71 |
|
72 |
progress(0, desc="Start Inference.")
|
73 |
with calculateDuration("Generating image"):
|
74 |
# Generate image
|
75 |
+
modes, images, scales = get_control_params()
|
76 |
+
if not cn_on or controlnet is None or len(modes) == 0:
|
77 |
+
image = pipe(
|
78 |
+
prompt=prompt_mash,
|
79 |
+
num_inference_steps=steps,
|
80 |
+
guidance_scale=cfg_scale,
|
81 |
+
width=width,
|
82 |
+
height=height,
|
83 |
+
generator=generator,
|
84 |
+
joint_attention_kwargs={"scale": lora_scale},
|
85 |
+
).images[0]
|
86 |
+
else:
|
87 |
+
|
88 |
+
image = pipe(
|
89 |
+
prompt=prompt_mash,
|
90 |
+
control_image=images,
|
91 |
+
control_mode=modes,
|
92 |
+
num_inference_steps=steps,
|
93 |
+
guidance_scale=cfg_scale,
|
94 |
+
width=width,
|
95 |
+
height=height,
|
96 |
+
controlnet_conditioning_scale=scales,
|
97 |
+
generator=generator,
|
98 |
+
joint_attention_kwargs={"scale": lora_scale},
|
99 |
+
).images[0]
|
100 |
return image
|
101 |
|
102 |
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
|
103 |
+
lora_scale, lora_json, cn_on, progress=gr.Progress(track_tqdm=True)):
|
104 |
if selected_index is None and not is_valid_lora(lora_json):
|
105 |
gr.Info("LoRA isn't selected.")
|
106 |
# raise gr.Error("You must select a LoRA before proceeding.")
|
|
|
140 |
|
141 |
progress(1, desc="Preparing Inference.")
|
142 |
|
143 |
+
image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
|
144 |
if is_valid_lora(lora_json):
|
145 |
pipe.unfuse_lora()
|
146 |
pipe.unload_lora_weights()
|
|
|
335 |
lora_download = [None] * num_loras
|
336 |
for i in range(num_loras):
|
337 |
lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
|
338 |
+
|
339 |
+
with gr.Accordion("ControlNet", open=True):
|
340 |
+
with gr.Column():
|
341 |
+
cn_on = gr.Checkbox(False, label="Use ControlNet")
|
342 |
+
cn_mode = [None] * num_cns
|
343 |
+
cn_scale = [None] * num_cns
|
344 |
+
cn_image = [None] * num_cns
|
345 |
+
cn_res = [None] * num_cns
|
346 |
+
cn_num = [None] * num_cns
|
347 |
+
for i in range(num_cns):
|
348 |
+
with gr.Group():
|
349 |
+
with gr.Row():
|
350 |
+
cn_mode[i] = gr.Dropdown(label=f"ControlNet {int(i+1)} Mode", choices=get_control_union_mode(), value=get_control_union_mode()[0], allow_custom_value=False)
|
351 |
+
cn_scale[i] = gr.Slider(label=f"ControlNet {int(i+1)} Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.75)
|
352 |
+
cn_res[i] = gr.Slider(label=f"ControlNet {int(i+1)} Preprocess resolution", minimum=128, maximum=512, value=384, step=1)
|
353 |
+
cn_num[i] = gr.Number(i, visible=False)
|
354 |
+
cn_image[i] = gr.Image(type="pil", label="Control Image", height=256)
|
355 |
+
|
356 |
gallery.select(
|
357 |
update_selection,
|
358 |
inputs=[width, height],
|
|
|
370 |
gr.on(
|
371 |
triggers=[generate_button.click, prompt.submit],
|
372 |
fn=change_base_model,
|
373 |
+
inputs=[model_name, cn_on],
|
374 |
outputs=[result]
|
375 |
).success(
|
376 |
fn=run_lora,
|
377 |
inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
|
378 |
+
lora_scale, lora_repo_json, cn_on],
|
379 |
outputs=[result, seed]
|
380 |
)
|
381 |
|
382 |
+
gr.on(
|
383 |
+
triggers=[model_name.change, cn_on.change],
|
384 |
+
fn=change_base_model,
|
385 |
+
inputs=[model_name, cn_on],
|
386 |
+
outputs=[result]
|
387 |
+
)
|
388 |
prompt_enhance.click(enhance_prompt, [prompt], [prompt])
|
389 |
|
390 |
gr.on(
|
|
|
421 |
).success(apply_lora_prompt, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
|
422 |
).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
|
423 |
|
424 |
+
for i, m in enumerate(cn_mode):
|
425 |
+
gr.on(
|
426 |
+
triggers=[cn_mode[i].change, cn_scale[i].change],
|
427 |
+
fn=set_control_union_mode,
|
428 |
+
inputs=[cn_num[i], cn_mode[i], cn_scale[i]],
|
429 |
+
outputs=[cn_on],
|
430 |
+
queue=True,
|
431 |
+
show_api=False,
|
432 |
+
)
|
433 |
+
cn_image[i].upload(set_control_union_image, [cn_num[i], cn_mode[i], cn_image[i], height, width, cn_res[i]], [cn_image[i]])
|
434 |
|
435 |
tagger_generate_from_image.click(
|
436 |
lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
|
cv_utils.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
MAX_IMAGE_SIZE = 512
|
5 |
+
|
6 |
+
def resize_image(input_image, resolution=MAX_IMAGE_SIZE, interpolation=None):
|
7 |
+
H, W, C = input_image.shape
|
8 |
+
H = float(H)
|
9 |
+
W = float(W)
|
10 |
+
k = float(resolution) / max(H, W)
|
11 |
+
H *= k
|
12 |
+
W *= k
|
13 |
+
H = int(np.round(H / 64.0)) * 64
|
14 |
+
W = int(np.round(W / 64.0)) * 64
|
15 |
+
if interpolation is None:
|
16 |
+
interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
|
17 |
+
img = cv2.resize(input_image, (W, H), interpolation=interpolation)
|
18 |
+
return img
|
depth_estimator.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import PIL.Image
|
3 |
+
from controlnet_aux.util import HWC3
|
4 |
+
from transformers import pipeline
|
5 |
+
|
6 |
+
from cv_utils import resize_image
|
7 |
+
|
8 |
+
|
9 |
+
class DepthEstimator:
|
10 |
+
def __init__(self):
|
11 |
+
self.model = pipeline("depth-estimation")
|
12 |
+
|
13 |
+
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
14 |
+
return image
|
image_datasets/canny_dataset.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
|
12 |
+
def canny_processor(image, low_threshold=100, high_threshold=200):
|
13 |
+
image = np.array(image)
|
14 |
+
image = cv2.Canny(image, low_threshold, high_threshold)
|
15 |
+
image = image[:, :, None]
|
16 |
+
image = np.concatenate([image, image, image], axis=2)
|
17 |
+
canny_image = Image.fromarray(image)
|
18 |
+
return canny_image
|
19 |
+
|
20 |
+
|
21 |
+
def c_crop(image):
|
22 |
+
width, height = image.size
|
23 |
+
new_size = min(width, height)
|
24 |
+
left = (width - new_size) / 2
|
25 |
+
top = (height - new_size) / 2
|
26 |
+
right = (width + new_size) / 2
|
27 |
+
bottom = (height + new_size) / 2
|
28 |
+
return image.crop((left, top, right, bottom))
|
29 |
+
|
30 |
+
class CustomImageDataset(Dataset):
|
31 |
+
def __init__(self, img_dir, img_size=512):
|
32 |
+
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
|
33 |
+
self.images.sort()
|
34 |
+
self.img_size = img_size
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.images)
|
38 |
+
|
39 |
+
def __getitem__(self, idx):
|
40 |
+
try:
|
41 |
+
img = Image.open(self.images[idx])
|
42 |
+
img = c_crop(img)
|
43 |
+
img = img.resize((self.img_size, self.img_size))
|
44 |
+
hint = canny_processor(img)
|
45 |
+
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
46 |
+
img = img.permute(2, 0, 1)
|
47 |
+
hint = torch.from_numpy((np.array(hint) / 127.5) - 1)
|
48 |
+
hint = hint.permute(2, 0, 1)
|
49 |
+
json_path = self.images[idx].split('.')[0] + '.json'
|
50 |
+
prompt = json.load(open(json_path))['caption']
|
51 |
+
return img, hint, prompt
|
52 |
+
except Exception as e:
|
53 |
+
print(e)
|
54 |
+
return self.__getitem__(random.randint(0, len(self.images) - 1))
|
55 |
+
|
56 |
+
|
57 |
+
def loader(train_batch_size, num_workers, **args):
|
58 |
+
dataset = CustomImageDataset(**args)
|
59 |
+
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers)
|
image_datasets/dataset.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
|
10 |
+
def c_crop(image):
|
11 |
+
width, height = image.size
|
12 |
+
new_size = min(width, height)
|
13 |
+
left = (width - new_size) / 2
|
14 |
+
top = (height - new_size) / 2
|
15 |
+
right = (width + new_size) / 2
|
16 |
+
bottom = (height + new_size) / 2
|
17 |
+
return image.crop((left, top, right, bottom))
|
18 |
+
|
19 |
+
class CustomImageDataset(Dataset):
|
20 |
+
def __init__(self, img_dir, img_size=512):
|
21 |
+
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
|
22 |
+
self.images.sort()
|
23 |
+
self.img_size = img_size
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.images)
|
27 |
+
|
28 |
+
def __getitem__(self, idx):
|
29 |
+
try:
|
30 |
+
img = Image.open(self.images[idx])
|
31 |
+
img = c_crop(img)
|
32 |
+
img = img.resize((self.img_size, self.img_size))
|
33 |
+
img = torch.from_numpy((np.array(img) / 127.5) - 1)
|
34 |
+
img = img.permute(2, 0, 1)
|
35 |
+
json_path = self.images[idx].split('.')[0] + '.json'
|
36 |
+
prompt = json.load(open(json_path))['caption']
|
37 |
+
return img, prompt
|
38 |
+
except Exception as e:
|
39 |
+
print(e)
|
40 |
+
return self.__getitem__(random.randint(0, len(self.images) - 1))
|
41 |
+
|
42 |
+
|
43 |
+
def loader(train_batch_size, num_workers, **args):
|
44 |
+
dataset = CustomImageDataset(**args)
|
45 |
+
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
|
image_segmentor.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import PIL.Image
|
4 |
+
import torch
|
5 |
+
from controlnet_aux.util import HWC3, ade_palette
|
6 |
+
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
|
7 |
+
|
8 |
+
from cv_utils import resize_image
|
9 |
+
|
10 |
+
|
11 |
+
class ImageSegmentor:
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
|
15 |
+
self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
|
16 |
+
|
17 |
+
@torch.no_grad()
|
18 |
+
def __call__(self, image: np.ndarray, **kwargs) -> PIL.Image.Image:
|
19 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
20 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
21 |
+
image = HWC3(image)
|
22 |
+
image = resize_image(image, resolution=detect_resolution)
|
23 |
+
image = PIL.Image.fromarray(image)
|
24 |
+
|
25 |
+
pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
|
26 |
+
outputs = self.image_segmentor(pixel_values)
|
27 |
+
seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
28 |
+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
29 |
+
for label, color in enumerate(ade_palette()):
|
30 |
+
color_seg[seg == label, :] = color
|
31 |
+
color_seg = color_seg.astype(np.uint8)
|
32 |
+
|
33 |
+
color_seg = resize_image(color_seg, resolution=image_resolution, interpolation=cv2.INTER_NEAREST)
|
34 |
+
return PIL.Image.fromarray(color_seg)
|
mod.py
CHANGED
@@ -2,9 +2,12 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import spaces
|
4 |
from diffusers import DiffusionPipeline
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
import gc
|
7 |
import subprocess
|
|
|
8 |
|
9 |
|
10 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
@@ -31,7 +34,17 @@ models = [
|
|
31 |
|
32 |
|
33 |
num_loras = 3
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def is_repo_name(s):
|
37 |
import re
|
@@ -70,26 +83,169 @@ def get_repo_safetensors(repo_id: str):
|
|
70 |
else: return gr.update(value=files[0], choices=files)
|
71 |
|
72 |
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
def change_base_model(repo_id: str, progress=gr.Progress(track_tqdm=True)):
|
79 |
global pipe
|
|
|
80 |
global last_model
|
|
|
81 |
try:
|
82 |
-
if repo_id == last_model or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
except Exception as e:
|
89 |
print(e)
|
90 |
return gr.update(visible=True)
|
91 |
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):
|
94 |
lorajson[i]["name"] = str(name) if name != "None" else ""
|
95 |
lorajson[i]["scale"] = float(scale)
|
@@ -112,6 +268,7 @@ def get_trigger_word(lorajson: list[dict]):
|
|
112 |
trigger += ", " + d["trigger"]
|
113 |
return trigger
|
114 |
|
|
|
115 |
# https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora
|
116 |
# https://github.com/huggingface/diffusers/issues/4919
|
117 |
def fuse_loras(pipe, lorajson: list[dict]):
|
@@ -139,13 +296,12 @@ def fuse_loras(pipe, lorajson: list[dict]):
|
|
139 |
#pipe.unload_lora_weights()
|
140 |
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
def description_ui():
|
146 |
gr.Markdown(
|
147 |
"""
|
148 |
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
|
|
|
|
|
149 |
[gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
|
150 |
"""
|
151 |
)
|
|
|
2 |
import torch
|
3 |
import spaces
|
4 |
from diffusers import DiffusionPipeline
|
5 |
+
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
|
6 |
+
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
7 |
from pathlib import Path
|
8 |
import gc
|
9 |
import subprocess
|
10 |
+
from PIL import Image
|
11 |
|
12 |
|
13 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
|
|
34 |
|
35 |
|
36 |
num_loras = 3
|
37 |
+
num_cns = 2
|
38 |
+
# Initialize the base model
|
39 |
+
base_model = models[0]
|
40 |
+
controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
|
41 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
42 |
+
controlnet = None
|
43 |
+
control_images = [None] * num_cns
|
44 |
+
control_modes = [-1] * num_cns
|
45 |
+
control_scales = [0] * num_cns
|
46 |
+
last_model = models[0]
|
47 |
+
last_cn_on = False
|
48 |
|
49 |
def is_repo_name(s):
|
50 |
import re
|
|
|
83 |
else: return gr.update(value=files[0], choices=files)
|
84 |
|
85 |
|
86 |
+
# https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny
|
87 |
+
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
|
88 |
+
# https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
|
89 |
+
def change_base_model(repo_id: str, cn_on: bool, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
90 |
global pipe
|
91 |
+
global controlnet
|
92 |
global last_model
|
93 |
+
global last_cn_on
|
94 |
try:
|
95 |
+
if (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return
|
96 |
+
if cn_on:
|
97 |
+
progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
|
98 |
+
print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
|
99 |
+
clear_cache()
|
100 |
+
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=torch.bfloat16)
|
101 |
+
controlnet = FluxMultiControlNetModel([controlnet_union])
|
102 |
+
pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=torch.bfloat16)
|
103 |
+
last_model = repo_id
|
104 |
+
progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
|
105 |
+
print(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
|
106 |
+
else:
|
107 |
+
progress(0, desc=f"Loading model: {repo_id}")
|
108 |
+
print(f"Loading model: {repo_id}")
|
109 |
+
clear_cache()
|
110 |
+
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
|
111 |
+
last_model = repo_id
|
112 |
+
progress(1, desc=f"Model loaded: {repo_id}")
|
113 |
+
print(f"Model loaded: {repo_id}")
|
114 |
except Exception as e:
|
115 |
print(e)
|
116 |
return gr.update(visible=True)
|
117 |
|
118 |
|
119 |
+
# https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny/blob/main/app.py
|
120 |
+
def resize_image(image, target_width, target_height, crop=True):
|
121 |
+
from image_datasets.canny_dataset import c_crop
|
122 |
+
if crop:
|
123 |
+
image = c_crop(image) # Crop the image to square
|
124 |
+
original_width, original_height = image.size
|
125 |
+
|
126 |
+
# Resize to match the target size without stretching
|
127 |
+
scale = max(target_width / original_width, target_height / original_height)
|
128 |
+
resized_width = int(scale * original_width)
|
129 |
+
resized_height = int(scale * original_height)
|
130 |
+
|
131 |
+
image = image.resize((resized_width, resized_height), Image.LANCZOS)
|
132 |
+
|
133 |
+
# Center crop to match the target dimensions
|
134 |
+
left = (resized_width - target_width) // 2
|
135 |
+
top = (resized_height - target_height) // 2
|
136 |
+
image = image.crop((left, top, left + target_width, top + target_height))
|
137 |
+
else:
|
138 |
+
image = image.resize((target_width, target_height), Image.LANCZOS)
|
139 |
+
|
140 |
+
return image
|
141 |
+
|
142 |
+
|
143 |
+
# https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union/blob/main/app.py
|
144 |
+
controlnet_union_modes = {
|
145 |
+
"None": -1,
|
146 |
+
#"scribble_hed": 0,
|
147 |
+
"canny": 0, # supported
|
148 |
+
"mlsd": 0, #supported
|
149 |
+
"tile": 1, #supported
|
150 |
+
"depth_midas": 2, # supported
|
151 |
+
"blur": 3, # supported
|
152 |
+
"openpose": 4, # supported
|
153 |
+
"gray": 5, # supported
|
154 |
+
"low_quality": 6, # supported
|
155 |
+
}
|
156 |
+
|
157 |
+
|
158 |
+
def get_control_params():
|
159 |
+
modes = []
|
160 |
+
images = []
|
161 |
+
scales = []
|
162 |
+
for i, mode in enumerate(control_modes):
|
163 |
+
if mode == -1 or control_images[i] is None: continue
|
164 |
+
modes.append(control_modes[i])
|
165 |
+
images.append(control_images[i])
|
166 |
+
scales.append(control_scales[i])
|
167 |
+
return modes, images, scales
|
168 |
+
|
169 |
+
|
170 |
+
from preprocessor import Preprocessor
|
171 |
+
def preprocess_image(image: Image.Image, control_mode: str, height: int, width: int, preprocess_resolution: int):
|
172 |
+
image_resolution = max(width, height)
|
173 |
+
image_before = resize_image(image, image_resolution, image_resolution, True)
|
174 |
+
# generated control_
|
175 |
+
print("start to generate control image")
|
176 |
+
preprocessor = Preprocessor()
|
177 |
+
if control_mode == "depth_midas":
|
178 |
+
preprocessor.load("Midas")
|
179 |
+
control_image = preprocessor(
|
180 |
+
image=image_before,
|
181 |
+
image_resolution=image_resolution,
|
182 |
+
detect_resolution=preprocess_resolution,
|
183 |
+
)
|
184 |
+
if control_mode == "openpose":
|
185 |
+
preprocessor.load("Openpose")
|
186 |
+
control_image = preprocessor(
|
187 |
+
image=image_before,
|
188 |
+
hand_and_face=True,
|
189 |
+
image_resolution=image_resolution,
|
190 |
+
detect_resolution=preprocess_resolution,
|
191 |
+
)
|
192 |
+
if control_mode == "canny":
|
193 |
+
preprocessor.load("Canny")
|
194 |
+
control_image = preprocessor(
|
195 |
+
image=image_before,
|
196 |
+
image_resolution=image_resolution,
|
197 |
+
detect_resolution=preprocess_resolution,
|
198 |
+
)
|
199 |
+
|
200 |
+
if control_mode == "mlsd":
|
201 |
+
preprocessor.load("MLSD")
|
202 |
+
control_image = preprocessor(
|
203 |
+
image=image_before,
|
204 |
+
image_resolution=image_resolution,
|
205 |
+
detect_resolution=preprocess_resolution,
|
206 |
+
)
|
207 |
+
|
208 |
+
if control_mode == "scribble_hed":
|
209 |
+
preprocessor.load("HED")
|
210 |
+
control_image = preprocessor(
|
211 |
+
image=image_before,
|
212 |
+
image_resolution=image_resolution,
|
213 |
+
detect_resolution=preprocess_resolution,
|
214 |
+
)
|
215 |
+
|
216 |
+
if control_mode == "low_quality" or control_mode == "gray" or control_mode == "blur" or control_mode == "tile":
|
217 |
+
control_image = image_before
|
218 |
+
image_width = 768
|
219 |
+
image_height = 768
|
220 |
+
else:
|
221 |
+
# make sure control image size is same as resized_image
|
222 |
+
image_width, image_height = control_image.size
|
223 |
+
|
224 |
+
image_after = resize_image(control_image, width, height, True)
|
225 |
+
print(f"generate control image success: {image_width}x{image_height} => {width}x{height}")
|
226 |
+
|
227 |
+
return image_after
|
228 |
+
|
229 |
+
|
230 |
+
def get_control_union_mode():
|
231 |
+
return list(controlnet_union_modes.keys())
|
232 |
+
|
233 |
+
|
234 |
+
def set_control_union_mode(i: int, mode: str, scale: str):
|
235 |
+
global control_modes
|
236 |
+
global control_scales
|
237 |
+
control_modes[i] = controlnet_union_modes.get(mode, 0)
|
238 |
+
control_scales[i] = scale
|
239 |
+
if mode != "None": return True
|
240 |
+
else: return gr.update(visible=True)
|
241 |
+
|
242 |
+
|
243 |
+
def set_control_union_image(i: int, mode: str, image: Image.Image, height: int, width: int, preprocess_resolution: int):
|
244 |
+
global control_images
|
245 |
+
control_images[i] = preprocess_image(image, mode, height, width, preprocess_resolution)
|
246 |
+
return control_images[i]
|
247 |
+
|
248 |
+
|
249 |
def compose_lora_json(lorajson: list[dict], i: int, name: str, scale: float, filename: str, trigger: str):
|
250 |
lorajson[i]["name"] = str(name) if name != "None" else ""
|
251 |
lorajson[i]["scale"] = float(scale)
|
|
|
268 |
trigger += ", " + d["trigger"]
|
269 |
return trigger
|
270 |
|
271 |
+
|
272 |
# https://huggingface.co/docs/diffusers/v0.23.1/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora
|
273 |
# https://github.com/huggingface/diffusers/issues/4919
|
274 |
def fuse_loras(pipe, lorajson: list[dict]):
|
|
|
296 |
#pipe.unload_lora_weights()
|
297 |
|
298 |
|
|
|
|
|
|
|
299 |
def description_ui():
|
300 |
gr.Markdown(
|
301 |
"""
|
302 |
- Mod of [multimodalart/flux-lora-the-explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer),
|
303 |
+
[jiuface/FLUX.1-dev-Controlnet-Union](https://huggingface.co/spaces/jiuface/),
|
304 |
+
[DamarJati/FLUX.1-DEV-Canny](https://huggingface.co/spaces/DamarJati/FLUX.1-DEV-Canny),
|
305 |
[gokaygokay/FLUX-Prompt-Generator](https://huggingface.co/spaces/gokaygokay/FLUX-Prompt-Generator).
|
306 |
"""
|
307 |
)
|
preprocessor.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import PIL.Image
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from controlnet_aux import (
|
8 |
+
CannyDetector,
|
9 |
+
ContentShuffleDetector,
|
10 |
+
HEDdetector,
|
11 |
+
LineartAnimeDetector,
|
12 |
+
LineartDetector,
|
13 |
+
MidasDetector,
|
14 |
+
MLSDdetector,
|
15 |
+
NormalBaeDetector,
|
16 |
+
OpenposeDetector,
|
17 |
+
PidiNetDetector,
|
18 |
+
)
|
19 |
+
from controlnet_aux.util import HWC3
|
20 |
+
|
21 |
+
from cv_utils import resize_image
|
22 |
+
from depth_estimator import DepthEstimator
|
23 |
+
from image_segmentor import ImageSegmentor
|
24 |
+
|
25 |
+
from kornia.core import Tensor
|
26 |
+
|
27 |
+
# load preprocessor
|
28 |
+
|
29 |
+
# HED = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
30 |
+
Midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
31 |
+
MLSD = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
32 |
+
Canny = CannyDetector()
|
33 |
+
OPENPOSE = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
34 |
+
|
35 |
+
|
36 |
+
class Preprocessor:
|
37 |
+
MODEL_ID = "lllyasviel/Annotators"
|
38 |
+
|
39 |
+
def __init__(self):
|
40 |
+
self.model = None
|
41 |
+
self.name = ""
|
42 |
+
|
43 |
+
def load(self, name: str) -> None:
|
44 |
+
if name == self.name:
|
45 |
+
return
|
46 |
+
|
47 |
+
if name == "Midas":
|
48 |
+
self.model = Midas
|
49 |
+
elif name == "MLSD":
|
50 |
+
self.model =MLSD
|
51 |
+
elif name == "Openpose":
|
52 |
+
self.model = OPENPOSE
|
53 |
+
elif name == "Canny":
|
54 |
+
self.model = Canny
|
55 |
+
else:
|
56 |
+
raise ValueError
|
57 |
+
torch.cuda.empty_cache()
|
58 |
+
gc.collect()
|
59 |
+
self.name = name
|
60 |
+
|
61 |
+
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
|
62 |
+
if self.name == "Canny" or self.name == "MLSD":
|
63 |
+
detect_resolution = kwargs.pop("detect_resolution")
|
64 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
65 |
+
image = np.array(image)
|
66 |
+
image = HWC3(image)
|
67 |
+
image = resize_image(image, resolution=detect_resolution)
|
68 |
+
image = self.model(image, **kwargs)
|
69 |
+
image = np.array(image)
|
70 |
+
image = HWC3(image)
|
71 |
+
image = resize_image(image, resolution=image_resolution)
|
72 |
+
return PIL.Image.fromarray(image).convert('RGB')
|
73 |
+
|
74 |
+
else:
|
75 |
+
detect_resolution = kwargs.pop("detect_resolution", 512)
|
76 |
+
image_resolution = kwargs.pop("image_resolution", 512)
|
77 |
+
image = np.array(image)
|
78 |
+
image = HWC3(image)
|
79 |
+
image = resize_image(image, resolution=detect_resolution)
|
80 |
+
image = self.model(image, **kwargs)
|
81 |
+
image = np.array(image)
|
82 |
+
image = HWC3(image)
|
83 |
+
image = resize_image(image, resolution=image_resolution)
|
84 |
+
return PIL.Image.fromarray(image)
|
requirements.txt
CHANGED
@@ -8,5 +8,7 @@ transformers
|
|
8 |
peft
|
9 |
sentencepiece
|
10 |
timm
|
11 |
-
|
12 |
-
|
|
|
|
|
|
8 |
peft
|
9 |
sentencepiece
|
10 |
timm
|
11 |
+
einops
|
12 |
+
controlnet-aux
|
13 |
+
kornia
|
14 |
+
numpy
|