nseq commited on
Commit
c1c3eae
·
verified ·
1 Parent(s): 9cc10e6

Update memory_management.py

Browse files
Files changed (1) hide show
  1. memory_management.py +114 -156
memory_management.py CHANGED
@@ -9,8 +9,9 @@ import platform
9
  from enum import Enum
10
  from backend import stream, utils
11
  from backend.args import args
12
- from modules_forge.main_thread import gpu_sync_lock, unload_complete, load_complete, current_gpu_id
13
 
 
14
 
15
  cpu = torch.device('cpu')
16
 
@@ -74,11 +75,6 @@ except:
74
  if args.always_cpu:
75
  cpu_state = CPUState.CPU
76
 
77
- def get_current_gpu_id():
78
- if torch.cuda.is_available():
79
- return torch.cuda.current_device()
80
- return None
81
-
82
 
83
  def is_intel_xpu():
84
  global cpu_state
@@ -566,66 +562,45 @@ def unload_model_clones(model):
566
 
567
 
568
  def free_memory(memory_required, device, keep_loaded=[], free_all=False):
569
- # this check fully unloads any 'abandoned' models
570
- global current_gpu_id, unload_complete, load_complete
571
-
572
- with gpu_sync_lock:
573
- gpu_id = get_current_gpu_id()
574
- if current_gpu_id is None:
575
- current_gpu_id = gpu_id
576
- elif current_gpu_id != gpu_id:
577
- # Wait for our turn
578
- while current_gpu_id != gpu_id:
579
- time.sleep(0.1)
580
-
581
- for i in range(len(current_loaded_models) - 1, -1, -1):
582
- if sys.getrefcount(current_loaded_models[i].model) <= 2:
583
- current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
- if free_all:
586
- memory_required = 1e30
587
- print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
588
- else:
589
- print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
590
-
591
- offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
592
- unloaded_model = False
593
- for i in range(len(current_loaded_models) - 1, -1, -1):
594
- if not offload_everything:
595
- free_memory = get_free_memory(device)
596
- print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
597
- if free_memory > memory_required:
598
- break
599
- shift_model = current_loaded_models[i]
600
- if shift_model.device == device:
601
- if shift_model not in keep_loaded:
602
- m = current_loaded_models.pop(i)
603
- print(f"Unload model {m.model.model.__class__.__name__} ", end="")
604
- m.model_unload()
605
- del m
606
- unloaded_model = True
607
-
608
- if unloaded_model:
609
- soft_empty_cache()
610
- else:
611
- if vram_state != VRAMState.HIGH_VRAM:
612
- mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
613
- if mem_free_torch > mem_free_total * 0.25:
614
- soft_empty_cache()
615
-
616
- print('Done.')
617
-
618
- if gpu_id == 0: # First GPU
619
- unload_complete.set()
620
- # Wait for second GPU to complete unloading
621
- while not load_complete.is_set():
622
- time.sleep(0.1)
623
- else: # Second GPU
624
- # Wait for first GPU to complete unloading
625
- while not unload_complete.is_set():
626
- time.sleep(0.1)
627
- load_complete.set()
628
- return
629
 
630
 
631
  def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
@@ -640,101 +615,84 @@ def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_mem
640
 
641
 
642
  def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
