Bobby commited on
Commit
838a1f4
·
1 Parent(s): 9b7a27f

profiler part 2

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. preprocess_anime.py +19 -0
  3. profiler.py +157 -171
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.31.4
8
  #app_file: anime_app.py
9
- app_file: anime_app.py
10
  pinned: true
11
  license: apache-2.0
12
  short_description: Turn yourself into a weeb
 
6
  sdk: gradio
7
  sdk_version: 4.31.4
8
  #app_file: anime_app.py
9
+ app_file: profiler.py
10
  pinned: true
11
  license: apache-2.0
12
  short_description: Turn yourself into a weeb
preprocess_anime.py CHANGED
@@ -49,3 +49,22 @@ class Preprocessor:
49
  return PIL.Image.fromarray(image)
50
  else:
51
  return self.model(image, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return PIL.Image.fromarray(image)
50
  else:
51
  return self.model(image, **kwargs)
52
+
53
+ def manage_memory(self):
54
+ torch.cuda.empty_cache()
55
+ gc.collect()
56
+
57
+ # Additional helper function to manage memory less frequently
58
+ def conditionally_manage_memory(memory_threshold=0.8):
59
+ """
60
+ Frees up GPU memory if usage exceeds the threshold.
61
+ :param memory_threshold: Fraction of memory usage to trigger cleanup.
62
+ """
63
+ if torch.cuda.is_available():
64
+ total_memory = torch.cuda.get_device_properties(0).total_memory
65
+ reserved_memory = torch.cuda.memory_reserved(0)
66
+ allocated_memory = torch.cuda.memory_allocated(0)
67
+
68
+ if reserved_memory / total_memory > memory_threshold:
69
+ torch.cuda.empty_cache()
70
+ gc.collect()
profiler.py CHANGED
@@ -1,40 +1,78 @@
1
  import cProfile
2
  import pstats
3
  import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Assuming the main function or entry point is `main`
6
- def main():
7
- prod = True
8
- show_options = True
9
- if prod:
10
- show_options = False
11
- import gc
12
- import random
13
- import time
14
- import gradio as gr
15
- import spaces
16
- import imageio
17
- from huggingface_hub import HfApi
18
- import torch
19
- from PIL import Image
20
- from diffusers import (
21
- ControlNetModel,
22
- DPMSolverMultistepScheduler,
23
- StableDiffusionControlNetPipeline,
24
- )
25
- from preprocess_anime import Preprocessor
26
- from settings import API_KEY, MAX_NUM_IMAGES, MAX_SEED
27
 
28
- print("CUDA version:", torch.version.cuda)
29
- print("loading pipe")
30
- compiled = False
31
- api = HfApi()
 
 
 
 
 
 
 
 
32
 
33
- if gr.NO_RELOAD:
34
- #torch.cuda.max_memory_allocated(device="cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  preprocessor = Preprocessor()
36
 
37
- # Controlnet Normal
38
  model_id = "lllyasviel/control_v11p_sd15_normalbae"
39
  print("initializing controlnet")
40
  controlnet = ControlNetModel.from_pretrained(
@@ -42,8 +80,8 @@ def main():
42
  torch_dtype=torch.float16,
43
  attn_implementation="flash_attention_2",
44
  ).to("cuda")
45
-
46
- # Scheduler
47
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
48
  "runwayml/stable-diffusion-v1-5",
49
  solver_order=2,
@@ -57,9 +95,8 @@ def main():
57
  device_map="cuda",
58
  )
59
 
60
- # Stable Diffusion Pipeline URL
61
  base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
62
-
63
  pipe = StableDiffusionControlNetPipeline.from_single_file(
64
  base_model_url,
65
  safety_checker=None,
@@ -67,8 +104,7 @@ def main():
67
  scheduler=scheduler,
68
  torch_dtype=torch.float16,
69
  )
70
-
71
- pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2",)
72
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
73
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
74
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
@@ -79,123 +115,84 @@ def main():
79
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
80
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
81
  pipe.to("cuda")
