| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| | import os
|
| | import shutil
|
| | import sys
|
| | import importlib.util
|
| | from typing import Optional, Tuple
|
| | from torch.autograd import Function
|
| |
|
| |
|
| | UNSLOTH_COMPILE_LOCATION = os.environ.get(
|
| | "UNSLOTH_COMPILE_LOCATION", "unsloth_compiled_cache"
|
| | )
|
| |
|
| |
|
| | def _get_compile_location() -> str:
|
| | return os.path.abspath(
|
| | os.environ.get("UNSLOTH_COMPILE_LOCATION", UNSLOTH_COMPILE_LOCATION)
|
| | )
|
| |
|
| |
|
| | def _log_info(message: str):
|
| | if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1":
|
| | print(message)
|
| |
|
| |
|
| | def install_to_cache(source_path, destination_filename=None):
|
| | """
|
| | Copies a file to the unsloth_compiled_cache directory
|
| | to ensure it is available for compiled modules.
|
| | """
|
| | compile_location = _get_compile_location()
|
| | if not os.path.exists(compile_location):
|
| | try:
|
| | os.makedirs(compile_location)
|
| | except:
|
| | pass
|
| |
|
| | current_file = os.path.abspath(source_path)
|
| | if destination_filename is None:
|
| | destination_filename = os.path.basename(current_file)
|
| |
|
| | destination = os.path.abspath(os.path.join(compile_location, destination_filename))
|
| |
|
| |
|
| | if current_file != destination:
|
| | try:
|
| | shutil.copy(current_file, destination)
|
| | except Exception:
|
| | pass
|
| |
|
| |
|
| | install_to_cache(__file__, "moe_utils.py")
|
| |
|
| | _CACHED_FORWARD_MOE_BACKEND = None
|
| | _CACHED_MOE_UTILS_MODULE = None
|
| |
|
| |
|
| | def _load_cached_moe_utils_module():
|
| | global _CACHED_MOE_UTILS_MODULE
|
| |
|
| | cache_file = os.path.abspath(os.path.join(_get_compile_location(), "moe_utils.py"))
|
| | current_file = os.path.abspath(__file__)
|
| | if not os.path.isfile(cache_file) or cache_file == current_file:
|
| | return None
|
| |
|
| | try:
|
| | module_name = "unsloth_cached_moe_utils"
|
| | module = sys.modules.get(module_name, None)
|
| | if module is not None and os.path.abspath(getattr(module, "__file__", "")) == cache_file:
|
| | _CACHED_MOE_UTILS_MODULE = module
|
| | return module
|
| |
|
| | spec = importlib.util.spec_from_file_location(module_name, cache_file)
|
| | if spec is None or spec.loader is None:
|
| | return None
|
| | module = importlib.util.module_from_spec(spec)
|
| | sys.modules[module_name] = module
|
| | spec.loader.exec_module(module)
|
| | _CACHED_MOE_UTILS_MODULE = module
|
| | return module
|
| | except Exception:
|
| | return None
|
| |
|
| |
|
| | def get_forward_moe_backend():
|
| | """
|
| | Resolve forward_moe_backend from the compiled cache copy when available.
|
| | Falls back to the local module definition.
|
| | """
|
| | global _CACHED_FORWARD_MOE_BACKEND
|
| | module = _load_cached_moe_utils_module()
|
| | if module is not None and hasattr(module, "forward_moe_backend"):
|
| | _CACHED_FORWARD_MOE_BACKEND = module.forward_moe_backend
|
| | return _CACHED_FORWARD_MOE_BACKEND
|
| |
|
| | _CACHED_FORWARD_MOE_BACKEND = forward_moe_backend
|
| | return _CACHED_FORWARD_MOE_BACKEND
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def _grouped_mm_with_backward_fix(
|
| | inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
|
| | ) -> torch.Tensor:
|
| | """
|
| | Grouped matmul with working backward pass.
|
| |
|
| | Uses native torch._grouped_mm with contiguous inputs for correct gradients.
|
| | """
|
| | return torch._grouped_mm(inputs, weight, offs=offsets)
|
| |
|
| |
|
| |
|
| | _GROUPED_GEMM_AVAILABLE = None
|
| | _TORCH_GROUPED_MM_AVAILABLE = hasattr(torch, "_grouped_mm")
|
| |
|
| |
|
| | _TORCH_GROUPED_MM_SUPPORTED = None
|
| |
|
| |
|
| | def _check_torch_grouped_mm_supported():
|
| | """
|
| | Check if torch._grouped_mm is actually supported on the current GPU.
|
| | We check for existence and verify with a dummy call.
|
| | A runtime probe is the only reliable check.
|
| | """
|
| | global _TORCH_GROUPED_MM_SUPPORTED
|
| | if _TORCH_GROUPED_MM_SUPPORTED is not None: return _TORCH_GROUPED_MM_SUPPORTED
|
| |
|
| | if not _TORCH_GROUPED_MM_AVAILABLE:
|
| | _TORCH_GROUPED_MM_SUPPORTED = False
|
| | return False
|
| |
|
| | if not torch.cuda.is_available():
|
| | _TORCH_GROUPED_MM_SUPPORTED = False
|
| | return False
|
| |
|
| | try:
|
| |
|
| |
|
| |
|
| | device = torch.cuda.current_device()
|
| | dtype = torch.float16
|
| |
|
| |
|
| | x = torch.ones((1, 8), device=device, dtype=dtype)
|
| | w = torch.ones((1, 8, 8), device=device, dtype=dtype)
|
| | offs = torch.tensor([1], device=device, dtype=torch.int32)
|
| |
|
| | torch._grouped_mm(x, w, offs=offs)
|
| | del x, w, offs
|
| | _TORCH_GROUPED_MM_SUPPORTED = True
|
| | except Exception:
|
| | _TORCH_GROUPED_MM_SUPPORTED = False
|
| |
|
| | return _TORCH_GROUPED_MM_SUPPORTED
|
| |
|
| |
|
| | _TRITON_ALLOCATOR_INITIALIZED = False
|
| | _PERSISTENT_BUFFER = None
|
| |
|
| |
|
| | def _init_triton_allocator():
|
| | """
|
| | Initialize a persistent Triton allocator to avoid memory allocation overhead per call.
|
| | This significantly reduces GPU utilization fluctuation.
|
| | """
|
| | global _TRITON_ALLOCATOR_INITIALIZED, _PERSISTENT_BUFFER
|
| | if _TRITON_ALLOCATOR_INITIALIZED: return
|
| |
|
| | try:
|
| | import triton
|
| |
|
| |
|
| |
|
| |
|
| | def persistent_alloc_fn(size: int, alignment: int, stream):
|
| | global _PERSISTENT_BUFFER
|
| |
|
| |
|
| | rounded_size = ((size + 128 - 1) // 128) * 128
|
| |
|
| | if (
|
| | _PERSISTENT_BUFFER is None
|
| | or _PERSISTENT_BUFFER.numel() * _PERSISTENT_BUFFER.element_size()
|
| | < rounded_size
|
| | ):
|
| |
|
| |
|
| | _PERSISTENT_BUFFER = torch.empty(
|
| | int(rounded_size * 1.1), device="cuda", dtype=torch.uint8
|
| | )
|
| | _PERSISTENT_BUFFER.__hibernate__ = {"type": "ignore"}
|
| | return _PERSISTENT_BUFFER
|
| |
|
| | triton.set_allocator(persistent_alloc_fn)
|
| | triton._unsloth_allocator_set = True
|
| | _TRITON_ALLOCATOR_INITIALIZED = True
|
| | except Exception:
|
| | pass
|
| |
|
| |
|
| | def _check_grouped_gemm_available():
|
| | """Check if Unsloth grouped GEMM kernels are available."""
|
| | if os.environ.get("UNSLOTH_DISABLE_MOE_TRITON", "0") == "1": return False
|
| |
|
| | global _GROUPED_GEMM_AVAILABLE
|
| | if _GROUPED_GEMM_AVAILABLE is not None: return _GROUPED_GEMM_AVAILABLE
|
| |
|
| | try:
|
| | from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm, supports_tma
|
| | _GROUPED_GEMM_AVAILABLE = True
|
| | _init_triton_allocator()
|
| | except (ImportError, ModuleNotFoundError):
|
| | _GROUPED_GEMM_AVAILABLE = False
|
| | return _GROUPED_GEMM_AVAILABLE
|
| |
|
| |
|
| | from functools import lru_cache
|
| |
|
| |
|
| | @lru_cache(maxsize=1)
|
| | def select_moe_backend():
|
| | """
|
| | Selects the MoE backend based on UNSLOTH_MOE_BACKEND environment variable and availability.
|
| | Choices: "grouped_mm", "unsloth_triton", "native_torch".
|
| | Default if unspecified: "grouped_mm".
|
| | """
|
| |
|
| |
|
| | requested = os.environ.get("UNSLOTH_MOE_BACKEND")
|
| | if requested:
|
| | if requested == "grouped_mm" and _check_torch_grouped_mm_supported():
|
| | return "grouped_mm"
|
| | if requested == "unsloth_triton" and _check_grouped_gemm_available():
|
| | return "unsloth_triton"
|
| | if requested == "native_torch":
|
| | return "native_torch"
|
| | _log_info(f"Unsloth: '{requested}' backend requested but is not available. Falling back to next available.")
|
| |
|
| | if _check_torch_grouped_mm_supported():
|
| | _log_info("Unsloth: Using MoE backend 'grouped_mm'")
|
| | return "grouped_mm"
|
| | if _check_grouped_gemm_available():
|
| | _log_info("Unsloth: Using MoE backend 'unsloth_triton'")
|
| | return "unsloth_triton"
|
| | return "native_torch"
|
| |
|
| |
|
| | def forward_moe_backend(
|
| | self,
|
| | hidden_states: torch.Tensor,
|
| | top_k_index: torch.Tensor,
|
| | top_k_weights: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Dispatch MoE forward to the selected backend.
|
| | Centralizes backend selection to keep model-specific patches minimal.
|
| | """
|
| |
|
| |
|
| | backend = select_moe_backend()
|
| | if backend == "grouped_mm":
|
| | return forward_native_grouped_mm(self, hidden_states, top_k_index, top_k_weights)
|
| | if backend == "unsloth_triton":
|
| | return forward_triton_grouped_gemm(self, hidden_states, top_k_index, top_k_weights)
|
| | return forward_native_moe_loop(self, hidden_states, top_k_index, top_k_weights)
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def _get_routing_indices(selected_experts, num_experts):
|
| | """
|
| | Compute token→expert mapping for grouped GEMM.
|
| | Uses bincount instead of histc to avoid float conversion overhead.
|
| |
|
| | Returns:
|
| | token_counts_by_expert: (num_experts,) token counts per expert
|
| | gather_indices: (total_tokens,) indices for gathering tokens in expert order
|
| | """
|
| |
|
| |
|
| | flat_experts = selected_experts.view(-1)
|
| |
|
| |
|
| | token_counts_by_expert = torch.bincount(flat_experts, minlength=num_experts).to(torch.int32)
|
| |
|
| |
|
| | gather_indices = flat_experts.argsort(stable=True)
|
| |
|
| | return token_counts_by_expert, gather_indices
|
| |
|
| |
|
| | def _silu_and_mul(x):
|
| | """Fused SiLU activation and element-wise multiply for gate/up projections."""
|
| | gate, up = x.chunk(2, dim=-1)
|
| | return F.silu(gate) * up
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def _has_lora_adapters(param) -> bool:
|
| | """Check if parameter has active LoRA adapters (PEFT ParamWrapper)."""
|
| |
|
| | if not hasattr(param, "lora_A") or not hasattr(param, "lora_B"):
|
| | return False
|
| | if hasattr(param, "disable_adapters") and param.disable_adapters:
|
| | return False
|
| | if hasattr(param, "merged") and param.merged:
|
| | return False
|
| | return len(param.lora_A) > 0
|
| |
|
| |
|
| | def _extract_lora_from_wrapper(
|
| | wrapper, adapter_name: str = "default", experts_module=None
|
| | ) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]:
|
| | """
|
| | Extract LoRA weights from PEFT ParamWrapper for MoE separated computation.
|
| |
|
| | PEFT ParamWrapper for 3D parameters creates:
|
| | - lora_A: nn.Linear(in_dim, E*R) -> weight: (E*R, in_dim)
|
| | - lora_B: nn.Linear(E*R, out_dim) -> weight: (out_dim, E*R)
|
| |
|
| | For grouped_mm: X @ first_weight @ second_weight
|
| |
|
| | STANDARD FORMAT (Qwen3-MoE): weights stored as (E, out_dim, in_dim) for F.linear
|
| | gate_up_proj: (E, 2*I, H) - input X is (N, H), output is (N, 2*I)
|
| | down_proj: (E, H, I) - input X is (N, I), output is (N, H)
|
| |
|
| | For gate_up with (E, 2*I, H):
|
| | lora_A: (E*R, H), lora_B: (2*I, E*R)
|
| | Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
|
| | first_weight from lora_A: (E*R, H) -> (E, H, R) after view/permute
|
| | second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I) after view/permute
|
| |
|
| | TRANSPOSED FORMAT (Qwen3-VL-MoE): weights stored as (E, in_dim, out_dim) for grouped_mm
|
| | gate_up_proj: (E, H, 2*I) - input X is (N, H), output is (N, 2*I)
|
| | down_proj: (E, I, H) - input X is (N, I), output is (N, H)
|
| |
|
| | For gate_up with (E, H, 2*I):
|
| | lora_A: (E*R, H), lora_B: (2*I, E*R)
|
| | Input X (N, H) needs: X @ (E, H, R) @ (E, R, 2*I) -> (N, 2*I)
|
| | first_weight from lora_A: (E*R, H) -> (E, H, R)
|
| | second_weight from lora_B: (2*I, E*R) -> (E, R, 2*I)
|
| |
|
| | Returns:
|
| | (first_weight, second_weight, scaling, num_experts) or None
|
| | """
|
| |
|
| |
|
| | try:
|
| | if not hasattr(wrapper, "lora_A") or not hasattr(wrapper, "lora_B"):
|
| | return None
|
| |
|
| | if hasattr(wrapper, "disable_adapters") and wrapper.disable_adapters:
|
| | return None
|
| | if hasattr(wrapper, "merged") and wrapper.merged:
|
| | return None
|
| |
|
| | if not wrapper.lora_A:
|
| | return None
|
| |
|
| | if adapter_name not in wrapper.lora_A:
|
| | adapter_name = list(wrapper.lora_A.keys())[0]
|
| |
|
| | lora_A_module = wrapper.lora_A[adapter_name]
|
| | lora_B_module = wrapper.lora_B[adapter_name]
|
| |
|
| | weight_A = lora_A_module.weight
|
| | weight_B = lora_B_module.weight
|
| | scaling = wrapper.scaling[adapter_name]
|
| | num_experts = getattr(wrapper, "num_experts", 1)
|
| |
|
| |
|
| | if experts_module is None:
|
| | experts_module = wrapper.get_base_layer() if hasattr(wrapper, "get_base_layer") else None
|
| |
|
| |
|
| | extractor_fn = getattr(experts_module, "_unsloth_lora_extractor_fn", None)
|
| |
|
| | if extractor_fn is not None:
|
| | return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts)
|
| |
|
| |
|
| | if num_experts > 1:
|
| | total_rank = weight_A.shape[0]
|
| | rank_per_expert = total_rank // num_experts
|
| | dim1 = weight_A.shape[1]
|
| | dim2 = weight_B.shape[0]
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | first_weight = weight_A.view(num_experts, rank_per_expert, dim1)
|
| | first_weight = first_weight.permute(0, 2, 1).contiguous()
|
| |
|
| |
|
| | second_weight = weight_B.view(dim2, num_experts, rank_per_expert)
|
| | second_weight = second_weight.permute(1, 2, 0).contiguous()
|
| | else:
|
| |
|
| | first_weight = weight_A.T
|
| | second_weight = weight_B.T
|
| |
|
| | return first_weight, second_weight, scaling, num_experts
|
| | except Exception:
|
| | return None
|
| |
|
| |
|
| | def _extract_lora_weights(
|
| | param, adapter_name: str = "default", num_experts: int = None, experts_module=None
|
| | ) -> Optional[Tuple[torch.Tensor, torch.Tensor, float]]:
|
| | """
|
| | Extract LoRA A and B weights from PEFT ParamWrapper.
|
| |
|
| | This is a compatibility wrapper around _extract_lora_from_wrapper.
|
| | Use _extract_lora_from_wrapper directly for new code.
|
| |
|
| | Returns:
|
| | (first_weight, second_weight, scaling) for (X @ first) @ second
|
| | """
|
| |
|
| |
|
| |
|
| | if num_experts is not None and not hasattr(param, "num_experts"):
|
| | param.num_experts = num_experts
|
| |
|
| | result = _extract_lora_from_wrapper(param, adapter_name, experts_module=experts_module)
|
| | if result is None:
|
| | return None
|
| |
|
| | return result[0], result[1], result[2]
|
| |
|
| |
|
| | def _get_base_weight(param):
|
| | """Get base weight from potentially wrapped parameter or module."""
|
| |
|
| |
|
| |
|
| | while hasattr(param, "base_layer"):
|
| | param = param.base_layer
|
| |
|
| | if hasattr(param, "get_param"):
|
| | return param.get_param()
|
| |
|
| |
|
| | if hasattr(param, "weight"):
|
| | return param.weight
|
| |
|
| | return param
|
| |
|
| |
|
| | def _get_lora_wrapper_for_param(experts_module, param_name):
|
| | """
|
| | Get the PEFT ParamWrapper for a specific parameter (gate_up_proj or down_proj).
|
| | Uses the explicit key stored in __dict__ if available.
|
| | Does NOT lazily setup wrappers as that requires traversing logic not present here.
|
| | """
|
| |
|
| |
|
| | if hasattr(experts_module, f"{param_name}_lora_wrapper"):
|
| | return getattr(experts_module, f"{param_name}_lora_wrapper")
|
| |
|
| |
|
| | if hasattr(experts_module, param_name):
|
| | attr = getattr(experts_module, param_name)
|
| | if hasattr(attr, "lora_A"):
|
| | return attr
|
| |
|
| | return None
|
| |
|
| |
|
| | def native_moe_grouped_mm(
|
| | inputs: torch.Tensor, weight: torch.Tensor, offsets: torch.Tensor
|
| | ) -> torch.Tensor:
|
| | """
|
| | Native implementation using grouped_mm with backward fix.
|
| |
|
| | Uses custom autograd function to avoid PyTorch's grouped_mm backward stride bug.
|
| | """
|
| | return _grouped_mm_with_backward_fix(inputs, weight, offsets)
|
| |
|
| |
|
| | def _apply_lora_grouped_mm(
|
| | inputs: torch.Tensor,
|
| | lora_B: torch.Tensor,
|
| | lora_A: torch.Tensor,
|
| | offsets: torch.Tensor,
|
| | scaling: float,
|
| | grouped_mm_func=native_moe_grouped_mm,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Apply LoRA using grouped GEMM: result = ((X @ B) @ A) * scaling
|
| |
|
| | Args:
|
| | inputs: (total_tokens, in_dim)
|
| | lora_B: (num_experts, in_dim, rank) - First projection
|
| | lora_A: (num_experts, rank, out_dim) - Second projection
|
| | offsets: Grouped GEMM offsets
|
| | scaling: LoRA scaling factor
|
| | grouped_mm_func: Function to use for grouped GEMM (default: native_moe_grouped_mm)
|
| | """
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | lora_intermediate = grouped_mm_func(inputs, lora_B.contiguous(), offsets)
|
| |
|
| |
|
| |
|
| |
|
| | lora_delta = grouped_mm_func(lora_intermediate, lora_A.contiguous(), offsets)
|
| |
|
| | return lora_delta * scaling
|
| |
|
| |
|
| | def _should_use_separated_lora() -> bool:
|
| | """
|
| | Check if separated LoRA approach should be used (default: True).
|
| | Set UNSLOTH_MOE_LORA_MERGED=1 to use merged approach instead.
|
| | """
|
| | return os.environ.get("UNSLOTH_MOE_LORA_MERGED", "0") != "1"
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | _WEIGHT_PREPROCESSORS = {}
|
| |
|
| |
|
| | def register_weight_preprocessor(model_type: str, preprocessor_fn):
|
| | """
|
| | Register a weight preprocessor for a specific model type.
|
| |
|
| | Args:
|
| | model_type: Model identifier (e.g., "qwen3_moe", "qwen3_vl_moe")
|
| | preprocessor_fn: Function(weight, proj_type, hidden_dim) -> processed_weight
|
| | proj_type is "gate_up" or "down"
|
| | """
|
| | _WEIGHT_PREPROCESSORS[model_type] = preprocessor_fn
|
| |
|
| |
|
| | def get_weight_preprocessor(model_type: str):
|
| | """Get registered weight preprocessor for model type."""
|
| | return _WEIGHT_PREPROCESSORS.get(model_type)
|
| |
|
| |
|
| | def preprocess_weight(
|
| | weight: torch.Tensor, proj_type: str, hidden_dim: int, model_type=None
|
| | ):
|
| | """
|
| | Preprocess weight tensor for grouped_mm compatibility.
|
| |
|
| | Uses model-specific preprocessor if registered, otherwise uses default logic.
|
| |
|
| | Args:
|
| | weight: Weight tensor (E, dim1, dim2) or similar
|
| | proj_type: "gate_up" or "down"
|
| | hidden_dim: Hidden dimension for shape inference
|
| | model_type: Optional model type to use specific preprocessor
|
| |
|
| | Returns:
|
| | Weight tensor in (E, in_dim, out_dim) format for grouped_mm
|
| | """
|
| |
|
| |
|
| | if model_type and model_type in _WEIGHT_PREPROCESSORS:
|
| | return _WEIGHT_PREPROCESSORS[model_type](weight, proj_type, hidden_dim)
|
| |
|
| |
|
| | if proj_type == "gate_up":
|
| |
|
| | if weight.shape[1] == hidden_dim:
|
| | return weight
|
| | else:
|
| | return weight.transpose(-2, -1)
|
| | else:
|
| |
|
| | if weight.shape[2] == hidden_dim:
|
| | return weight
|
| | else:
|
| | return weight.transpose(-2, -1)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def _is_moe_experts_module(module) -> bool:
|
| | """
|
| | Check if module is an MoE experts layer (generic, not model-specific).
|
| |
|
| | Detects modules with stacked expert weights as 3D nn.Parameter:
|
| | - gate_up_proj/down_proj pattern (Qwen3-MoE, Qwen3-VL-MoE, etc.)
|
| | - w1/w2/w3 pattern (older MoE models)
|
| | """
|
| |
|
| |
|
| | import torch.nn as nn
|
| |
|
| |
|
| |
|
| |
|
| | if hasattr(module, "gate_up_proj"):
|
| | param = module.gate_up_proj
|
| |
|
| |
|
| | if isinstance(param, (nn.Parameter, torch.Tensor)) and param.ndim in (2, 3):
|
| | return True
|
| |
|
| |
|
| | if hasattr(module, "w1") and hasattr(module, "w2"):
|
| | w1 = module.w1
|
| | if isinstance(w1, (nn.Parameter, torch.Tensor)) and w1.ndim in (2, 3):
|
| | return True
|
| |
|
| | return False
|
| |
|
| |
|
| |
|
| | _get_moe_lora_weights = _extract_lora_from_wrapper
|
| |
|
| |
|
| |
|
| | _original_param_wrapper_forward = None
|
| |
|
| |
|
| | def _patched_param_wrapper_forward(
|
| | self, x: torch.Tensor, *args, **kwargs
|
| | ) -> torch.Tensor:
|
| | """
|
| | Patched ParamWrapper.forward for MoE separated LoRA.
|
| |
|
| | For MoE expert modules:
|
| | - Bypasses PEFTs _activate_lora parametrization context
|
| | - Stores LoRA data by parameter_name for forward_native_grouped_mm to use
|
| |
|
| | For non-MoE modules:
|
| | - Falls back to original PEFT forward
|
| | """
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | immediate_base_layer = self.base_layer
|
| |
|
| |
|
| |
|
| | experts_module = self.get_base_layer()
|
| |
|
| | use_separated = _should_use_separated_lora()
|
| | param_name = getattr(self, "parameter_name", None)
|
| |
|
| |
|
| | if (
|
| | use_separated
|
| | and param_name in ("gate_up_proj", "down_proj")
|
| | and _is_moe_experts_module(experts_module)
|
| | ):
|
| |
|
| |
|
| |
|
| | if self.disable_adapters:
|
| | if self.merged:
|
| | self.unmerge()
|
| | return immediate_base_layer(x, *args, **kwargs)
|
| |
|
| | if self.merged:
|
| | return immediate_base_layer(x, *args, **kwargs)
|
| |
|
| |
|
| | if not hasattr(self, "num_experts"):
|
| | if hasattr(experts_module, "num_experts"):
|
| | self.num_experts = experts_module.num_experts
|
| | elif hasattr(experts_module, param_name):
|
| | p = getattr(experts_module, param_name)
|
| | if hasattr(p, "shape") and len(p.shape) >= 1:
|
| | self.num_experts = p.shape[0]
|
| |
|
| |
|
| | lora_data = _extract_lora_from_wrapper(self)
|
| |
|
| | if lora_data is not None and param_name:
|
| |
|
| |
|
| | lora_attr = f"_unsloth_lora_{param_name}"
|
| | setattr(experts_module, lora_attr, lora_data)
|
| |
|
| | try:
|
| |
|
| |
|
| | result = immediate_base_layer(x, *args, **kwargs)
|
| | finally:
|
| |
|
| | if param_name:
|
| | lora_attr = f"_unsloth_lora_{param_name}"
|
| | if hasattr(experts_module, lora_attr):
|
| | delattr(experts_module, lora_attr)
|
| |
|
| | return result
|
| |
|
| |
|
| | return _original_param_wrapper_forward(self, x, *args, **kwargs)
|
| |
|
| |
|
| | def patch_param_wrapper_for_moe():
|
| | """
|
| | Patch PEFT's ParamWrapper.forward to use separated LoRA for MoE.
|
| |
|
| | This should be called after PEFT is imported.
|
| | """
|
| |
|
| |
|
| | global _original_param_wrapper_forward
|
| |
|
| | module = _load_cached_moe_utils_module()
|
| | if module is not None and hasattr(module, "patch_param_wrapper_for_moe"):
|
| | try:
|
| | return module.patch_param_wrapper_for_moe()
|
| | except Exception:
|
| | pass
|
| |
|
| | try:
|
| | from peft.tuners.lora.layer import ParamWrapper
|
| |
|
| |
|
| | if _original_param_wrapper_forward is None:
|
| | _original_param_wrapper_forward = ParamWrapper.forward
|
| |
|
| |
|
| | ParamWrapper.forward = _patched_param_wrapper_forward
|
| |
|
| | return True
|
| | except ImportError:
|
| | return False
|
| |
|
| |
|
| | def forward_native_grouped_mm(
|
| | self,
|
| | hidden_states: torch.Tensor,
|
| | top_k_index: torch.Tensor,
|
| | top_k_weights: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Native Pytorch grouped GEMM MoE forward pass.
|
| | Uses torch._grouped_mm which is significantly faster than loop and works without Triton dependencies.
|
| | Requires torch._grouped_mm support (verified via runtime check).
|
| | """
|
| |
|
| |
|
| |
|
| | if not _check_torch_grouped_mm_supported():
|
| | major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
| | raise RuntimeError(
|
| | f"torch._grouped_mm is not supported on this device (Compute Capability {major}.{minor}). "
|
| | f"Set UNSLOTH_MOE_BACKEND='unsloth_triton' or 'native_torch' to use a compatible backend."
|
| | )
|
| |
|
| | is_2d_input = hidden_states.dim() == 2
|
| | if is_2d_input:
|
| | sequence_length, hidden_dim = hidden_states.shape
|
| | batch_size = 1
|
| | else:
|
| | batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| |
|
| | hidden_states = hidden_states.view(-1, hidden_dim)
|
| |
|
| |
|
| | flat_top_k = top_k_index.view(-1)
|
| | num_tokens_per_expert = torch.bincount(flat_top_k, minlength=self.num_experts).int()
|
| |
|
| |
|
| | sorted_indices = torch.argsort(flat_top_k, stable=True)
|
| | token_indices = sorted_indices // top_k_index.shape[-1]
|
| |
|
| |
|
| |
|
| | permuted_input = hidden_states[token_indices]
|
| |
|
| |
|
| | offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
|
| |
|
| |
|
| |
|
| |
|
| | use_separated_lora = _should_use_separated_lora()
|
| | gate_up_lora = None
|
| |
|
| |
|
| | if getattr(self, "_unsloth_lora_gate_up_proj", None) is not None:
|
| | gate_up_lora = self._unsloth_lora_gate_up_proj[
|
| | :3
|
| | ]
|
| |
|
| | elif (
|
| | use_separated_lora
|
| | and hasattr(self, "gate_up_proj")
|
| | and _has_lora_adapters(self.gate_up_proj)
|
| | ):
|
| | gate_up_lora = _extract_lora_weights(
|
| | self.gate_up_proj, num_experts=self.num_experts, experts_module=self
|
| | )
|
| |
|
| | if hasattr(self, "gate_up_proj"):
|
| |
|
| | gate_up_base = _get_base_weight(self.gate_up_proj)
|
| |
|
| |
|
| | model_type = getattr(self, "_unsloth_model_type", None)
|
| |
|
| |
|
| |
|
| | w1 = preprocess_weight(gate_up_base, "gate_up", hidden_dim, model_type)
|
| |
|
| | mm1_out = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
|
| |
|
| |
|
| |
|
| | if gate_up_lora is not None:
|
| | first_weight, second_weight, scaling = gate_up_lora
|
| |
|
| |
|
| |
|
| | first_weight = first_weight.to(permuted_input.dtype).contiguous()
|
| | second_weight = second_weight.to(permuted_input.dtype).contiguous()
|
| |
|
| |
|
| | try:
|
| | lora_out = _grouped_mm_with_backward_fix(permuted_input, first_weight, offsets)
|
| | lora_out = lora_out.contiguous()
|
| | except RuntimeError as e:
|
| | raise e
|
| |
|
| |
|
| |
|
| | try:
|
| | if second_weight.shape[-1] % 8 != 0:
|
| | pad_size = 8 - (second_weight.shape[-1] % 8)
|
| | second_weight_padded = F.pad(
|
| | second_weight, (0, pad_size)
|
| | ).contiguous()
|
| | lora_delta = _grouped_mm_with_backward_fix(
|
| | lora_out, second_weight_padded, offsets
|
| | )
|
| | lora_delta = lora_delta[:, :-pad_size]
|
| | else:
|
| | lora_delta = _grouped_mm_with_backward_fix(
|
| | lora_out, second_weight, offsets
|
| | )
|
| | except RuntimeError:
|
| |
|
| | lora_delta = torch.empty(
|
| | (lora_out.shape[0], second_weight.shape[-1]),
|
| | dtype=lora_out.dtype,
|
| | device=lora_out.device,
|
| | )
|
| | cpu_offsets = offsets.cpu().tolist()
|
| | prev_offset = 0
|
| | for i, end in enumerate(cpu_offsets):
|
| | if prev_offset < end:
|
| | lora_delta[prev_offset:end] = torch.matmul(
|
| | lora_out[prev_offset:end], second_weight[i]
|
| | )
|
| | prev_offset = end
|
| |
|
| |
|
| | mm1_out = mm1_out + lora_delta * scaling
|
| |
|
| | if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None:
|
| | num_repeats = num_tokens_per_expert.to(self.gate_up_proj_bias.device)
|
| | bias_expanded = self.gate_up_proj_bias.repeat_interleave(num_repeats, dim=0)
|
| | mm1_out = mm1_out + bias_expanded.to(mm1_out.dtype)
|
| |
|
| | if "GptOssExperts" in self.__class__.__name__:
|
| | gate = mm1_out[..., ::2]
|
| | up = mm1_out[..., 1::2]
|
| | else:
|
| | gate, up = mm1_out.chunk(2, dim=-1)
|
| |
|
| | elif hasattr(self, "w1") and hasattr(self, "w3"):
|
| |
|
| | w1_base = _get_base_weight(self.w1)
|
| | w3_base = _get_base_weight(self.w3)
|
| |
|
| | w1 = w1_base.transpose(-2, -1)
|
| | w3 = w3_base.transpose(-2, -1)
|
| |
|
| | gate = _grouped_mm_with_backward_fix(permuted_input, w1, offsets)
|
| | up = _grouped_mm_with_backward_fix(permuted_input, w3, offsets)
|
| |
|
| |
|
| | if use_separated_lora:
|
| | if _has_lora_adapters(self.w1):
|
| | w1_lora = _extract_lora_weights(self.w1, experts_module=self)
|
| | if w1_lora is not None:
|
| | lora_A, lora_B, scaling = w1_lora
|
| | lora_A_t = lora_A.transpose(-2, -1)
|
| | lora_A_out = _grouped_mm_with_backward_fix(
|
| | permuted_input, lora_A_t, offsets
|
| | )
|
| | lora_B_t = lora_B.transpose(-2, -1)
|
| | lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
|
| | gate = gate + lora_B_out * scaling
|
| |
|
| | if _has_lora_adapters(self.w3):
|
| | w3_lora = _extract_lora_weights(self.w3, experts_module=self)
|
| | if w3_lora is not None:
|
| | lora_A, lora_B, scaling = w3_lora
|
| | lora_A_t = lora_A.transpose(-2, -1)
|
| | lora_A_out = _grouped_mm_with_backward_fix(
|
| | permuted_input, lora_A_t, offsets
|
| | )
|
| | lora_B_t = lora_B.transpose(-2, -1)
|
| | lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
|
| | up = up + lora_B_out * scaling
|
| | else:
|
| | raise AttributeError("MoE layer must have 'gate_up_proj' or 'w1'/'w3'.")
|
| |
|
| |
|
| | if "GptOssExperts" in self.__class__.__name__:
|
| |
|
| | limit = getattr(self, "limit", 7.0)
|
| | alpha = getattr(self, "alpha", 1.702)
|
| |
|
| | gate = gate.clamp(min=None, max=limit)
|
| | up = up.clamp(min=-limit, max=limit)
|
| | glu = gate * torch.sigmoid(gate * alpha)
|
| | inter = (up + 1.0) * glu
|
| | else:
|
| | inter = F.silu(gate) * up
|
| |
|
| |
|
| |
|
| |
|
| | down_lora = None
|
| |
|
| |
|
| | if getattr(self, "_unsloth_lora_down_proj", None) is not None:
|
| | down_lora = self._unsloth_lora_down_proj[
|
| | :3
|
| | ]
|
| |
|
| | elif (
|
| | use_separated_lora
|
| | and hasattr(self, "down_proj")
|
| | and _has_lora_adapters(self.down_proj)
|
| | ):
|
| | down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts, experts_module=self)
|
| |
|
| | if hasattr(self, "down_proj"):
|
| |
|
| | down_base = _get_base_weight(self.down_proj)
|
| |
|
| |
|
| | model_type = getattr(self, "_unsloth_model_type", None)
|
| |
|
| |
|
| | w2 = preprocess_weight(down_base, "down", hidden_dim, model_type)
|
| |
|
| |
|
| | mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
|
| |
|
| |
|
| |
|
| | if down_lora is not None:
|
| | first_weight, second_weight, scaling = down_lora
|
| |
|
| |
|
| | first_weight = first_weight.to(inter.dtype).contiguous()
|
| | second_weight = second_weight.to(inter.dtype).contiguous()
|
| |
|
| |
|
| | lora_out = _grouped_mm_with_backward_fix(inter, first_weight, offsets)
|
| | lora_out = lora_out.contiguous()
|
| |
|
| |
|
| | try:
|
| | lora_delta = _grouped_mm_with_backward_fix(lora_out, second_weight, offsets)
|
| | except RuntimeError:
|
| |
|
| | lora_delta = torch.empty(
|
| | (lora_out.shape[0], second_weight.shape[-1]),
|
| | dtype=lora_out.dtype,
|
| | device=lora_out.device,
|
| | )
|
| | cpu_offsets = offsets.cpu().tolist()
|
| | prev_offset = 0
|
| | for i, end in enumerate(cpu_offsets):
|
| | if prev_offset < end:
|
| | lora_delta[prev_offset:end] = torch.matmul(
|
| | lora_out[prev_offset:end], second_weight[i]
|
| | )
|
| | prev_offset = end
|
| |
|
| |
|
| | mm2_out = mm2_out + lora_delta * scaling
|
| |
|
| | if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None:
|
| | bias_expanded = self.down_proj_bias.repeat_interleave(
|
| | num_tokens_per_expert.to(self.down_proj_bias.device), dim=0
|
| | ).to(mm2_out.device)
|
| | mm2_out = mm2_out + bias_expanded.to(mm2_out.dtype)
|
| |
|
| | elif hasattr(self, "w2"):
|
| | w2_base = _get_base_weight(self.w2)
|
| | w2 = w2_base.transpose(-2, -1)
|
| |
|
| |
|
| | mm2_out = _grouped_mm_with_backward_fix(inter, w2, offsets)
|
| |
|
| |
|
| | if use_separated_lora and _has_lora_adapters(self.w2):
|
| | w2_lora = _extract_lora_weights(self.w2, experts_module=self)
|
| | if w2_lora is not None:
|
| | lora_A, lora_B, scaling = w2_lora
|
| | lora_A_t = lora_A.transpose(-2, -1).contiguous()
|
| | lora_A_out = _grouped_mm_with_backward_fix(inter, lora_A_t, offsets)
|
| | lora_B_t = lora_B.transpose(-2, -1).contiguous()
|
| | lora_B_out = _grouped_mm_with_backward_fix(lora_A_out, lora_B_t, offsets)
|
| | mm2_out = mm2_out + lora_B_out * scaling
|
| | else:
|
| | raise AttributeError("MoE layer must have 'down_proj' or 'w2'.")
|
| |
|
| |
|
| | flat_weights = top_k_weights.view(-1)
|
| | permuted_weights = flat_weights[sorted_indices]
|
| | mm2_out = mm2_out * permuted_weights.unsqueeze(-1)
|
| |
|
| | final_hidden_states = torch.zeros(
|
| | (batch_size * sequence_length, hidden_dim),
|
| | dtype=hidden_states.dtype,
|
| | device=hidden_states.device,
|
| | )
|
| |
|
| | final_hidden_states.index_add_(0, token_indices, mm2_out.to(hidden_states.dtype))
|
| |
|
| | if is_2d_input:
|
| | return final_hidden_states
|
| |
|
| | return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
| |
|
| |
|
| | def forward_triton_grouped_gemm(
|
| | self,
|
| | hidden_states: torch.Tensor,
|
| | top_k_index: torch.Tensor,
|
| | top_k_weights: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Grouped GEMM MoE forward pass using Triton kernels.
|
| | Compatible with torch.compile (recommended mode="max-autotune" with cudagraph_mark_step_begin).
|
| | """
|
| |
|
| |
|
| |
|
| | from unsloth.kernels.moe.grouped_gemm.interface import grouped_gemm
|
| |
|
| |
|
| | from unsloth.kernels.moe.autotune_cache import get_or_autotune_moe_kernels
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if not hasattr(self, "_unsloth_moe_configs"):
|
| | self._unsloth_moe_configs = None
|
| |
|
| | use_separated_lora = _should_use_separated_lora()
|
| |
|
| |
|
| |
|
| | is_3d = hidden_states.dim() == 3
|
| | if is_3d:
|
| | batch_size, seq_len, hidden_dim = hidden_states.shape
|
| | hidden_states = hidden_states.view(-1, hidden_dim)
|
| | num_tokens = batch_size * seq_len
|
| |
|
| | if top_k_index.dim() == 3:
|
| | top_k_index = top_k_index.view(-1, top_k_index.shape[-1])
|
| | if top_k_weights.dim() == 3:
|
| | top_k_weights = top_k_weights.view(-1, top_k_weights.shape[-1])
|
| | else:
|
| | num_tokens, hidden_dim = hidden_states.shape
|
| |
|
| | top_k = top_k_index.shape[1]
|
| |
|
| |
|
| | if self._unsloth_moe_configs is None:
|
| | intermediate_dim = self.gate_up_proj.shape[1] // 2
|
| |
|
| |
|
| | gemm1_configs = get_or_autotune_moe_kernels(
|
| | num_experts=self.num_experts,
|
| | hidden_dim=hidden_dim,
|
| | intermediate_dim=intermediate_dim * 2,
|
| | top_k=top_k,
|
| | dtype=hidden_states.dtype,
|
| | )
|
| |
|
| |
|
| | gemm2_configs = get_or_autotune_moe_kernels(
|
| | num_experts=self.num_experts,
|
| | hidden_dim=intermediate_dim,
|
| | intermediate_dim=hidden_dim,
|
| | top_k=top_k,
|
| | dtype=hidden_states.dtype,
|
| | )
|
| |
|
| | self._unsloth_moe_configs = (intermediate_dim, gemm1_configs, gemm2_configs)
|
| |
|
| |
|
| | torch.cuda.empty_cache()
|
| |
|
| |
|
| | intermediate_dim, gemm1_configs, gemm2_configs = self._unsloth_moe_configs
|
| |
|
| |
|
| | fwd_config_1, bwd_dX_config_1, bwd_dW_config_1 = gemm1_configs
|
| | fwd_config_2, bwd_dX_config_2, bwd_dW_config_2 = gemm2_configs
|
| |
|
| |
|
| | token_counts_by_expert, gather_indices = _get_routing_indices(
|
| | top_k_index, self.num_experts
|
| | )
|
| | offsets = torch.cumsum(token_counts_by_expert, dim=0, dtype=torch.int32)
|
| |
|
| | if self.gate_up_proj.shape[-1] == hidden_dim:
|
| | w1 = self.gate_up_proj
|
| | else:
|
| | w1 = self.gate_up_proj.transpose(-2, -1).contiguous()
|
| |
|
| |
|
| | first_gemm_output = grouped_gemm(
|
| | X=hidden_states,
|
| | W=w1,
|
| | m_sizes=token_counts_by_expert,
|
| | topk=top_k,
|
| | gather_indices=gather_indices,
|
| | permute_x=True,
|
| | permute_y=False,
|
| | autotune=False,
|
| | kernel_config_fwd=fwd_config_1,
|
| | kernel_config_bwd_dX=bwd_dX_config_1,
|
| | kernel_config_bwd_dW=bwd_dW_config_1,
|
| | is_first_gemm=True,
|
| | )
|
| |
|
| |
|
| | intermediate = _silu_and_mul(first_gemm_output)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | down_lora = None
|
| | if getattr(self, "_unsloth_lora_down_proj", None) is not None:
|
| | down_lora = self._unsloth_lora_down_proj[:3]
|
| | elif (
|
| | use_separated_lora
|
| | and hasattr(self, "down_proj")
|
| | and _has_lora_adapters(self.down_proj)
|
| | ):
|
| | down_lora = _extract_lora_weights(self.down_proj, num_experts=self.num_experts)
|
| |
|
| | if self.down_proj.shape[-1] == intermediate.shape[-1]:
|
| | w2 = self.down_proj
|
| | else:
|
| | w2 = self.down_proj.transpose(-2, -1).contiguous()
|
| |
|
| | second_gemm_output = grouped_gemm(
|
| | X=intermediate,
|
| | W=w2,
|
| | m_sizes=token_counts_by_expert,
|
| | topk=top_k,
|
| | gather_indices=gather_indices,
|
| | permute_x=False,
|
| | permute_y=True,
|
| | autotune=False,
|
| | kernel_config_fwd=fwd_config_2,
|
| | kernel_config_bwd_dX=bwd_dX_config_2,
|
| | kernel_config_bwd_dW=bwd_dW_config_2,
|
| | is_first_gemm=False,
|
| | )
|
| |
|
| |
|
| | if down_lora is not None:
|
| | first_weight, second_weight, scaling = down_lora
|
| |
|
| |
|
| |
|
| |
|
| | first_weight = first_weight.to(intermediate.dtype)
|
| | second_weight = second_weight.to(intermediate.dtype)
|
| |
|
| | lora_delta = _apply_lora_grouped_mm(
|
| | intermediate,
|
| | first_weight,
|
| | second_weight,
|
| | offsets,
|
| | scaling,
|
| | grouped_mm_func=native_moe_grouped_mm
|
| | )
|
| |
|
| | second_gemm_output = second_gemm_output + lora_delta
|
| |
|
| |
|
| |
|
| |
|
| | top_k_weights_casted = top_k_weights.to(hidden_states.dtype)
|
| | final_hidden_states = (
|
| | second_gemm_output.view(num_tokens, top_k, hidden_dim)
|
| | * top_k_weights_casted[..., None]
|
| | )
|
| | final_hidden_states = final_hidden_states.sum(dim=1)
|
| |
|
| | if is_3d:
|
| | final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
|
| |
|
| | return final_hidden_states
|
| |
|
| |
|
| | @torch.compiler.disable
|
| | def forward_native_moe_loop(
|
| | self,
|
| | hidden_states: torch.Tensor,
|
| | top_k_index: torch.Tensor,
|
| | top_k_weights: torch.Tensor,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Loop-based MoE forward pass. Loops over experts that have tokens routed to them.
|
| | Explicitly disabled for torch.compile to prevent graph breaks/recompilation issues with dynamic control flow.
|
| | """
|
| |
|
| | final_hidden_states = torch.zeros_like(hidden_states)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts)
|
| | expert_mask = expert_mask.permute(2, 1, 0)
|
| | expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| |
|
| |
|
| | for expert_idx_t in expert_hit:
|
| | expert_idx = expert_idx_t.item()
|
| |
|
| |
|
| | top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
| |
|
| |
|
| | current_state = hidden_states[token_idx]
|
| |
|
| |
|
| |
|
| | if hasattr(self, "gate_up_proj"):
|
| | gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(
|
| | 2, dim=-1
|
| | )
|
| | else:
|
| | gate = F.linear(current_state, self.w1[expert_idx])
|
| | up = F.linear(current_state, self.w3[expert_idx])
|
| |
|
| | current_hidden_states = self.act_fn(gate) * up
|
| |
|
| |
|
| | if hasattr(self, "down_proj"):
|
| | current_hidden_states = F.linear(
|
| | current_hidden_states, self.down_proj[expert_idx]
|
| | )
|
| | else:
|
| | current_hidden_states = F.linear(current_hidden_states, self.w2[expert_idx])
|
| |
|
| |
|
| | current_hidden_states = (
|
| | current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
|
| | )
|
| |
|
| |
|
| | final_hidden_states.index_add_(
|
| | 0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
|
| | )
|
| |
|
| | return final_hidden_states
|
| |
|