# Copyright (c) Alibaba, Inc. and its affiliates. import functools import time from contextlib import contextmanager from types import MethodType from typing import Any, Optional import torch import torch.nn.functional as F from peft.tuners import lora from peft.tuners.lora import LoraLayer from torch import nn from swift.utils import is_swanlab_available, is_wandb_available if is_wandb_available(): import wandb if is_swanlab_available(): import swanlab def round_robin(num_reqs, num_workers): """Distribute requests evenly across workers using round-robin algorithm. Args: num_reqs (int): Total number of requests to distribute num_workers (int): Number of available workers Returns: list: A list of lists where each sublist contains the request indices assigned to that particular node """ distribution = [[] for _ in range(num_workers)] for idx in range(num_reqs): worker_id = idx % num_workers distribution[worker_id].append(idx) return distribution @contextmanager def patch_lora_merge(model, parameter_group=None): """Patch LoraLayer's merge and get_delta_weight methods for controlled merging. Args: model: The PEFT model to patch parameter_group: Optional list of parameter names to restrict merging Yields: The patched model (context manager ensures cleanup) """ from peft.tuners.tuners_utils import check_adapters_to_merge def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: if parameter_group and all(self.name not in pg for pg in parameter_group): return # Skip if not in target parameter group adapter_names = check_adapters_to_merge(self, adapter_names) if not adapter_names: return for active_adapter in adapter_names: if active_adapter in self.lora_A.keys(): base_layer = self.get_base_layer() if self.use_dora.get(active_adapter, False): self.lora_magnitude_vector[active_adapter].weight.data = \ self.lora_magnitude_vector[active_adapter].weight.data.to(base_layer.weight.device) return self.merge_origin(safe_merge, adapter_names) def get_delta_weight(self, adapter) -> torch.Tensor: # Ensure tensors are on correct device if isinstance(self, lora.Embedding): self.lora_embedding_A[adapter].data = self.lora_embedding_A[adapter].data.to(self.base_layer.weight.device) self.lora_embedding_B[adapter].data = self.lora_embedding_B[adapter].data.to(self.base_layer.weight.device) else: self.lora_A[adapter].weight.data = self.lora_A[adapter].weight.data.to(self.base_layer.weight.device) self.lora_B[adapter].weight.data = self.lora_B[adapter].weight.data.to(self.base_layer.weight.device) return self.get_delta_weight_origin(adapter).to(self.base_layer.weight.device) def _cache_pop(self, key: str) -> Any: value = self._caches.pop(key).to(self.base_layer.weight.device) return value # Patch all LoraLayer instances for name, module in model.named_modules(): if isinstance(module, LoraLayer): module.name = name if not hasattr(module, 'merge_origin') and hasattr(module, 'base_layer'): module.merge_origin = module.merge module.merge = MethodType(merge, module) module.get_delta_weight_origin = module.get_delta_weight module.get_delta_weight = MethodType(get_delta_weight, module) module._cache_pop_origin = module._cache_pop module._cache_pop = MethodType(_cache_pop, module) try: yield model finally: # Cleanup: restore original methods for module in model.modules(): if isinstance(module, LoraLayer): if hasattr(module, 'merge_origin'): module.merge = module.merge_origin del module.merge_origin module.get_delta_weight = module.get_delta_weight_origin del module.get_delta_weight_origin module._cache_pop = module._cache_pop_origin del module._cache_pop_origin @contextmanager def patch_lora_unmerge(model): def unmerge_patched(self): if not self.merged: return # Move magnitude vectors to correct device first for adapter in list(self.merged_adapters): if self.use_dora.get(adapter, False): self.lora_magnitude_vector[adapter].weight.data = \ self.lora_magnitude_vector[adapter].weight.data.to(self.base_layer.weight.device) return self.unmerge_origin() for module in model.modules(): if isinstance(module, LoraLayer) and not hasattr(module, 'unmerge_origin'): module.unmerge_origin = module.unmerge module.unmerge = MethodType(unmerge_patched, module) try: yield model finally: for module in model.modules(): if isinstance(module, LoraLayer) and hasattr(module, 'unmerge_origin'): module.unmerge = module.unmerge_origin del module.unmerge_origin @contextmanager def patch_profiling_context(trainer, name: str): start_time = time.perf_counter() yield end_time = time.perf_counter() duration = end_time - start_time profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration} if 'wandb' in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: wandb.log(profiling_metrics) if 'swanlab' in trainer.args.report_to and swanlab.get_run() is not None and trainer.accelerator.is_main_process: swanlab.log(profiling_metrics) def patch_profiling_decorator(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): with patch_profiling_context(self, func.__name__): return func(self, *args, **kwargs) return wrapper class _ForwardRedirection: """Implements the `forward-redirection`. Taken from Pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602 A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead. """ def __call__(self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any, **kwargs: Any): """Reroutes a method call through the `wrapper_module`'s `forward` method. Args: wrapper_module: The module that has `original_module` wrapped. original_module: The module that was wrapped inside `wrapper_module`. method_name: The name of the method that should be called on the `original_module` after inputs get redirected through the `wrapper_module`'s `forward` method. *args: The positional arguments to the method `method_name`. They will get passed to a patched `forward` method instead. **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched `forward` method instead. """ original_forward = original_module.forward def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: # Unpatch ourselves immediately before calling the method `method_name` # because itself may want to call the real `forward` original_module.forward = original_forward # type: ignore[method-assign] # Call the actual method e.g. `.training_step(...)` out = method(*_args, **_kwargs) self.on_after_inner_forward(wrapper_module, original_module) return out # Patch the original_module's forward so we can redirect the arguments back to the real method original_module.forward = wrapped_forward # type: ignore[method-assign] wrapper_output = wrapper_module(*args, **kwargs) self.on_after_outer_forward(wrapper_module, original_module) return wrapper_output def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: pass def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: pass def entropy_from_logits(logits, chunk_size: int = 1) -> torch.Tensor: """ Compute the Shannon entropy (in nats) for each row of *logits* without materialising the full soft-max in memory. The batch dimension is processed in chunks of size `chunk_size` so that only a subset of rows is expanded to probabilities at any one time. Args: logits (`torch.Tensor`): Logits tensor of shape `(..., num_classes)`. Entropy is taken along the last axis; all leading dimensions are preserved. chunk_size (`int`, *optional*, defaults to `1`): Number of rows to process per iteration. Returns: `torch.Tensor`: Entropy values with shape `logits.shape[:-1]`. """ per_token_entropies = [] for logits_chunk in logits.split(chunk_size, dim=0): logps = F.log_softmax(logits_chunk, dim=-1) chunk_entropy = -(torch.exp(logps) * logps).sum(-1) per_token_entropies.append(chunk_entropy) return torch.cat(per_token_entropies, dim=0)