82
-
83
- torch.cuda.empty_cache()
84
- gc.collect()
85
- print("---------------Loaded controlnet pipeline---------------")
86
 
87
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
88
- if randomize_seed:
89
- seed = random.randint(0, MAX_SEED)
90
- return seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- def get_additional_prompt():
93
- prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
94
- top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
95
- bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
96
- accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
97
- return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
98
- # outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
99
 
100
- def get_prompt(prompt, additional_prompt):
101
- default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
102
- randomize = get_additional_prompt()
103
- nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
104
- bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
105
- lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
106
- pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
107
- bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
108
- ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
109
- ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
110
- athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
111
- atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
112
- maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
113
- nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
114
- naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
115
- abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
116
- shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
117
-
118
- if prompt == "":
119
- prompts = [randomize, nude, bodypaint, pet_play, bondage, ahegao2, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari]
120
- prompts_nsfw = [nude, bodypaint, abg, ahegao2, shibari]
121
- preset = random.choice(prompts)
122
- prompt = f"{preset}"
123
- # print(f"-------------{preset}-------------")
124
- else:
125
- # prompt = f"{prompt}, {randomize}"
126
- prompt = f"{default},{prompt}"
127
- print(f"{prompt}")
128
- return prompt
129
 
130
- @spaces.GPU(duration=11)
131
- @torch.inference_mode()
132
- def process_image(
133
- image,
134
- prompt,
135
- a_prompt,
136
- n_prompt,
137
- num_images,
138
- image_resolution,
139
- preprocess_resolution,
140
- num_steps,
141
- guidance_scale,
142
- seed,
143
- ):
144
- print("processing image")
145
- start = time.time()
146
- preprocessor.load("NormalBae")
147
- # preprocessor.load("Canny") #20 steps, 9 guidance, 512, 512
148
- control_image = preprocessor(
149
- image=image,
150
- image_resolution=image_resolution,
151
- detect_resolution=preprocess_resolution,
152
- )
153
- custom_prompt=str(get_prompt(prompt, a_prompt))
154
- negative_prompt=str(n_prompt)
155
- global compiled
156
- generator = torch.cuda.manual_seed(seed)
157
- if not compiled:
158
- print("-----------------------------------Not Compiled-----------------------------------")
159
- compiled = True
160
- results = pipe(
161
- prompt=custom_prompt,
162
- negative_prompt=negative_prompt,
163
- guidance_scale=guidance_scale,
164
- num_images_per_prompt=num_images,
165
- num_inference_steps=num_steps,
166
- generator=generator,
167
- image=control_image,
168
- ).images[0]
169
- print(f"Inference done in: {time.time() - start:.2f} seconds")
170
-
171
- timestamp = int(time.time())
172
- img_path = f"{timestamp}.jpg"
173
- results_path = f"{timestamp}_out.jpg"
174
- imageio.imsave(img_path, image)
175
- results.save(results_path)
176
-
177
- api.upload_file(
178
- path_or_fileobj=img_path,
179
- path_in_repo=img_path,
180
- repo_id="broyang/anime-ai-outputs",
181
- repo_type="dataset",
182
- token=API_KEY,
183
- run_as_future=True,
184
- )
185
- api.upload_file(
186
- path_or_fileobj=results_path,
187
- path_in_repo=results_path,
188
- repo_id="broyang/anime-ai-outputs",
189
- repo_type="dataset",
190
- token=API_KEY,
191
- run_as_future=True,
192
- )
193
-
194
- torch.cuda.empty_cache()
195
- gc.collect()
196
-
197
- results.save("temp_image.png")
198
- return results
199
 
200
  css = """
201
  h1 {
@@ -213,7 +210,6 @@ def main():
213
  footer {visibility: hidden}
214
  """
215
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
216
- #############################################################################
217
  with gr.Row():
218
  with gr.Accordion("Advanced options", open=show_options, visible=show_options):
