|
|
from contextlib import contextmanager |
|
|
from types import MethodType |
|
|
from typing import Any, List, Optional |
|
|
|
|
|
import torch |
|
|
from peft.tuners import lora |
|
|
from peft.tuners.lora import LoraLayer |
|
|
|
|
|
|
|
|
def round_robin(num_reqs, num_workers): |
|
|
"""Distribute requests evenly across workers using round-robin algorithm. |
|
|
|
|
|
Args: |
|
|
num_reqs (int): Total number of requests to distribute |
|
|
num_workers (int): Number of available workers |
|
|
|
|
|
Returns: |
|
|
list: A list of lists where each sublist contains the request indices |
|
|
assigned to that particular node |
|
|
""" |
|
|
distribution = [[] for _ in range(num_workers)] |
|
|
for idx in range(num_reqs): |
|
|
worker_id = idx % num_workers |
|
|
distribution[worker_id].append(idx) |
|
|
return distribution |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def patch_lora_merge(model, parameter_group=None): |
|
|
"""Patch LoraLayer's merge and get_delta_weight methods for controlled merging. |
|
|
|
|
|
Args: |
|
|
model: The PEFT model to patch |
|
|
parameter_group: Optional list of parameter names to restrict merging |
|
|
|
|
|
Yields: |
|
|
The patched model (context manager ensures cleanup) |
|
|
""" |
|
|
from peft.tuners.tuners_utils import check_adapters_to_merge |
|
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: |
|
|
if parameter_group and all(self.name not in pg for pg in parameter_group): |
|
|
return |
|
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
|
if not adapter_names: |
|
|
return |
|
|
|
|
|
for active_adapter in adapter_names: |
|
|
if active_adapter in self.lora_A.keys(): |
|
|
base_layer = self.get_base_layer() |
|
|
if self.use_dora.get(active_adapter, False): |
|
|
self.lora_magnitude_vector[active_adapter].weight.data = \ |
|
|
self.lora_magnitude_vector[active_adapter].weight.data.to(base_layer.weight.device) |
|
|
|
|
|
return self.merge_origin(safe_merge, adapter_names) |
|
|
|
|
|
def get_delta_weight(self, adapter) -> torch.Tensor: |
|
|
|
|
|
if isinstance(self, lora.Embedding): |
|
|
self.lora_embedding_A[adapter].data = self.lora_embedding_A[adapter].data.to(self.base_layer.weight.device) |
|
|
self.lora_embedding_B[adapter].data = self.lora_embedding_B[adapter].data.to(self.base_layer.weight.device) |
|
|
else: |
|
|
self.lora_A[adapter].weight.data = self.lora_A[adapter].weight.data.to(self.base_layer.weight.device) |
|
|
self.lora_B[adapter].weight.data = self.lora_B[adapter].weight.data.to(self.base_layer.weight.device) |
|
|
return self.get_delta_weight_origin(adapter).to(self.base_layer.weight.device) |
|
|
|
|
|
def _cache_pop(self, key: str) -> Any: |
|
|
value = self._caches.pop(key).to(self.base_layer.weight.device) |
|
|
return value |
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, LoraLayer): |
|
|
module.name = name |
|
|
if not hasattr(module, 'merge_origin') and hasattr(module, 'base_layer'): |
|
|
module.merge_origin = module.merge |
|
|
module.merge = MethodType(merge, module) |
|
|
module.get_delta_weight_origin = module.get_delta_weight |
|
|
module.get_delta_weight = MethodType(get_delta_weight, module) |
|
|
module._cache_pop_origin = module._cache_pop |
|
|
module._cache_pop = MethodType(_cache_pop, module) |
|
|
|
|
|
try: |
|
|
yield model |
|
|
finally: |
|
|
|
|
|
for module in model.modules(): |
|
|
if isinstance(module, LoraLayer): |
|
|
if hasattr(module, 'merge_origin'): |
|
|
module.merge = module.merge_origin |
|
|
del module.merge_origin |
|
|
module.get_delta_weight = module.get_delta_weight_origin |
|
|
del module.get_delta_weight_origin |
|
|
module._cache_pop = module._cache_pop_origin |
|
|
del module._cache_pop_origin |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def patch_lora_unmerge(model): |
|
|
"""Patch the unmerge method to ensure proper device handling.""" |
|
|
|
|
|
def _cache_pop_patched(self, key: str) -> Any: |
|
|
value = self._caches.pop(key).to(self.base_layer.weight.device) |
|
|
return value |
|
|
|
|
|
def unmerge_patched(self): |
|
|
if not self.merged: |
|
|
return |
|
|
|
|
|
for adapter in list(self.merged_adapters): |
|
|
if self.use_dora.get(adapter, False): |
|
|
self.lora_magnitude_vector[adapter].weight.data = \ |
|
|
self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device) |
|
|
|
|
|
return self.unmerge_origin() |
|
|
|
|
|
for module in model.modules(): |
|
|
if isinstance(module, LoraLayer) and not hasattr(module, 'unmerge_origin'): |
|
|
module.unmerge_origin = module.unmerge |
|
|
module.unmerge = MethodType(unmerge_patched, module) |
|
|
module._cache_pop_origin = module._cache_pop |
|
|
module._cache_pop = MethodType(_cache_pop_patched, module) |
|
|
|
|
|
try: |
|
|
yield model |
|
|
finally: |
|
|
for module in model.modules(): |
|
|
if isinstance(module, LoraLayer) and hasattr(module, 'unmerge_origin'): |
|
|
module.unmerge = module.unmerge_origin |
|
|
del module.unmerge_origin |
|
|
module._cache_pop = module._cache_pop_origin |
|
|
del module._cache_pop_origin |
|
|
|