Spaces:
Running
on
Zero
Running
on
Zero
Bobby
commited on
Commit
·
838a1f4
1
Parent(s):
9b7a27f
profiler part 2
Browse files- README.md +1 -1
- preprocess_anime.py +19 -0
- profiler.py +157 -171
README.md
CHANGED
@@ -6,7 +6,7 @@ colorTo: pink
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.31.4
|
8 |
#app_file: anime_app.py
|
9 |
-
app_file:
|
10 |
pinned: true
|
11 |
license: apache-2.0
|
12 |
short_description: Turn yourself into a weeb
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.31.4
|
8 |
#app_file: anime_app.py
|
9 |
+
app_file: profiler.py
|
10 |
pinned: true
|
11 |
license: apache-2.0
|
12 |
short_description: Turn yourself into a weeb
|
preprocess_anime.py
CHANGED
@@ -49,3 +49,22 @@ class Preprocessor:
|
|
49 |
return PIL.Image.fromarray(image)
|
50 |
else:
|
51 |
return self.model(image, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
return PIL.Image.fromarray(image)
|
50 |
else:
|
51 |
return self.model(image, **kwargs)
|
52 |
+
|
53 |
+
def manage_memory(self):
|
54 |
+
torch.cuda.empty_cache()
|
55 |
+
gc.collect()
|
56 |
+
|
57 |
+
# Additional helper function to manage memory less frequently
|
58 |
+
def conditionally_manage_memory(memory_threshold=0.8):
|
59 |
+
"""
|
60 |
+
Frees up GPU memory if usage exceeds the threshold.
|
61 |
+
:param memory_threshold: Fraction of memory usage to trigger cleanup.
|
62 |
+
"""
|
63 |
+
if torch.cuda.is_available():
|
64 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory
|
65 |
+
reserved_memory = torch.cuda.memory_reserved(0)
|
66 |
+
allocated_memory = torch.cuda.memory_allocated(0)
|
67 |
+
|
68 |
+
if reserved_memory / total_memory > memory_threshold:
|
69 |
+
torch.cuda.empty_cache()
|
70 |
+
gc.collect()
|
profiler.py
CHANGED
@@ -1,40 +1,78 @@
|
|
1 |
import cProfile
|
2 |
import pstats
|
3 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
show_options = False
|
11 |
-
import gc
|
12 |
-
import random
|
13 |
-
import time
|
14 |
-
import gradio as gr
|
15 |
-
import spaces
|
16 |
-
import imageio
|
17 |
-
from huggingface_hub import HfApi
|
18 |
-
import torch
|
19 |
-
from PIL import Image
|
20 |
-
from diffusers import (
|
21 |
-
ControlNetModel,
|
22 |
-
DPMSolverMultistepScheduler,
|
23 |
-
StableDiffusionControlNetPipeline,
|
24 |
-
)
|
25 |
-
from preprocess_anime import Preprocessor
|
26 |
-
from settings import API_KEY, MAX_NUM_IMAGES, MAX_SEED
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
preprocessor = Preprocessor()
|
36 |
|
37 |
-
|
38 |
model_id = "lllyasviel/control_v11p_sd15_normalbae"
|
39 |
print("initializing controlnet")
|
40 |
controlnet = ControlNetModel.from_pretrained(
|
@@ -42,8 +80,8 @@ def main():
|
|
42 |
torch_dtype=torch.float16,
|
43 |
attn_implementation="flash_attention_2",
|
44 |
).to("cuda")
|
45 |
-
|
46 |
-
|
47 |
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
48 |
"runwayml/stable-diffusion-v1-5",
|
49 |
solver_order=2,
|
@@ -57,9 +95,8 @@ def main():
|
|
57 |
device_map="cuda",
|
58 |
)
|
59 |
|
60 |
-
|
61 |
base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
|
62 |
-
|
63 |
pipe = StableDiffusionControlNetPipeline.from_single_file(
|
64 |
base_model_url,
|
65 |
safety_checker=None,
|
@@ -67,8 +104,7 @@ def main():
|
|
67 |
scheduler=scheduler,
|
68 |
torch_dtype=torch.float16,
|
69 |
)
|
70 |
-
|
71 |
-
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
|
72 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
|
73 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
|
74 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
|
@@ -79,123 +115,84 @@ def main():
|
|
79 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
|
80 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
|
81 |
pipe.to("cuda")
|
82 |
-
|
83 |
-
torch.cuda.empty_cache()
|
84 |
-
gc.collect()
|
85 |
-
print("---------------Loaded controlnet pipeline---------------")
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
|
94 |
-
top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
|
95 |
-
bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
|
96 |
-
accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
|
97 |
-
return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
|
98 |
-
# outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
randomize = get_additional_prompt()
|
103 |
-
nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
|
104 |
-
bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
|
105 |
-
lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
|
106 |
-
pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
|
107 |
-
bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
|
108 |
-
ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
|
109 |
-
ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
|
110 |
-
athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
|
111 |
-
atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
|
112 |
-
maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
|
113 |
-
nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
|
114 |
-
naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
|
115 |
-
abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
|
116 |
-
shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
|
117 |
-
|
118 |
-
if prompt == "":
|
119 |
-
prompts = [randomize, nude, bodypaint, pet_play, bondage, ahegao2, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari]
|
120 |
-
prompts_nsfw = [nude, bodypaint, abg, ahegao2, shibari]
|
121 |
-
preset = random.choice(prompts)
|
122 |
-
prompt = f"{preset}"
|
123 |
-
# print(f"-------------{preset}-------------")
|
124 |
-
else:
|
125 |
-
# prompt = f"{prompt}, {randomize}"
|
126 |
-
prompt = f"{default},{prompt}"
|
127 |
-
print(f"{prompt}")
|
128 |
-
return prompt
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
image_resolution,
|
139 |
-
preprocess_resolution,
|
140 |
-
num_steps,
|
141 |
-
guidance_scale,
|
142 |
-
seed,
|
143 |
-
):
|
144 |
-
print("processing image")
|
145 |
-
start = time.time()
|
146 |
-
preprocessor.load("NormalBae")
|
147 |
-
# preprocessor.load("Canny") #20 steps, 9 guidance, 512, 512
|
148 |
-
control_image = preprocessor(
|
149 |
-
image=image,
|
150 |
-
image_resolution=image_resolution,
|
151 |
-
detect_resolution=preprocess_resolution,
|
152 |
-
)
|
153 |
-
custom_prompt=str(get_prompt(prompt, a_prompt))
|
154 |
-
negative_prompt=str(n_prompt)
|
155 |
-
global compiled
|
156 |
-
generator = torch.cuda.manual_seed(seed)
|
157 |
-
if not compiled:
|
158 |
-
print("-----------------------------------Not Compiled-----------------------------------")
|
159 |
-
compiled = True
|
160 |
-
results = pipe(
|
161 |
-
prompt=custom_prompt,
|
162 |
-
negative_prompt=negative_prompt,
|
163 |
-
guidance_scale=guidance_scale,
|
164 |
-
num_images_per_prompt=num_images,
|
165 |
-
num_inference_steps=num_steps,
|
166 |
-
generator=generator,
|
167 |
-
image=control_image,
|
168 |
-
).images[0]
|
169 |
-
print(f"Inference done in: {time.time() - start:.2f} seconds")
|
170 |
-
|
171 |
-
timestamp = int(time.time())
|
172 |
-
img_path = f"{timestamp}.jpg"
|
173 |
-
results_path = f"{timestamp}_out.jpg"
|
174 |
-
imageio.imsave(img_path, image)
|
175 |
-
results.save(results_path)
|
176 |
-
|
177 |
-
api.upload_file(
|
178 |
-
path_or_fileobj=img_path,
|
179 |
-
path_in_repo=img_path,
|
180 |
-
repo_id="broyang/anime-ai-outputs",
|
181 |
-
repo_type="dataset",
|
182 |
-
token=API_KEY,
|
183 |
-
run_as_future=True,
|
184 |
-
)
|
185 |
-
api.upload_file(
|
186 |
-
path_or_fileobj=results_path,
|
187 |
-
path_in_repo=results_path,
|
188 |
-
repo_id="broyang/anime-ai-outputs",
|
189 |
-
repo_type="dataset",
|
190 |
-
token=API_KEY,
|
191 |
-
run_as_future=True,
|
192 |
-
)
|
193 |
-
|
194 |
-
torch.cuda.empty_cache()
|
195 |
-
gc.collect()
|
196 |
-
|
197 |
-
results.save("temp_image.png")
|
198 |
-
return results
|
199 |
|
200 |
css = """
|
201 |
h1 {
|
@@ -213,7 +210,6 @@ def main():
|
|
213 |
footer {visibility: hidden}
|
214 |
"""
|
215 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
216 |
-
#############################################################################
|
217 |
with gr.Row():
|
218 |
with gr.Accordion("Advanced options", open=show_options, visible=show_options):
|
219 |
num_images = gr.Slider(
|
@@ -235,10 +231,10 @@ def main():
|
|
235 |
)
|
236 |
num_steps = gr.Slider(
|
237 |
label="Number of steps", minimum=1, maximum=100, value=12, step=1
|
238 |
-
)
|
239 |
guidance_scale = gr.Slider(
|
240 |
label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
|
241 |
-
)
|
242 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
243 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
244 |
a_prompt = gr.Textbox(
|
@@ -249,14 +245,11 @@ def main():
|
|
249 |
label="Negative prompt",
|
250 |
value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
|
251 |
)
|
252 |
-
#############################################################################
|
253 |
-
# input text
|
254 |
with gr.Column():
|
255 |
prompt = gr.Textbox(
|
256 |
label="Description",
|
257 |
placeholder="Leave empty for something spicy 👀",
|
258 |
)
|
259 |
-
# input image
|
260 |
with gr.Row():
|
261 |
with gr.Column():
|
262 |
image = gr.Image(
|
@@ -265,19 +258,16 @@ def main():
|
|
265 |
show_label=True,
|
266 |
format="webp",
|
267 |
)
|
268 |
-
# run button
|
269 |
with gr.Column():
|
270 |
run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
|
271 |
-
# output image
|
272 |
with gr.Column():
|
273 |
-
result = gr.Image(
|
274 |
label="Anime AI",
|
275 |
interactive=False,
|
276 |
format="webp",
|
277 |
visible = True,
|
278 |
show_share_button= False,
|
279 |
)
|
280 |
-
# Use this image button
|
281 |
with gr.Column():
|
282 |
use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
|
283 |
config = [
|
@@ -292,16 +282,14 @@ def main():
|
|
292 |
guidance_scale,
|
293 |
seed,
|
294 |
]
|
295 |
-
|
296 |
-
# inputs=image,
|
297 |
-
# fn=process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed),
|
298 |
-
# run_on_click=True
|
299 |
-
# )
|
300 |
-
|
301 |
-
@gr.on(triggers=[image.upload], inputs=config, outputs=[result])
|
302 |
def auto_process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed):
|
303 |
return process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
|
304 |
|
|
|
|
|
|
|
|
|
305 |
@gr.on(triggers=[image.upload], inputs=None, outputs=[use_ai_button, run_button])
|
306 |
def turn_buttons_off():
|
307 |
return gr.update(visible=False), gr.update(visible=False)
|
@@ -329,7 +317,7 @@ def main():
|
|
329 |
api_name=False,
|
330 |
show_progress="none",
|
331 |
).then(
|
332 |
-
fn=
|
333 |
inputs=config,
|
334 |
outputs=result,
|
335 |
api_name=False,
|
@@ -344,7 +332,7 @@ def main():
|
|
344 |
api_name=False,
|
345 |
show_progress="none",
|
346 |
).then(
|
347 |
-
fn=
|
348 |
inputs=config,
|
349 |
outputs=result,
|
350 |
show_progress="minimal",
|
@@ -353,7 +341,6 @@ def main():
|
|
353 |
def update_config():
|
354 |
try:
|
355 |
print("Updating image to AI Temp Image")
|
356 |
-
# Read the image from the file
|
357 |
ai_temp_image = Image.open("temp_image.png")
|
358 |
return ai_temp_image
|
359 |
except FileNotFoundError:
|
@@ -373,13 +360,12 @@ def main():
|
|
373 |
outputs=image,
|
374 |
show_progress="minimal",
|
375 |
).then(
|
376 |
-
fn=
|
377 |
inputs=[image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed],
|
378 |
outputs=result,
|
379 |
show_progress="minimal",
|
380 |
)
|
381 |
|
382 |
-
|
383 |
demo.launch()
|
384 |
|
385 |
if __name__ == "__main__":
|
@@ -392,4 +378,4 @@ if __name__ == "__main__":
|
|
392 |
sortby = 'cumulative'
|
393 |
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
|
394 |
ps.print_stats()
|
395 |
-
print(s.getvalue())
|
|
|
1 |
import cProfile
|
2 |
import pstats
|
3 |
import io
|
4 |
+
import gc
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
import gradio as gr
|
8 |
+
import spaces
|
9 |
+
import imageio
|
10 |
+
from huggingface_hub import HfApi
|
11 |
+
import torch
|
12 |
+
from PIL import Image
|
13 |
+
from diffusers import (
|
14 |
+
ControlNetModel,
|
15 |
+
DPMSolverMultistepScheduler,
|
16 |
+
StableDiffusionControlNetPipeline,
|
17 |
+
)
|
18 |
+
from preprocess_anime import Preprocessor, conditionally_manage_memory
|
19 |
+
from settings import API_KEY, MAX_NUM_IMAGES, MAX_SEED
|
20 |
|
21 |
+
preprocessor = None
|
22 |
+
controlnet = None
|
23 |
+
scheduler = None
|
24 |
+
pipe = None
|
25 |
+
api = HfApi()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
28 |
+
if randomize_seed:
|
29 |
+
seed = random.randint(0, MAX_SEED)
|
30 |
+
return seed
|
31 |
+
|
32 |
+
def get_additional_prompt():
|
33 |
+
prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
|
34 |
+
top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
|
35 |
+
bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
|
36 |
+
accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
|
37 |
+
return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
|
38 |
+
# outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
|
39 |
|
40 |
+
def get_prompt(prompt, additional_prompt):
|
41 |
+
default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
|
42 |
+
randomize = get_additional_prompt()
|
43 |
+
nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
|
44 |
+
bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
|
45 |
+
lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
|
46 |
+
pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
|
47 |
+
bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
|
48 |
+
ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
|
49 |
+
ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
|
50 |
+
athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
|
51 |
+
atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
|
52 |
+
maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
|
53 |
+
nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
|
54 |
+
naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
|
55 |
+
abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
|
56 |
+
shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
|
57 |
+
|
58 |
+
if prompt == "":
|
59 |
+
prompts = [randomize, nude, bodypaint, pet_play, bondage, ahegao2, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari]
|
60 |
+
prompts_nsfw = [nude, bodypaint, abg, ahegao2, shibari]
|
61 |
+
preset = random.choice(prompts)
|
62 |
+
prompt = f"{preset}"
|
63 |
+
# print(f"-------------{preset}-------------")
|
64 |
+
else:
|
65 |
+
# prompt = f"{prompt}, {randomize}"
|
66 |
+
prompt = f"{default},{prompt}"
|
67 |
+
print(f"{prompt}")
|
68 |
+
return prompt
|
69 |
+
|
70 |
+
def initialize_models():
|
71 |
+
global preprocessor, controlnet, scheduler, pipe
|
72 |
+
if preprocessor is None:
|
73 |
preprocessor = Preprocessor()
|
74 |
|
75 |
+
if controlnet is None:
|
76 |
model_id = "lllyasviel/control_v11p_sd15_normalbae"
|
77 |
print("initializing controlnet")
|
78 |
controlnet = ControlNetModel.from_pretrained(
|
|
|
80 |
torch_dtype=torch.float16,
|
81 |
attn_implementation="flash_attention_2",
|
82 |
).to("cuda")
|
83 |
+
|
84 |
+
if scheduler is None:
|
85 |
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
86 |
"runwayml/stable-diffusion-v1-5",
|
87 |
solver_order=2,
|
|
|
95 |
device_map="cuda",
|
96 |
)
|
97 |
|
98 |
+
if pipe is None:
|
99 |
base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
|
|
|
100 |
pipe = StableDiffusionControlNetPipeline.from_single_file(
|
101 |
base_model_url,
|
102 |
safety_checker=None,
|
|
|
104 |
scheduler=scheduler,
|
105 |
torch_dtype=torch.float16,
|
106 |
)
|
107 |
+
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2")
|
|
|
108 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
|
109 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
|
110 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
|
|
|
115 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
|
116 |
pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
|
117 |
pipe.to("cuda")
|
118 |
+
print("---------------Loaded controlnet pipeline---------------")
|
|
|
|
|
|
|
119 |
|
120 |
+
@spaces.GPU(duration=11)
|
121 |
+
@torch.inference_mode()
|
122 |
+
def process_image(
|
123 |
+
image,
|
124 |
+
prompt,
|
125 |
+
a_prompt,
|
126 |
+
n_prompt,
|
127 |
+
num_images,
|
128 |
+
image_resolution,
|
129 |
+
preprocess_resolution,
|
130 |
+
num_steps,
|
131 |
+
guidance_scale,
|
132 |
+
seed,
|
133 |
+
):
|
134 |
+
initialize_models()
|
135 |
+
preprocessor.load("NormalBae")
|
136 |
+
control_image = preprocessor(
|
137 |
+
image=image,
|
138 |
+
image_resolution=image_resolution,
|
139 |
+
detect_resolution=preprocess_resolution,
|
140 |
+
)
|
141 |
+
custom_prompt = str(get_prompt(prompt, a_prompt))
|
142 |
+
negative_prompt = str(n_prompt)
|
143 |
+
global compiled
|
144 |
+
generator = torch.cuda.manual_seed(seed)
|
145 |
+
if not compiled:
|
146 |
+
print("-----------------------------------Not Compiled-----------------------------------")
|
147 |
+
compiled = True
|
148 |
+
start = time.time()
|
149 |
+
results = pipe(
|
150 |
+
prompt=custom_prompt,
|
151 |
+
negative_prompt=negative_prompt,
|
152 |
+
guidance_scale=guidance_scale,
|
153 |
+
num_images_per_prompt=num_images,
|
154 |
+
num_inference_steps=num_steps,
|
155 |
+
generator=generator,
|
156 |
+
image=control_image,
|
157 |
+
).images[0]
|
158 |
+
print(f"Inference done in: {time.time() - start:.2f} seconds")
|
159 |
+
|
160 |
+
timestamp = int(time.time())
|
161 |
+
img_path = f"{timestamp}.jpg"
|
162 |
+
results_path = f"{timestamp}_out.jpg"
|
163 |
+
imageio.imsave(img_path, image)
|
164 |
+
results.save(results_path)
|
165 |
+
|
166 |
+
api.upload_file(
|
167 |
+
path_or_fileobj=img_path,
|
168 |
+
path_in_repo=img_path,
|
169 |
+
repo_id="broyang/anime-ai-outputs",
|
170 |
+
repo_type="dataset",
|
171 |
+
token=API_KEY,
|
172 |
+
run_as_future=True,
|
173 |
+
)
|
174 |
+
api.upload_file(
|
175 |
+
path_or_fileobj=results_path,
|
176 |
+
path_in_repo=results_path,
|
177 |
+
repo_id="broyang/anime-ai-outputs",
|
178 |
+
repo_type="dataset",
|
179 |
+
token=API_KEY,
|
180 |
+
run_as_future=True,
|
181 |
+
)
|
182 |
|
183 |
+
conditionally_manage_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
results.save("temp_image.png")
|
186 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
+
def main():
|
189 |
+
prod = True
|
190 |
+
show_options = True
|
191 |
+
if prod:
|
192 |
+
show_options = False
|
193 |
+
|
194 |
+
print("CUDA version:", torch.version.cuda)
|
195 |
+
print("loading pipe")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
css = """
|
198 |
h1 {
|
|
|
210 |
footer {visibility: hidden}
|
211 |
"""
|
212 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
|
213 |
with gr.Row():
|
214 |
with gr.Accordion("Advanced options", open=show_options, visible=show_options):
|
215 |
num_images = gr.Slider(
|
|
|
231 |
)
|
232 |
num_steps = gr.Slider(
|
233 |
label="Number of steps", minimum=1, maximum=100, value=12, step=1
|
234 |
+
)
|
235 |
guidance_scale = gr.Slider(
|
236 |
label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
|
237 |
+
)
|
238 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
|
239 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
240 |
a_prompt = gr.Textbox(
|
|
|
245 |
label="Negative prompt",
|
246 |
value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
|
247 |
)
|
|
|
|
|
248 |
with gr.Column():
|
249 |
prompt = gr.Textbox(
|
250 |
label="Description",
|
251 |
placeholder="Leave empty for something spicy 👀",
|
252 |
)
|
|
|
253 |
with gr.Row():
|
254 |
with gr.Column():
|
255 |
image = gr.Image(
|
|
|
258 |
show_label=True,
|
259 |
format="webp",
|
260 |
)
|
|
|
261 |
with gr.Column():
|
262 |
run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
|
|
|
263 |
with gr.Column():
|
264 |
+
result = gr.Image(
|
265 |
label="Anime AI",
|
266 |
interactive=False,
|
267 |
format="webp",
|
268 |
visible = True,
|
269 |
show_share_button= False,
|
270 |
)
|
|
|
271 |
with gr.Column():
|
272 |
use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
|
273 |
config = [
|
|
|
282 |
guidance_scale,
|
283 |
seed,
|
284 |
]
|
285 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
def auto_process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed):
|
287 |
return process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
|
288 |
|
289 |
+
@gr.on(triggers=[image.upload], inputs=config, outputs=[result])
|
290 |
+
def turn_buttons_off():
|
291 |
+
return gr.update(visible=False), gr.update(visible=False)
|
292 |
+
|
293 |
@gr.on(triggers=[image.upload], inputs=None, outputs=[use_ai_button, run_button])
|
294 |
def turn_buttons_off():
|
295 |
return gr.update(visible=False), gr.update(visible=False)
|
|
|
317 |
api_name=False,
|
318 |
show_progress="none",
|
319 |
).then(
|
320 |
+
fn=auto_process_image,
|
321 |
inputs=config,
|
322 |
outputs=result,
|
323 |
api_name=False,
|
|
|
332 |
api_name=False,
|
333 |
show_progress="none",
|
334 |
).then(
|
335 |
+
fn=auto_process_image,
|
336 |
inputs=config,
|
337 |
outputs=result,
|
338 |
show_progress="minimal",
|
|
|
341 |
def update_config():
|
342 |
try:
|
343 |
print("Updating image to AI Temp Image")
|
|
|
344 |
ai_temp_image = Image.open("temp_image.png")
|
345 |
return ai_temp_image
|
346 |
except FileNotFoundError:
|
|
|
360 |
outputs=image,
|
361 |
show_progress="minimal",
|
362 |
).then(
|
363 |
+
fn=auto_process_image,
|
364 |
inputs=[image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed],
|
365 |
outputs=result,
|
366 |
show_progress="minimal",
|
367 |
)
|
368 |
|
|
|
369 |
demo.launch()
|
370 |
|
371 |
if __name__ == "__main__":
|
|
|
378 |
sortby = 'cumulative'
|
379 |
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
|
380 |
ps.print_stats()
|
381 |
+
print(s.getvalue())
|