219
  num_images = gr.Slider(
@@ -235,10 +231,10 @@ def main():
235
  )
236
  num_steps = gr.Slider(
237
  label="Number of steps", minimum=1, maximum=100, value=12, step=1
238
- ) # 20/4.5 or 12 without lora, 4 with lora
239
  guidance_scale = gr.Slider(
240
  label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
241
- ) # 5 without lora, 2 with lora
242
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
243
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
244
  a_prompt = gr.Textbox(
@@ -249,14 +245,11 @@ def main():
249
  label="Negative prompt",
250
  value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
251
  )
252
- #############################################################################
253
- # input text
254
  with gr.Column():
255
  prompt = gr.Textbox(
256
  label="Description",
257
  placeholder="Leave empty for something spicy 👀",
258
  )
259
- # input image
260
  with gr.Row():
261
  with gr.Column():
262
  image = gr.Image(
@@ -265,19 +258,16 @@ def main():
265
  show_label=True,
266
  format="webp",
267
  )
268
- # run button
269
  with gr.Column():
270
  run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
271
- # output image
272
  with gr.Column():
273
- result = gr.Image(
274
  label="Anime AI",
275
  interactive=False,
276
  format="webp",
277
  visible = True,
278
  show_share_button= False,
279
  )
280
- # Use this image button
281
  with gr.Column():
282
  use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
283
  config = [
@@ -292,16 +282,14 @@ def main():
292
  guidance_scale,
293
  seed,
294
  ]
295
- # examples = gr.Examples(examples=["./img/peleton.webp", "./img/peleton2.webp", "./img/peleton_anime.webp", "./img/peleton2_anime.webp", "./img/miku.webp"],
296
- # inputs=image,
297
- # fn=process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed),
298
- # run_on_click=True
299
- # )
300
-
301
- @gr.on(triggers=[image.upload], inputs=config, outputs=[result])
302
  def auto_process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed):
303
  return process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
304
 
 
 
 
 
305
  @gr.on(triggers=[image.upload], inputs=None, outputs=[use_ai_button, run_button])
306
  def turn_buttons_off():
307
  return gr.update(visible=False), gr.update(visible=False)
@@ -329,7 +317,7 @@ def main():
329
  api_name=False,
330
  show_progress="none",
331
  ).then(
332
- fn=process_image,
333
  inputs=config,
334
  outputs=result,
335
  api_name=False,
@@ -344,7 +332,7 @@ def main():
344
  api_name=False,
345
  show_progress="none",
346
  ).then(
347
- fn=process_image,
348
  inputs=config,
349
  outputs=result,
350
  show_progress="minimal",
@@ -353,7 +341,6 @@ def main():
353
  def update_config():
354
  try:
355
  print("Updating image to AI Temp Image")
356
- # Read the image from the file
357
  ai_temp_image = Image.open("temp_image.png")
358
  return ai_temp_image
359
  except FileNotFoundError:
@@ -373,13 +360,12 @@ def main():
373
  outputs=image,
374
  show_progress="minimal",
375
  ).then(
376
- fn=process_image,
377
  inputs=[image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed],
378
  outputs=result,
379
  show_progress="minimal",
380
  )
381
 
382
-
383
  demo.launch()
384
 
385
  if __name__ == "__main__":
@@ -392,4 +378,4 @@ if __name__ == "__main__":
392
  sortby = 'cumulative'
393
  ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
394
  ps.print_stats()
395
- print(s.getvalue())
 
1
  import cProfile
2
  import pstats
3
  import io
4
+ import gc
5
+ import random
6
+ import time
7
+ import gradio as gr
8
+ import spaces
9
+ import imageio
10
+ from huggingface_hub import HfApi
11
+ import torch
12
+ from PIL import Image
13
+ from diffusers import (
14
+ ControlNetModel,
15
+ DPMSolverMultistepScheduler,
16
+ StableDiffusionControlNetPipeline,
17
+ )
18
+ from preprocess_anime import Preprocessor, conditionally_manage_memory
19
+ from settings import API_KEY, MAX_NUM_IMAGES, MAX_SEED
20
 