643
- global vram_state
644
- global current_gpu_id, unload_complete, load_complete
645
-
646
- gpu_id = get_current_gpu_id()
647
-
648
- # Wait for unloading to complete on both GPUs
649
- if gpu_id == 1: # Second GPU
650
- while not unload_complete.is_set() or not load_complete.is_set():
651
- time.sleep(0.1)
652
-
653
- with gpu_sync_lock:
654
- if current_gpu_id is None or current_gpu_id == gpu_id:
655
-
656
- execution_start_time = time.perf_counter()
657
- memory_to_free = max(minimum_inference_memory(), memory_required) + hard_memory_preservation
658
- memory_for_inference = minimum_inference_memory() + hard_memory_preservation
659
-
660
- models_to_load = []
661
- models_already_loaded = []
662
- for x in models:
663
- loaded_model = LoadedModel(x)
664
-
665
- if loaded_model in current_loaded_models:
666
- index = current_loaded_models.index(loaded_model)
667
- current_loaded_models.insert(0, current_loaded_models.pop(index))
668
- models_already_loaded.append(loaded_model)
669
- else:
670
- models_to_load.append(loaded_model)
671
-
672
- if len(models_to_load) == 0:
673
- devs = set(map(lambda a: a.device, models_already_loaded))
674
- for d in devs:
675
- if d != torch.device("cpu"):
676
- free_memory(memory_to_free, d, models_already_loaded)
677
-
678
- moving_time = time.perf_counter() - execution_start_time
679
- if moving_time > 0.1:
680
- print(f'Memory cleanup has taken {moving_time:.2f} seconds')
681
-
682
- return
683
-
684
- for loaded_model in models_to_load:
685
- unload_model_clones(loaded_model.model)
686
-
687
- total_memory_required = {}
688
- for loaded_model in models_to_load:
689
- loaded_model.compute_inclusive_exclusive_memory()
690
- total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
691
-
692
- for device in total_memory_required:
693
- if device != torch.device("cpu"):
694
- free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded)
695
-
696
- for loaded_model in models_to_load:
697
- model = loaded_model.model
698
- torch_dev = model.load_device
699
- if is_device_cpu(torch_dev):
700
- vram_set_state = VRAMState.DISABLED
701
- else:
702
- vram_set_state = vram_state
703
-
704
- model_gpu_memory_when_using_cpu_swap = -1
705
-
706
- if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
707
- model_require = loaded_model.exclusive_memory
708
- previously_loaded = loaded_model.inclusive_memory
709
- current_free_mem = get_free_memory(torch_dev)
710
- estimated_remaining_memory = current_free_mem - model_require - memory_for_inference
711
-
712
- print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
713
-
714
- if estimated_remaining_memory < 0:
715
- vram_set_state = VRAMState.LOW_VRAM
716
- model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
717
- if previously_loaded > 0:
718
- model_gpu_memory_when_using_cpu_swap = previously_loaded
719
-
720
- if vram_set_state == VRAMState.NO_VRAM:
721
- model_gpu_memory_when_using_cpu_swap = 0
722
-
723
- loaded_model.model_load(model_gpu_memory_when_using_cpu_swap)
724
- current_loaded_models.insert(0, loaded_model)
725
 
726
  moving_time = time.perf_counter() - execution_start_time
727
- print(f'Moving model(s) has taken {moving_time:.2f} seconds')
728
-
729
- if gpu_id == 0: # First GPU
730
- current_gpu_id = 1 # Signal second GPU to start
731
- else: # Second GPU
732
- # Reset synchronization
733
- current_gpu_id = None
734
- unload_complete.clear()
735
- load_complete.clear()
736
  return
737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
 
739
  def load_model_gpu(model):
740
  return load_models_gpu([model])
 
9
  from enum import Enum
10
  from backend import stream, utils
11
  from backend.args import args
12
+ import threading
13
 
14
+ global_model_lock = threading.Lock()
15
 
16
  cpu = torch.device('cpu')
17
 
 
75
  if args.always_cpu:
76
  cpu_state = CPUState.CPU
77
 
 
 
 
 
 
78
 
79
  def is_intel_xpu():
80
  global cpu_state
 
562
 
563
 
564
  def free_memory(memory_required, device, keep_loaded=[], free_all=False):
565
+ with global_model_lock:
566
+ # this check fully unloads any 'abandoned' models
567
+ for i in range(len(current_loaded_models) - 1, -1, -1):
568
+ if sys.getrefcount(current_loaded_models[i].model) <= 2:
569
+ current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
570
+
571
+ if free_all:
572
+ memory_required = 1e30
573
+ print(f"[Unload] Trying to free all memory for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
574
+ else:
575
+ print(f"[Unload] Trying to free {memory_required / (1024 * 1024):.2f} MB for {device} with {len(keep_loaded)} models keep loaded ... ", end="")
576
+
577
+ offload_everything = ALWAYS_VRAM_OFFLOAD or vram_state == VRAMState.NO_VRAM
578
+ unloaded_model = False
579
+ for i in range(len(current_loaded_models) - 1, -1, -1):
580
+ if not offload_everything:
581
+ free_memory = get_free_memory(device)
582
+ print(f"Current free memory is {free_memory / (1024 * 1024):.2f} MB ... ", end="")
583
+ if free_memory > memory_required:
584
+ break
585
+ shift_model = current_loaded_models[i]
586
+ if shift_model.device == device:
587
+ if shift_model not in keep_loaded:
588
+ m = current_loaded_models.pop(i)
589
+ print(f"Unload model {m.model.model.__class__.__name__} ", end="")
590
+ m.model_unload()
591
+ del m
592
+ unloaded_model = True
593
+
594
+ if unloaded_model:
595
+ soft_empty_cache()
596
+ else:
597
+ if vram_state != VRAMState.HIGH_VRAM:
598
+ mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
599
+ if mem_free_torch > mem_free_total * 0.25:
600
+ soft_empty_cache()
601
 
602
+ print('Done.')
603
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
605
 
606
  def compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, inference_memory):
 
