Gainward777 commited on
Commit
9b843da
1 Parent(s): 6eaf8e4

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +6 -154
  2. sd/sd_controller.py +74 -0
  3. sd/utils/utils.py +78 -0
  4. ui/gradio_ui.py +30 -0
  5. utils/utils.py +77 -0
app.py CHANGED
@@ -1,154 +1,6 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
-
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
+ from ui.gradio_ui import ui
2
+ from sd.sd_controller import Controller
3
+
4
+ controller=Controller()
5
+
6
+ ui(controller)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sd/sd_controller.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sd.utils.utils import *
2
+ from utils.utils import sketch_process, prompt_preprocess
3
+ #from controlnet_aux.pidi import PidiNetDetector
4
+ import spaces
5
+
6
+ class Controller():
7
+
8
+ def __init__(self,
9
+ models_names=["cagliostrolab/animagine-xl-3.1",
10
+ "stabilityai/stable-diffusion-xl-base-1.0"],
11
+ lora_path='sd/lora/lora.safetensors'):
12
+
13
+ self.models_names=models_names
14
+ self.lora_path=lora_path
15
+ self.vae=get_vae()
16
+ self.controlnet=get_controlnet()
17
+ self.adaptr=get_adapter()
18
+ self.scheduler=get_scheduler(model_name=self.models_names[1])
19
+ self.detector=get_detector()
20
+
21
+ self.first_pipe=get_pipe(vae=self.vae,
22
+ model_name=self.models_names[0],
23
+ controlnet=self.controlnet
24
+ lora_path=self.lora_path)
25
+
26
+ self.second_pipe=get_pipe(vae=self.vae,
27
+ model_name=self.models_names[1],
28
+ adapter=self.adapter
29
+ scheduler=self.scheduler)
30
+
31
+
32
+ @spaces.GPU
33
+ def get_first_result(self, img, prompt, negative_prompt,
34
+ controlnet_scale=0.5, strength=1.0,n_steps=30,eta=1.0):
35
+
36
+ substrate, resized_image = sketch_process(input_image)
37
+ prompt=prompt_preprocess(prompt)
38
+
39
+ result=self.first_pipe(image=substrate,
40
+ control_image=resized_image,
41
+ strength=strength,
42
+ prompt=prompt,
43
+ negative_prompt = negative_prompt,
44
+ controlnet_conditioning_scale=float(controlnet_scale),
45
+ generator=torch.manual_seed(0),
46
+ num_inference_steps=n_steps,
47
+ eta=eta)
48
+
49
+ return result.images[0]
50
+
51
+
52
+ @spaces.GPU
53
+ def get_second_result(self, img, prompt, negative_prompt,
54
+ g_scale=7.5, n_steps=25,
55
+ adapter_scale=0.9, adapter_factor=1.0):
56
+
57
+ preprocessed_img=self.detector(img,
58
+ detect_resolution=1024,
59
+ image_resolution=1024,
60
+ apply_filter=True).convert("L")
61
+
62
+ result=self.second_pipe(prompt=prompt,
63
+ negative_prompt=negative_prompt,
64
+ image=image_preprocessed,
65
+ guidance_scale=g_scale,
66
+ num_inference_steps=n_steps,
67
+ adapter_conditioning_scale=adapter_scale,
68
+ adapter_conditioning_factor=adapter_factor,
69
+ generator = torch.manual_seed(42))
70
+
71
+ return result.images[0]
72
+
73
+
74
+
sd/utils/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import (ControlNetModel,
3
+ StableDiffusionXLControlNetImg2ImgPipeline,
4
+ AutoencoderKL,
5
+ T2IAdapter,
6
+ StableDiffusionXLAdapterPipeline,
7
+ EulerAncestralDiscreteScheduler)
8
+
9
+ from controlnet_aux.pidi import PidiNetDetector
10
+
11
+ from PIL import Image
12
+ import os
13
+
14
+
15
+ #VAE=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
16
+
17
+ #CONTROLNET = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16)
18
+
19
+ #ADAPTER = T2IAdapter.from_pretrained("Adapter/t2iadapter",
20
+ #subfolder="sketch_sdxl_1.0",
21
+ #torch_dtype=torch.float16,
22
+ #adapter_type="full_adapter_xl")
23
+
24
+
25
+ def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"):
26
+ return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16)
27
+
28
+ def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"):
29
+ return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16)
30
+
31
+ def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",
32
+ adapter_type="full_adapter_xl"):
33
+ if adapter_type == "full_adapter_xl":
34
+ return T2IAdapter.from_pretrained(model_name,
35
+ subfolder=subfolder,
36
+ torch_dtype=torch.float16,
37
+ adapter_type=adapter_type)
38
+
39
+ def get_scheduler(model_name, scheduler_type="discrete"):
40
+ if scheduler_type == "discrete":
41
+ return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
42
+
43
+
44
+ def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'):
45
+ if model_type == 'pidi':
46
+ return PidiNetDetector.from_pretrained(model_name)
47
+
48
+
49
+ def load_lora(pipe, lora_path=None):
50
+ if lora_path != None:
51
+ try:
52
+ lora_dir='./'+'/'.join(lora_path.split("/")[:-1])
53
+ lora_name=lora_path.split("/")[-1]
54
+ pipe.load_lora_weights(lora_dir, weight_name=lora_name)
55
+ except Exception as ex:
56
+ print(ex)
57
+ #return pipe
58
+
59
+
60
+ def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None):
61
+ if controlnet!=None:
62
+ pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name,
63
+ controlnet=controlnet,
64
+ vae=vae,
65
+ torch_dtype=torch.float16)
66
+
67
+ load_lora(pipe, lora_path)
68
+ return pipe
69
+
70
+ elif adapter != None:
71
+ pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
72
+ adapter=adapter,
73
+ vae=vae,
74
+ scheduler=scheduler,
75
+ torch_dtype=torch.float16,
76
+ variant="fp16")
77
+ load_lora(pipe, lora_path)
78
+ return pipe
ui/gradio_ui.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def ui(controller):
4
+ with gr.Blocks() as ui:
5
+ with gr.Row():
6
+ with gr.Column():
7
+ sketch=gr.Image(sources = 'upload', label='Model image', type = 'pil')
8
+ first_prompt = gr.Textbox(label="Prompt", lines=3)
9
+ first_negative_prompt = gr.Textbox(label="Negative prompt", lines=3, value="sketch, lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry")
10
+ #controlnet_scale = gr.Slider(minimum=0.5, maximum=1.25, value=0.5, step=0.01, label="Contr")
11
+ improve_sketch = gr.Button(value="Improve Sketch", variant="primary")
12
+ with gr.Column():
13
+ improved_sketch_view = gr.Image(type="pil", label="Improved Sketch")
14
+
15
+ improve_sketch.click(fn=controller.get_first_result,
16
+ inputs=[sketch, first_prompt, first_negative_prompt],
17
+ outputs=improved_sketch_view)
18
+
19
+ with gr.Row():
20
+ result=gr.Image(type="pil", label="Improved Sketch")
21
+ second_prompt = gr.Textbox(label="Prompt", lines=3)
22
+ second_negative_prompt = gr.Textbox(label="Negative prompt", lines=3, value="disfigured, extra digit, fewer digits, cropped, worst quality, low quality")
23
+ result_button = gr.Button(value="Paint It", variant="primary")
24
+
25
+ result_button.click(fn=controller.get_secnd_result,
26
+ inputs=[sketch, second_prompt, second_negative_prompt],
27
+ outputs=result)
28
+
29
+
30
+ ui.queue().launch(debug=True)
utils/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ #first stage sketch preprocess
7
+ def conventional_resize(img):
8
+ original_width, original_height = img.size
9
+ aspect_ratio = original_width / original_height
10
+
11
+ conventional_sizes = {
12
+ 1: (1024, 1024),
13
+ 4/3: (1152, 896),
14
+ 3/2: (1216, 832),
15
+ 16/9: (1344, 768),
16
+ 21/9: (1568, 672),
17
+ 3/1: (1728, 576),
18
+ 1/4: (512, 2048),
19
+ 1/3: (576, 1728),
20
+ 9/16: (768, 1344),
21
+ 2/3: (832, 1216),
22
+ 3/4: (896, 1152)
23
+ }
24
+
25
+ closest_aspect_ratio = min(conventional_sizes.keys(), key=lambda x: abs(x - aspect_ratio))
26
+ new_width, new_height = conventional_sizes[closest_aspect_ratio]
27
+
28
+ resized_img = img.resize((new_width, new_height), Image.LANCZOS)
29
+
30
+ return resized_img
31
+
32
+
33
+ def get_substrate(img, color=(255, 255, 255, 255)):
34
+ size=img.size
35
+ substrate = Image.new("RGBA", size, color)
36
+ return substrate.convert("RGB")
37
+
38
+
39
+ def sketch_process(img):
40
+ substrate=conventional_resize(get_substrate(img))
41
+ resized_img = conventional_resize(img)
42
+ return substrate, resized_img
43
+
44
+
45
+
46
+
47
+ #first stage prompt preprocess
48
+ def remove_duplicates(base_prompt):
49
+ prompt_list = base_prompt.split(", ")
50
+ seen = set()
51
+ unique_tags = []
52
+ for tag in prompt_list :
53
+ tag_clean = tag.lower().strip()
54
+ if tag_clean not in seen and tag_clean != "":
55
+ unique_tags.append(tag)
56
+ seen.add(tag_clean)
57
+ return ", ".join(unique_tags)
58
+
59
+
60
+ def remove_color(base_prompt):
61
+ prompt_list = base_prompt.split(", ")
62
+ color_list = ["pink", "red", "orange", "brown", "yellow", "green", "blue", "purple", "blonde", "colored skin", "white hair"]
63
+ cleaned_tags = [tag for tag in prompt_list if all(color.lower() not in tag.lower() for color in color_list)]
64
+ return ", ".join(cleaned_tags)
65
+
66
+
67
+ def execute_prompt(base_prompt):
68
+ prompt_list = base_prompt.split(", ")
69
+ execute_tags = ["sketch", "transparent background"]
70
+ filtered_tags = [tag for tag in prompt_list if tag not in execute_tags]
71
+ return ", ".join(filtered_tags)
72
+
73
+ def prompt_preprocess(prompt):
74
+ result=execute_prompt(prompt)
75
+ result=remove_duplicates(result)
76
+ result=remove_color(result)
77
+ return result