21
+ preprocessor = None
22
+ controlnet = None
23
+ scheduler = None
24
+ pipe = None
25
+ api = HfApi()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
28
+ if randomize_seed:
29
+ seed = random.randint(0, MAX_SEED)
30
+ return seed
31
+
32
+ def get_additional_prompt():
33
+ prompt = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
34
+ top = ["tank top", "blouse", "button up shirt", "sweater", "corset top"]
35
+ bottom = ["short skirt", "athletic shorts", "jean shorts", "pleated skirt", "short skirt", "leggings", "high-waisted shorts"]
36
+ accessory = ["knee-high boots", "gloves", "Thigh-high stockings", "Garter belt", "choker", "necklace", "headband", "headphones"]
37
+ return f"{prompt}, {random.choice(top)}, {random.choice(bottom)}, {random.choice(accessory)}, score_9"
38
+ # outfit = ["schoolgirl outfit", "playboy outfit", "red dress", "gala dress", "cheerleader outfit", "nurse outfit", "Kimono"]
39
 
40
+ def get_prompt(prompt, additional_prompt):
41
+ default = "hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
42
+ randomize = get_additional_prompt()
43
+ nude = "NSFW,((nude)),medium bare breasts,hyperrealistic photography,extremely detailed,(intricate details),unity 8k wallpaper,ultra detailed"
44
+ bodypaint = "((fully naked with no clothes)),nude naked seethroughxray,invisiblebodypaint,rating_newd,NSFW"
45
+ lab_girl = "hyperrealistic photography, extremely detailed, shy assistant wearing minidress boots and gloves, laboratory background, score_9, 1girl"
46
+ pet_play = "hyperrealistic photography, extremely detailed, playful, blush, glasses, collar, score_9, HDA_pet_play"
47
+ bondage = "hyperrealistic photography, extremely detailed, submissive, glasses, score_9, HDA_Bondage"
48
+ ahegao = "((invisible clothing)), hyperrealistic photography,exposed vagina,sexy,nsfw,HDA_Ahegao"
49
+ ahegao2 = "(invisiblebodypaint),rating_newd,HDA_Ahegao"
50
+ athleisure = "hyperrealistic photography, extremely detailed, 1girl athlete, exhausted embarrassed sweaty,outdoors, ((athleisure clothing)), score_9"
51
+ atompunk = "((atompunk world)), hyperrealistic photography, extremely detailed, short hair, bodysuit, glasses, neon cyberpunk background, score_9"
52
+ maid = "hyperrealistic photography, extremely detailed, shy, blushing, score_9, pastel background, HDA_unconventional_maid"
53
+ nundress = "hyperrealistic photography, extremely detailed, shy, blushing, fantasy background, score_9, HDA_NunDress"
54
+ naked_hoodie = "hyperrealistic photography, extremely detailed, medium hair, cityscape, (neon lights), score_9, HDA_NakedHoodie"
55
+ abg = "(1girl, asian body covered in words, words on body, tattoos of (words) on body),(masterpiece, best quality),medium breasts,(intricate details),unity 8k wallpaper,ultra detailed,(pastel colors),beautiful and aesthetic,see-through (clothes),detailed,solo"
56
+ shibari = "extremely detailed, hyperrealistic photography, earrings, blushing, lace choker, tattoo, medium hair, score_9, HDA_Shibari"
57
+
58
+ if prompt == "":
59
+ prompts = [randomize, nude, bodypaint, pet_play, bondage, ahegao2, lab_girl, athleisure, atompunk, maid, nundress, naked_hoodie, abg, shibari]
60
+ prompts_nsfw = [nude, bodypaint, abg, ahegao2, shibari]
61
+ preset = random.choice(prompts)
62
+ prompt = f"{preset}"
63
+ # print(f"-------------{preset}-------------")
64
+ else:
65
+ # prompt = f"{prompt}, {randomize}"
66
+ prompt = f"{default},{prompt}"
67
+ print(f"{prompt}")
68
+ return prompt
69
+
70
+ def initialize_models():
71
+ global preprocessor, controlnet, scheduler, pipe
72
+ if preprocessor is None:
73
  preprocessor = Preprocessor()
