import time import torch import contextlib from ldm_patched.modules import model_management from ldm_patched.modules.ops import use_patched_ops @contextlib.contextmanager def automatic_memory_management(): model_management.free_memory( memory_required=3 * 1024 * 1024 * 1024, device=model_management.get_torch_device() ) module_list = [] original_init = torch.nn.Module.__init__ original_to = torch.nn.Module.to def patched_init(self, *args, **kwargs): module_list.append(self) return original_init(self, *args, **kwargs) def patched_to(self, *args, **kwargs): module_list.append(self) return original_to(self, *args, **kwargs) try: torch.nn.Module.__init__ = patched_init torch.nn.Module.to = patched_to yield finally: torch.nn.Module.__init__ = original_init torch.nn.Module.to = original_to start = time.perf_counter() module_list = set(module_list) for module in module_list: module.cpu() model_management.soft_empty_cache() end = time.perf_counter() print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') return