import torch from contextlib import contextmanager high_vram = False gpu = torch.device('cuda') cpu = torch.device('cpu') torch.zeros((1, 1)).to(gpu, torch.float32) torch.cuda.empty_cache() models_in_gpu = [] @contextmanager def movable_bnb_model(m): if hasattr(m, 'quantization_method'): m.quantization_method_backup = m.quantization_method del m.quantization_method try: yield None finally: if hasattr(m, 'quantization_method_backup'): m.quantization_method = m.quantization_method_backup del m.quantization_method_backup return def load_models_to_gpu(models): global models_in_gpu if not isinstance(models, (tuple, list)): models = [models] models_to_remain = [m for m in set(models) if m in models_in_gpu] models_to_load = [m for m in set(models) if m not in models_in_gpu] models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain] if not high_vram: for m in models_to_unload: with movable_bnb_model(m): m.to(cpu) print('Unload to CPU:', m.__class__.__name__) models_in_gpu = models_to_remain for m in models_to_load: with movable_bnb_model(m): m.to(gpu) print('Load to GPU:', m.__class__.__name__) models_in_gpu = list(set(models_in_gpu + models)) torch.cuda.empty_cache() return def unload_all_models(extra_models=None): global models_in_gpu if extra_models is None: extra_models = [] if not isinstance(extra_models, (tuple, list)): extra_models = [extra_models] models_in_gpu = list(set(models_in_gpu + extra_models)) return load_models_to_gpu([])