74
 
75
+ if controlnet is None:
76
  model_id = "lllyasviel/control_v11p_sd15_normalbae"
77
  print("initializing controlnet")
78
  controlnet = ControlNetModel.from_pretrained(
 
80
  torch_dtype=torch.float16,
81
  attn_implementation="flash_attention_2",
82
  ).to("cuda")
83
+
84
+ if scheduler is None:
85
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
86
  "runwayml/stable-diffusion-v1-5",
87
  solver_order=2,
 
95
  device_map="cuda",
96
  )
97
 
98
+ if pipe is None:
99
  base_model_url = "https://huggingface.co/broyang/hentaidigitalart_v20/blob/main/realcartoon3d_v15.safetensors"
 
100
  pipe = StableDiffusionControlNetPipeline.from_single_file(
101
  base_model_url,
102
  safety_checker=None,
 
104
  scheduler=scheduler,
105
  torch_dtype=torch.float16,
106
  )
107
+ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="EasyNegativeV2.safetensors", token="EasyNegativeV2")
 
108
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="badhandv4.pt", token="badhandv4")
109
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Ahegao.pt", token="HDA_Ahegao")
110
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Bondage.pt", token="HDA_Bondage")
 
115
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_NunDress.pt", token="HDA_NunDress")
116
  pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shibari.pt", token="HDA_Shibari")
117
  pipe.to("cuda")
118
+ print("---------------Loaded controlnet pipeline---------------")
 
 
 
119
 
120
+ @spaces.GPU(duration=11)
121
+ @torch.inference_mode()
122
+ def process_image(
123
+ image,
124
+ prompt,
125
+ a_prompt,
126
+ n_prompt,
127
+ num_images,
128
+ image_resolution,
129
+ preprocess_resolution,
130
+ num_steps,
131
+ guidance_scale,
132
+ seed,
133
+ ):
134
+ initialize_models()
135
+ preprocessor.load("NormalBae")
136
+ control_image = preprocessor(
137
+ image=image,
138
+ image_resolution=image_resolution,
139
+ detect_resolution=preprocess_resolution,
140
+ )
141
+ custom_prompt = str(get_prompt(prompt, a_prompt))
142
+ negative_prompt = str(n_prompt)
143
+ global compiled
144
+ generator = torch.cuda.manual_seed(seed)
145
+ if not compiled:
146
+ print("-----------------------------------Not Compiled-----------------------------------")
147
+ compiled = True
148
+ start = time.time()
149
+ results = pipe(
150
+ prompt=custom_prompt,
151
+ negative_prompt=negative_prompt,
152
+ guidance_scale=guidance_scale,
153
+ num_images_per_prompt=num_images,
154
+ num_inference_steps=num_steps,
155
+ generator=generator,
156
+ image=control_image,
157
+ ).images[0]
158
+ print(f"Inference done in: {time.time() - start:.2f} seconds")
159
+
160
+ timestamp = int(time.time())
161
+ img_path = f"{timestamp}.jpg"
162
+ results_path = f"{timestamp}_out.jpg"
163
+ imageio.imsave(img_path, image)
164
+ results.save(results_path)
165
+
166
+ api.upload_file(
167
+ path_or_fileobj=img_path,
168
+ path_in_repo=img_path,
169
+ repo_id="broyang/anime-ai-outputs",
170
+ repo_type="dataset",
171
+ token=API_KEY,
172
+ run_as_future=True,
173
+ )
174
+ api.upload_file(
175
+ path_or_fileobj=results_path,
176
+ path_in_repo=results_path,
177
+ repo_id="broyang/anime-ai-outputs",
178
+ repo_type="dataset",
179
+ token=API_KEY,
180
+ run_as_future=True,
181
+ )
182
 
