Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
from contextlib import contextmanager
from types import MethodType
from typing import Any, List, Optional
import torch
from peft.tuners import lora
from peft.tuners.lora import LoraLayer
def round_robin(num_reqs, num_workers):
"""Distribute requests evenly across workers using round-robin algorithm.
Args:
num_reqs (int): Total number of requests to distribute
num_workers (int): Number of available workers
Returns:
list: A list of lists where each sublist contains the request indices
assigned to that particular node
"""
distribution = [[] for _ in range(num_workers)]
for idx in range(num_reqs):
worker_id = idx % num_workers
distribution[worker_id].append(idx)
return distribution
@contextmanager
def patch_lora_merge(model, parameter_group=None):
"""Patch LoraLayer's merge and get_delta_weight methods for controlled merging.
Args:
model: The PEFT model to patch
parameter_group: Optional list of parameter names to restrict merging
Yields:
The patched model (context manager ensures cleanup)
"""
from peft.tuners.tuners_utils import check_adapters_to_merge
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
if parameter_group and all(self.name not in pg for pg in parameter_group):
return # Skip if not in target parameter group
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
return
for active_adapter in adapter_names:
if active_adapter in self.lora_A.keys():
base_layer = self.get_base_layer()
if self.use_dora.get(active_adapter, False):
self.lora_magnitude_vector[active_adapter].weight.data = \
self.lora_magnitude_vector[active_adapter].weight.data.to(base_layer.weight.device)
return self.merge_origin(safe_merge, adapter_names)
def get_delta_weight(self, adapter) -> torch.Tensor:
# Ensure tensors are on correct device
if isinstance(self, lora.Embedding):
self.lora_embedding_A[adapter].data = self.lora_embedding_A[adapter].data.to(self.base_layer.weight.device)
self.lora_embedding_B[adapter].data = self.lora_embedding_B[adapter].data.to(self.base_layer.weight.device)
else:
self.lora_A[adapter].weight.data = self.lora_A[adapter].weight.data.to(self.base_layer.weight.device)
self.lora_B[adapter].weight.data = self.lora_B[adapter].weight.data.to(self.base_layer.weight.device)
return self.get_delta_weight_origin(adapter).to(self.base_layer.weight.device)
def _cache_pop(self, key: str) -> Any:
value = self._caches.pop(key).to(self.base_layer.weight.device)
return value
# Patch all LoraLayer instances
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module.name = name
if not hasattr(module, 'merge_origin') and hasattr(module, 'base_layer'):
module.merge_origin = module.merge
module.merge = MethodType(merge, module)
module.get_delta_weight_origin = module.get_delta_weight
module.get_delta_weight = MethodType(get_delta_weight, module)
module._cache_pop_origin = module._cache_pop
module._cache_pop = MethodType(_cache_pop, module)
try:
yield model
finally:
# Cleanup: restore original methods
for module in model.modules():
if isinstance(module, LoraLayer):
if hasattr(module, 'merge_origin'):
module.merge = module.merge_origin
del module.merge_origin
module.get_delta_weight = module.get_delta_weight_origin
del module.get_delta_weight_origin
module._cache_pop = module._cache_pop_origin
del module._cache_pop_origin
@contextmanager
def patch_lora_unmerge(model):
"""Patch the unmerge method to ensure proper device handling."""
def _cache_pop_patched(self, key: str) -> Any:
value = self._caches.pop(key).to(self.base_layer.weight.device)
return value
def unmerge_patched(self):
if not self.merged:
return
# Move magnitude vectors to correct device first
for adapter in list(self.merged_adapters):
if self.use_dora.get(adapter, False):
self.lora_magnitude_vector[adapter].weight.data = \
self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device)
return self.unmerge_origin()
for module in model.modules():
if isinstance(module, LoraLayer) and not hasattr(module, 'unmerge_origin'):
module.unmerge_origin = module.unmerge
module.unmerge = MethodType(unmerge_patched, module)
module._cache_pop_origin = module._cache_pop
module._cache_pop = MethodType(_cache_pop_patched, module)
try:
yield model
finally:
for module in model.modules():
if isinstance(module, LoraLayer) and hasattr(module, 'unmerge_origin'):
module.unmerge = module.unmerge_origin
del module.unmerge_origin
module._cache_pop = module._cache_pop_origin
del module._cache_pop_origin