Spaces:
Running
on
A100
Running
on
A100
add use low vram option
Browse files
app.py
CHANGED
@@ -34,7 +34,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
34 |
print(device)
|
35 |
|
36 |
# Flag for low VRAM usage
|
37 |
-
low_vram = False
|
38 |
|
39 |
# Function definition for low VRAM usage
|
40 |
def models_to(model, device="cpu", excepts=None):
|
@@ -107,11 +107,13 @@ models_b = WurstCoreB.Models(
|
|
107 |
)
|
108 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
109 |
|
|
|
110 |
if low_vram:
|
111 |
# Off-load old generator (which is not used in models_rbm)
|
112 |
models.generator.to("cpu")
|
113 |
torch.cuda.empty_cache()
|
114 |
gc.collect()
|
|
|
115 |
|
116 |
generator_rbm = StageCRBM()
|
117 |
for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
|
@@ -128,10 +130,10 @@ models_rbm.generator.eval().requires_grad_(False)
|
|
128 |
|
129 |
|
130 |
|
131 |
-
def infer(ref_style_file, style_description, caption, progress):
|
132 |
global models_rbm, models_b, device
|
133 |
|
134 |
-
if
|
135 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
136 |
try:
|
137 |
|
@@ -167,7 +169,7 @@ def infer(ref_style_file, style_description, caption, progress):
|
|
167 |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
168 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
169 |
|
170 |
-
if
|
171 |
# The sampling process uses more vram, so we offload everything except two modules to the cpu.
|
172 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
173 |
|
@@ -236,10 +238,10 @@ def infer(ref_style_file, style_description, caption, progress):
|
|
236 |
torch.cuda.empty_cache()
|
237 |
gc.collect()
|
238 |
|
239 |
-
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progress):
|
240 |
global models_rbm, models_b, device
|
241 |
sam_model = LangSAM()
|
242 |
-
if
|
243 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
244 |
models_to(sam_model, device=device)
|
245 |
models_to(sam_model.sam, device=device)
|
@@ -288,7 +290,7 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
|
|
288 |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
289 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
290 |
|
291 |
-
if
|
292 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
293 |
models_to(sam_model, device="cpu")
|
294 |
models_to(sam_model.sam, device="cpu")
|
@@ -363,13 +365,13 @@ def infer_compo(style_description, ref_style_file, caption, ref_sub_file, progre
|
|
363 |
torch.cuda.empty_cache()
|
364 |
gc.collect()
|
365 |
|
366 |
-
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref):
|
367 |
result = None
|
368 |
progress = gr.Progress(track_tqdm=True)
|
369 |
if use_subject_ref is True:
|
370 |
-
result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference, progress)
|
371 |
else:
|
372 |
-
result = infer(style_reference_image, style_description, subject_prompt, progress)
|
373 |
return result
|
374 |
|
375 |
def show_hide_subject_image_component(use_subject_ref):
|
@@ -406,7 +408,9 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
406 |
subject_prompt = gr.Textbox(
|
407 |
label = "Subject Prompt"
|
408 |
)
|
409 |
-
|
|
|
|
|
410 |
|
411 |
with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
|
412 |
subject_reference = gr.Image(label="Subject Reference", type="filepath")
|
@@ -418,13 +422,13 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
418 |
output_image = gr.Image(label="Output Image")
|
419 |
gr.Examples(
|
420 |
examples = [
|
421 |
-
["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False],
|
422 |
-
["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False],
|
423 |
-
["./data/glowing.png", "glowing style", "a dwarf", None, False],
|
424 |
-
["./data/melting_gold.png", "melting golden 3D rendering style", "a dog", "./data/dog.jpg", True]
|
425 |
],
|
426 |
fn=run,
|
427 |
-
inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref],
|
428 |
outputs=[output_image],
|
429 |
cache_examples=False
|
430 |
|
@@ -439,7 +443,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
|
|
439 |
|
440 |
submit_btn.click(
|
441 |
fn = run,
|
442 |
-
inputs = [style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref],
|
443 |
outputs = [output_image],
|
444 |
show_api = False
|
445 |
)
|
|
|
34 |
print(device)
|
35 |
|
36 |
# Flag for low VRAM usage
|
37 |
+
# low_vram = False
|
38 |
|
39 |
# Function definition for low VRAM usage
|
40 |
def models_to(model, device="cpu", excepts=None):
|
|
|
107 |
)
|
108 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
109 |
|
110 |
+
"""
|
111 |
if low_vram:
|
112 |
# Off-load old generator (which is not used in models_rbm)
|
113 |
models.generator.to("cpu")
|
114 |
torch.cuda.empty_cache()
|
115 |
gc.collect()
|
116 |
+
"""
|
117 |
|
118 |
generator_rbm = StageCRBM()
|
119 |
for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
|
|
|
130 |
|
131 |
|
132 |
|
133 |
+
def infer(ref_style_file, style_description, caption, use_low_vram, progress):
|
134 |
global models_rbm, models_b, device
|
135 |
|
136 |
+
if use_low_vram:
|
137 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
138 |
try:
|
139 |
|
|
|
169 |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
170 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
171 |
|
172 |
+
if use_low_vram:
|
173 |
# The sampling process uses more vram, so we offload everything except two modules to the cpu.
|
174 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
175 |
|
|
|
238 |
torch.cuda.empty_cache()
|
239 |
gc.collect()
|
240 |
|
241 |
+
def infer_compo(style_description, ref_style_file, caption, ref_sub_file, use_low_vram, progress):
|
242 |
global models_rbm, models_b, device
|
243 |
sam_model = LangSAM()
|
244 |
+
if use_low_vram:
|
245 |
models_to(models_rbm, device=device, excepts=["generator", "previewer"])
|
246 |
models_to(sam_model, device=device)
|
247 |
models_to(sam_model.sam, device=device)
|
|
|
290 |
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
|
291 |
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
|
292 |
|
293 |
+
if use_low_vram:
|
294 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
295 |
models_to(sam_model, device="cpu")
|
296 |
models_to(sam_model.sam, device="cpu")
|
|
|
365 |
torch.cuda.empty_cache()
|
366 |
gc.collect()
|
367 |
|
368 |
+
def run(style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram):
|
369 |
result = None
|
370 |
progress = gr.Progress(track_tqdm=True)
|
371 |
if use_subject_ref is True:
|
372 |
+
result = infer_compo(style_description, style_reference_image, subject_prompt, subject_reference, use_low_vram, progress)
|
373 |
else:
|
374 |
+
result = infer(style_reference_image, style_description, subject_prompt, use_low_vram, progress)
|
375 |
return result
|
376 |
|
377 |
def show_hide_subject_image_component(use_subject_ref):
|
|
|
408 |
subject_prompt = gr.Textbox(
|
409 |
label = "Subject Prompt"
|
410 |
)
|
411 |
+
with gr.Row():
|
412 |
+
use_subject_ref = gr.Checkbox(label="Use Subject Image as Reference", value=False)
|
413 |
+
use_low_vram = gr.Checkbox(label="Use Low-VRAM", value=False)
|
414 |
|
415 |
with gr.Accordion("Advanced Settings", open=False) as sub_img_panel:
|
416 |
subject_reference = gr.Image(label="Subject Reference", type="filepath")
|
|
|
422 |
output_image = gr.Image(label="Output Image")
|
423 |
gr.Examples(
|
424 |
examples = [
|
425 |
+
["./data/cyberpunk.png", "cyberpunk art style", "a car", None, False, False],
|
426 |
+
["./data/mosaic.png", "mosaic art style", "a lighthouse", None, False, False],
|
427 |
+
["./data/glowing.png", "glowing style", "a dwarf", None, False, False],
|
428 |
+
["./data/melting_gold.png", "melting golden 3D rendering style", "a dog", "./data/dog.jpg", True, False]
|
429 |
],
|
430 |
fn=run,
|
431 |
+
inputs=[style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
|
432 |
outputs=[output_image],
|
433 |
cache_examples=False
|
434 |
|
|
|
443 |
|
444 |
submit_btn.click(
|
445 |
fn = run,
|
446 |
+
inputs = [style_reference_image, style_description, subject_prompt, subject_reference, use_subject_ref, use_low_vram],
|
447 |
outputs = [output_image],
|
448 |
show_api = False
|
449 |
)
|