183
+ conditionally_manage_memory()
 
 
 
 
 
 
184
 
185
+ results.save("temp_image.png")
186
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ def main():
189
+ prod = True
190
+ show_options = True
191
+ if prod:
192
+ show_options = False
193
+
194
+ print("CUDA version:", torch.version.cuda)
195
+ print("loading pipe")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  css = """
198
  h1 {
 
210
  footer {visibility: hidden}
211
  """
212
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
 
213
  with gr.Row():
214
  with gr.Accordion("Advanced options", open=show_options, visible=show_options):
215
  num_images = gr.Slider(
 
231
  )
232
  num_steps = gr.Slider(
233
  label="Number of steps", minimum=1, maximum=100, value=12, step=1
234
+ )
235
  guidance_scale = gr.Slider(
236
  label="Guidance scale", minimum=0.1, maximum=30.0, value=5.5, step=0.1
237
+ )
238
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
239
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
240
  a_prompt = gr.Textbox(
 
245
  label="Negative prompt",
246
  value="EasyNegativeV2, fcNeg, (badhandv4:1.4), (worst quality, low quality, bad quality, normal quality:2.0), (bad hands, missing fingers, extra fingers:2.0)",
247
  )
 
 
248
  with gr.Column():
249
  prompt = gr.Textbox(
250
  label="Description",
251
  placeholder="Leave empty for something spicy 👀",
252
  )
 
253
  with gr.Row():
254
  with gr.Column():
255
  image = gr.Image(
 
258
  show_label=True,
259
  format="webp",
260
  )
 
261
  with gr.Column():
262
  run_button = gr.Button(value="Use this one", size=["lg"], visible=False)
 
263
  with gr.Column():
264
+ result = gr.Image(
265
  label="Anime AI",
266
  interactive=False,
267
  format="webp",
268
  visible = True,
269
  show_share_button= False,
270
  )
 
271
  with gr.Column():
272
  use_ai_button = gr.Button(value="Use this one", size=["lg"], visible=False)
273
  config = [
 
282
  guidance_scale,
283
  seed,
284
  ]
285
+
 
 
 
 
 
 
286
  def auto_process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed):
287
  return process_image(image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed)
288
 
289
+ @gr.on(triggers=[image.upload], inputs=config, outputs=[result])
290
+ def turn_buttons_off():
291
+ return gr.update(visible=False), gr.update(visible=False)
292
+
293
  @gr.on(triggers=[image.upload], inputs=None, outputs=[use_ai_button, run_button])
294
  def turn_buttons_off():
295
  return gr.update(visible=False), gr.update(visible=False)
 
317
  api_name=False,
318
  show_progress="none",
319
  ).then(
320
+ fn=auto_process_image,
321
  inputs=config,
322
  outputs=result,
323
  api_name=False,
 
332
  api_name=False,
333
  show_progress="none",
334
  ).then(
335
+ fn=auto_process_image,
336
  inputs=config,
337
  outputs=result,
338
  show_progress="minimal",
 
341
  def update_config():
342
  try:
343
  print("Updating image to AI Temp Image")
 
344
  ai_temp_image = Image.open("temp_image.png")
345
  return ai_temp_image
346
  except FileNotFoundError:
 
360
  outputs=image,
361
  show_progress="minimal",
362
  ).then(
363
+ fn=auto_process_image,
364
  inputs=[image, prompt, a_prompt, n_prompt, num_images, image_resolution, preprocess_resolution, num_steps, guidance_scale, seed],
365
  outputs=result,
366
  show_progress="minimal",
367
  )
368
 
 
369
  demo.launch()
370
 
371
  if __name__ == "__main__":
 
378
  sortby = 'cumulative'
379
  ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
380
  ps.print_stats()
381
+ print(s.getvalue())