import torch import gc from torch import nn from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module import bitsandbytes as bnb def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device('cuda'): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() def restart_cpu_offload(pipe, load_mode): #if load_mode != '4bit' : # pipe.disable_xformers_memory_efficient_attention() optionally_disable_offloading(pipe) gc.collect() torch.cuda.empty_cache() pipe.enable_model_cpu_offload() #if load_mode != '4bit' : # pipe.enable_xformers_memory_efficient_attention() def optionally_disable_offloading(_pipeline): """ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. Args: _pipeline (`DiffusionPipeline`): The pipeline to disable offloading for. Returns: tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ is_model_cpu_offload = False is_sequential_cpu_offload = False print( fr"Restarting CPU Offloading for {_pipeline.unet_name}..." ) if _pipeline is not None: for _, component in _pipeline.components.items(): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if not is_model_cpu_offload: is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) if not is_sequential_cpu_offload: is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook) remove_hook_from_module(component, recurse=True) return (is_model_cpu_offload, is_sequential_cpu_offload) def quantize_4bit(module): for name, child in module.named_children(): if isinstance(child, torch.nn.Linear): in_features = child.in_features out_features = child.out_features device = child.weight.data.device # Create and configure the Linear layer has_bias = True if child.bias is not None else False # TODO: Make that configurable # fp16 for compute dtype leads to faster inference # and one should almost always use nf4 as a rule of thumb bnb_4bit_compute_dtype = torch.float16 quant_type = "nf4" new_layer = bnb.nn.Linear4bit( in_features, out_features, bias=has_bias, compute_dtype=bnb_4bit_compute_dtype, quant_type=quant_type, ) new_layer.load_state_dict(child.state_dict()) new_layer = new_layer.to(device) # Set the attribute setattr(module, name, new_layer) else: # Recursively apply to child modules quantize_4bit(child)