Update memory_management.py
Browse files- memory_management.py +114 -156
memory_management.py
CHANGED
@@ -9,8 +9,9 @@ import platform
|
|
9 |
from enum import Enum
|
10 |
from backend import stream, utils
|
11 |
from backend.args import args
|
12 |
-
|
13 |
|
|
|
14 |
|
15 |
cpu = torch.device('cpu')
|
16 |
|
@@ -74,11 +75,6 @@ except:
|
|
74 |
if args.always_cpu:
|
75 |
cpu_state = CPUState.CPU
|
76 |
|
77 |
-
def get_current_gpu_id():
|
78 |
-
if torch.cuda.is_available():
|
79 |
-
return torch.cuda.current_device()
|
80 |
-
return None
|
81 |
-
|
82 |
|
83 |
def is_intel_xpu():
|
84 |
global cpu_state
|
@@ -566,66 +562,45 @@ def unload_model_clones(model):
|
|
566 |
|
567 |
|
568 |
def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
584 |
|
585 |
-
|
586 |
-
|
587 |
-
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
588 |
-
else:
|
589 |
-
print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
590 |
-
|
591 |
-
offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
|
592 |
-
unloaded_model = False
|
593 |
-
for i in range(len(current_loaded_models) - 1, -1, -1):
|
594 |
-
if not offload_everything:
|
595 |
-
free_memory = get_free_memory(device)
|
596 |
-
print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
|
597 |
-
if free_memory > memory_required:
|
598 |
-
break
|
599 |
-
shift_model = current_loaded_models[i]
|
600 |
-
if shift_model.device == device:
|
601 |
-
if shift_model not in keep_loaded:
|
602 |
-
m = current_loaded_models.pop(i)
|
603 |
-
print(f"Unload model {m.model.model.__class__.__name__} ", end="")
|
604 |
-
m.model_unload()
|
605 |
-
del m
|
606 |
-
unloaded_model = True
|
607 |
-
|
608 |
-
if unloaded_model:
|
609 |
-
soft_empty_cache()
|
610 |
-
else:
|
611 |
-
if vram_state != VRAMState.HIGH_VRAM:
|
612 |
-
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
613 |
-
if mem_free_torch > mem_free_total * 0.25:
|
614 |
-
soft_empty_cache()
|
615 |
-
|
616 |
-
print('Done.')
|
617 |
-
|
618 |
-
if gpu_id == 0: # First GPU
|
619 |
-
unload_complete.set()
|
620 |
-
# Wait for second GPU to complete unloading
|
621 |
-
while not load_complete.is_set():
|
622 |
-
time.sleep(0.1)
|
623 |
-
else: # Second GPU
|
624 |
-
# Wait for first GPU to complete unloading
|
625 |
-
while not unload_complete.is_set():
|
626 |
-
time.sleep(0.1)
|
627 |
-
load_complete.set()
|
628 |
-
return
|
629 |
|
630 |
|
631 |
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
|
@@ -640,101 +615,84 @@ def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_mem
|
|
640 |
|
641 |
|
642 |
def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
if
|
666 |
-
|
667 |
-
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
668 |
-
models_already_loaded.append(loaded_model)
|
669 |
-
else:
|
670 |
-
models_to_load.append(loaded_model)
|
671 |
-
|
672 |
-
if len(models_to_load) == 0:
|
673 |
-
devs = set(map(lambda a: a.device, models_already_loaded))
|
674 |
-
for d in devs:
|
675 |
-
if d != torch.device("cpu"):
|
676 |
-
free_memory(memory_to_free, d, models_already_loaded)
|
677 |
-
|
678 |
-
moving_time = time.perf_counter() - execution_start_time
|
679 |
-
if moving_time > 0.1:
|
680 |
-
print(f'Memory cleanup has taken {moving_time:.2f} seconds')
|
681 |
-
|
682 |
-
return
|
683 |
-
|
684 |
-
for loaded_model in models_to_load:
|
685 |
-
unload_model_clones(loaded_model.model)
|
686 |
-
|
687 |
-
total_memory_required = {}
|
688 |
-
for loaded_model in models_to_load:
|
689 |
-
loaded_model.compute_inclusive_exclusive_memory()
|
690 |
-
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
|
691 |
-
|
692 |
-
for device in total_memory_required:
|
693 |
-
if device != torch.device("cpu"):
|
694 |
-
free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded)
|
695 |
-
|
696 |
-
for loaded_model in models_to_load:
|
697 |
-
model = loaded_model.model
|
698 |
-
torch_dev = model.load_device
|
699 |
-
if is_device_cpu(torch_dev):
|
700 |
-
vram_set_state = VRAMState.DISABLED
|
701 |
-
else:
|
702 |
-
vram_set_state = vram_state
|
703 |
-
|
704 |
-
model_gpu_memory_when_using_cpu_swap = -1
|
705 |
-
|
706 |
-
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
707 |
-
model_require = loaded_model.exclusive_memory
|
708 |
-
previously_loaded = loaded_model.inclusive_memory
|
709 |
-
current_free_mem = get_free_memory(torch_dev)
|
710 |
-
estimated_remaining_memory = current_free_mem - model_require - memory_for_inference
|
711 |
-
|
712 |
-
print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
|
713 |
-
|
714 |
-
if estimated_remaining_memory < 0:
|
715 |
-
vram_set_state = VRAMState.LOW_VRAM
|
716 |
-
model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
|
717 |
-
if previously_loaded > 0:
|
718 |
-
model_gpu_memory_when_using_cpu_swap = previously_loaded
|
719 |
-
|
720 |
-
if vram_set_state == VRAMState.NO_VRAM:
|
721 |
-
model_gpu_memory_when_using_cpu_swap = 0
|
722 |
-
|
723 |
-
loaded_model.model_load(model_gpu_memory_when_using_cpu_swap)
|
724 |
-
current_loaded_models.insert(0, loaded_model)
|
725 |
|
726 |
moving_time = time.perf_counter() - execution_start_time
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
current_gpu_id = 1 # Signal second GPU to start
|
731 |
-
else: # Second GPU
|
732 |
-
# Reset synchronization
|
733 |
-
current_gpu_id = None
|
734 |
-
unload_complete.clear()
|
735 |
-
load_complete.clear()
|
736 |
return
|
737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
738 |
|
739 |
def load_model_gpu(model):
|
740 |
return load_models_gpu([model])
|
|
|
9 |
from enum import Enum
|
10 |
from backend import stream, utils
|
11 |
from backend.args import args
|
12 |
+
import threading
|
13 |
|
14 |
+
global_model_lock = threading.Lock()
|
15 |
|
16 |
cpu = torch.device('cpu')
|
17 |
|
|
|
75 |
if args.always_cpu:
|
76 |
cpu_state = CPUState.CPU
|
77 |
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def is_intel_xpu():
|
80 |
global cpu_state
|
|
|
562 |
|
563 |
|
564 |
def free_memory(memory_required, device, keep_loaded=[], free_all=False):
|
565 |
+
with global_model_lock:
|
566 |
+
# this check fully unloads any 'abandoned' models
|
567 |
+
for i in range(len(current_loaded_models) - 1, -1, -1):
|
568 |
+
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
569 |
+
current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
|
570 |
+
|
571 |
+
if free_all:
|
572 |
+
memory_required = 1e30
|
573 |
+
print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
574 |
+
else:
|
575 |
+
print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
|
576 |
+
|
577 |
+
offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
|
578 |
+
unloaded_model = False
|
579 |
+
for i in range(len(current_loaded_models) - 1, -1, -1):
|
580 |
+
if not offload_everything:
|
581 |
+
free_memory = get_free_memory(device)
|
582 |
+
print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
|
583 |
+
if free_memory > memory_required:
|
584 |
+
break
|
585 |
+
shift_model = current_loaded_models[i]
|
586 |
+
if shift_model.device == device:
|
587 |
+
if shift_model not in keep_loaded:
|
588 |
+
m = current_loaded_models.pop(i)
|
589 |
+
print(f"Unload model {m.model.model.__class__.__name__} ", end="")
|
590 |
+
m.model_unload()
|
591 |
+
del m
|
592 |
+
unloaded_model = True
|
593 |
+
|
594 |
+
if unloaded_model:
|
595 |
+
soft_empty_cache()
|
596 |
+
else:
|
597 |
+
if vram_state != VRAMState.HIGH_VRAM:
|
598 |
+
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
599 |
+
if mem_free_torch > mem_free_total * 0.25:
|
600 |
+
soft_empty_cache()
|
601 |
|
602 |
+
print('Done.')
|
603 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
|
605 |
|
606 |
def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
|
|
|
615 |
|
616 |
|
617 |
def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
|
618 |
+
with global_model_lock: # Add this line
|
619 |
+
global vram_state
|
620 |
+
|
621 |
+
execution_start_time = time.perf_counter()
|
622 |
+
memory_to_free = max(minimum_inference_memory(), memory_required) + hard_memory_preservation
|
623 |
+
memory_for_inference = minimum_inference_memory() + hard_memory_preservation
|
624 |
+
|
625 |
+
models_to_load = []
|
626 |
+
models_already_loaded = []
|
627 |
+
for x in models:
|
628 |
+
loaded_model = LoadedModel(x)
|
629 |
+
|
630 |
+
if loaded_model in current_loaded_models:
|
631 |
+
index = current_loaded_models.index(loaded_model)
|
632 |
+
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
633 |
+
models_already_loaded.append(loaded_model)
|
634 |
+
else:
|
635 |
+
models_to_load.append(loaded_model)
|
636 |
+
|
637 |
+
if len(models_to_load) == 0:
|
638 |
+
devs = set(map(lambda a: a.device, models_already_loaded))
|
639 |
+
for d in devs:
|
640 |
+
if d != torch.device("cpu"):
|
641 |
+
free_memory(memory_to_free, d, models_already_loaded)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
642 |
|
643 |
moving_time = time.perf_counter() - execution_start_time
|
644 |
+
if moving_time > 0.1:
|
645 |
+
print(f'Memory cleanup has taken {moving_time:.2f} seconds')
|
646 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
647 |
return
|
648 |
|
649 |
+
for loaded_model in models_to_load:
|
650 |
+
unload_model_clones(loaded_model.model)
|
651 |
+
|
652 |
+
total_memory_required = {}
|
653 |
+
for loaded_model in models_to_load:
|
654 |
+
loaded_model.compute_inclusive_exclusive_memory()
|
655 |
+
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
|
656 |
+
|
657 |
+
for device in total_memory_required:
|
658 |
+
if device != torch.device("cpu"):
|
659 |
+
free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded)
|
660 |
+
|
661 |
+
for loaded_model in models_to_load:
|
662 |
+
model = loaded_model.model
|
663 |
+
torch_dev = model.load_device
|
664 |
+
if is_device_cpu(torch_dev):
|
665 |
+
vram_set_state = VRAMState.DISABLED
|
666 |
+
else:
|
667 |
+
vram_set_state = vram_state
|
668 |
+
|
669 |
+
model_gpu_memory_when_using_cpu_swap = -1
|
670 |
+
|
671 |
+
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
672 |
+
model_require = loaded_model.exclusive_memory
|
673 |
+
previously_loaded = loaded_model.inclusive_memory
|
674 |
+
current_free_mem = get_free_memory(torch_dev)
|
675 |
+
estimated_remaining_memory = current_free_mem - model_require - memory_for_inference
|
676 |
+
|
677 |
+
print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
|
678 |
+
|
679 |
+
if estimated_remaining_memory < 0:
|
680 |
+
vram_set_state = VRAMState.LOW_VRAM
|
681 |
+
model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
|
682 |
+
if previously_loaded > 0:
|
683 |
+
model_gpu_memory_when_using_cpu_swap = previously_loaded
|
684 |
+
|
685 |
+
if vram_set_state == VRAMState.NO_VRAM:
|
686 |
+
model_gpu_memory_when_using_cpu_swap = 0
|
687 |
+
|
688 |
+
loaded_model.model_load(model_gpu_memory_when_using_cpu_swap)
|
689 |
+
current_loaded_models.insert(0, loaded_model)
|
690 |
+
|
691 |
+
moving_time = time.perf_counter() - execution_start_time
|
692 |
+
print(f'Moving model(s) has taken {moving_time:.2f} seconds')
|
693 |
+
|
694 |
+
return
|
695 |
+
|
696 |
|
697 |
def load_model_gpu(model):
|
698 |
return load_models_gpu([model])
|