615
 
616
 
617
  def load_models_gpu(models, memory_required=0, hard_memory_preservation=0):
618
+ with global_model_lock: # Add this line
619
+ global vram_state
620
+
621
+ execution_start_time = time.perf_counter()
622
+ memory_to_free = max(minimum_inference_memory(), memory_required) + hard_memory_preservation
623
+ memory_for_inference = minimum_inference_memory() + hard_memory_preservation
624
+
625
+ models_to_load = []
626
+ models_already_loaded = []
627
+ for x in models:
628
+ loaded_model = LoadedModel(x)
629
+
630
+ if loaded_model in current_loaded_models:
631
+ index = current_loaded_models.index(loaded_model)
632
+ current_loaded_models.insert(0, current_loaded_models.pop(index))
633
+ models_already_loaded.append(loaded_model)
634
+ else:
635
+ models_to_load.append(loaded_model)
636
+
637
+ if len(models_to_load) == 0:
638
+ devs = set(map(lambda a: a.device, models_already_loaded))
639
+ for d in devs:
640
+ if d != torch.device("cpu"):
641
+ free_memory(memory_to_free, d, models_already_loaded)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
 
643
  moving_time = time.perf_counter() - execution_start_time
644
+ if moving_time > 0.1:
645
+ print(f'Memory cleanup has taken {moving_time:.2f} seconds')
646
+
 
 
 
 
 
 
647
  return
648
 
649
+ for loaded_model in models_to_load:
650
+ unload_model_clones(loaded_model.model)
651
+
652
+ total_memory_required = {}
653
+ for loaded_model in models_to_load:
654
+ loaded_model.compute_inclusive_exclusive_memory()
655
+ total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.exclusive_memory + loaded_model.inclusive_memory * 0.25
656
+
657
+ for device in total_memory_required:
658
+ if device != torch.device("cpu"):
659
+ free_memory(total_memory_required[device] * 1.3 + memory_to_free, device, models_already_loaded)
660
+
661
+ for loaded_model in models_to_load:
662
+ model = loaded_model.model
663
+ torch_dev = model.load_device
664
+ if is_device_cpu(torch_dev):
665
+ vram_set_state = VRAMState.DISABLED
666
+ else:
667
+ vram_set_state = vram_state
668
+
669
+ model_gpu_memory_when_using_cpu_swap = -1
670
+
671
+ if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
672
+ model_require = loaded_model.exclusive_memory
673
+ previously_loaded = loaded_model.inclusive_memory
674
+ current_free_mem = get_free_memory(torch_dev)
675
+ estimated_remaining_memory = current_free_mem - model_require - memory_for_inference
676
+
677
+ print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_require / (1024 * 1024):.2f} MB, Previously Loaded: {previously_loaded / (1024 * 1024):.2f} MB, Inference Require: {memory_for_inference / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="")
678
+
679
+ if estimated_remaining_memory < 0:
680
+ vram_set_state = VRAMState.LOW_VRAM
681
+ model_gpu_memory_when_using_cpu_swap = compute_model_gpu_memory_when_using_cpu_swap(current_free_mem, memory_for_inference)
682
+ if previously_loaded > 0:
683
+ model_gpu_memory_when_using_cpu_swap = previously_loaded
684
+
685
+ if vram_set_state == VRAMState.NO_VRAM:
686
+ model_gpu_memory_when_using_cpu_swap = 0
687
+
688
+ loaded_model.model_load(model_gpu_memory_when_using_cpu_swap)
689
+ current_loaded_models.insert(0, loaded_model)
690
+
691
+ moving_time = time.perf_counter() - execution_start_time
692
+ print(f'Moving model(s) has taken {moving_time:.2f} seconds')
693
+
694
+ return
695
+
696
 
697
  def load_model_gpu(model):
698
  return load_models_gpu([model])