|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib.resources |
|
import string |
|
from abc import ABC, abstractmethod |
|
from typing import ClassVar, Dict, List, Optional, Tuple, Union |
|
|
|
from pydantic import BaseModel, Field |
|
from transformers import PretrainedConfig |
|
from typing_extensions import Literal |
|
|
|
import mergekit._data.architectures |
|
|
|
|
|
class WeightInfo(BaseModel, frozen=True): |
|
"""Information about an individual weight tensor in a model. |
|
|
|
Attributes: |
|
name (str): |
|
The name of the tensor representing the weight. |
|
is_embed (bool): |
|
Indicates whether the weight is for an embedding or language model head. |
|
input_space (Optional[str]): |
|
The name of the input space associated with the weight, if applicable. |
|
output_space (Optional[str]): |
|
The name of the output space associated with the weight, if applicable. |
|
optional (bool): |
|
Indicates whether the weight can be omitted from a model. |
|
aliases (Optional[List[str]]): |
|
List of alternative names for the weight, if applicable. |
|
""" |
|
|
|
name: str |
|
is_embed: bool = False |
|
input_space: Optional[str] = None |
|
output_space: Optional[str] = None |
|
optional: bool = False |
|
aliases: Optional[List[str]] = None |
|
|
|
|
|
class ProceduralSpaceInfo(BaseModel, frozen=True): |
|
"""Defines a procedural space computed from one or more other spaces. |
|
|
|
Currently only supports residual connections. |
|
|
|
Attributes: |
|
name (str): The name of the space defined. |
|
type (str): The type of procedural space. |
|
inputs (List[str]): List of names of spaces used to define this space.""" |
|
|
|
name: str |
|
type: Literal["residual"] |
|
inputs: List[str] |
|
|
|
|
|
class ArchitectureInfo(ABC): |
|
@abstractmethod |
|
def name(self) -> str: |
|
"""Return the name of the architecture.""" |
|
... |
|
|
|
@abstractmethod |
|
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: |
|
"""Return a list of all weights preceding the first layer.""" |
|
... |
|
|
|
@abstractmethod |
|
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: |
|
"""Return a list of all weights following the final layer.""" |
|
... |
|
|
|
@abstractmethod |
|
def layer_weights( |
|
self, index: int, config: PretrainedConfig |
|
) -> Optional[List[WeightInfo]]: |
|
"""Return a list of all weights associated with a given layer.""" |
|
... |
|
|
|
@abstractmethod |
|
def sliceable(self) -> bool: |
|
""" |
|
Return True if the layers of this architecture can be meaningfully sliced. |
|
""" |
|
... |
|
|
|
def num_layers_config_key(self) -> str: |
|
"""Key in config that represents number of layers""" |
|
return "num_hidden_layers" |
|
|
|
def num_layers(self, config: PretrainedConfig) -> int: |
|
"""Return the number of layers in a model.""" |
|
return getattr(config, self.num_layers_config_key()) |
|
|
|
def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: |
|
"""Return all weights associated with a model.""" |
|
num_layers = self.num_layers(config) |
|
res = list(self.pre_weights(config)) |
|
for layer_idx in range(num_layers): |
|
res.extend(self.layer_weights(layer_idx, config)) |
|
res.extend(self.post_weights(config)) |
|
return res |
|
|
|
def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: |
|
"""Return a list of all procedurally defined spaces in a model.""" |
|
return [] |
|
|
|
def has_defined_spaces(self) -> bool: |
|
""" |
|
Return True if this architecture defines space information needed for |
|
matching-based merge methods. |
|
""" |
|
return False |
|
|
|
|
|
class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): |
|
info: ArchitectureInfo |
|
config: PretrainedConfig |
|
|
|
def name(self) -> str: |
|
return self.info.name() |
|
|
|
def num_layers(self) -> int: |
|
return self.info.num_layers(self.config) |
|
|
|
def pre_weights(self) -> List[WeightInfo]: |
|
return self.info.pre_weights(self.config) |
|
|
|
def post_weights(self) -> List[WeightInfo]: |
|
return self.info.post_weights(self.config) |
|
|
|
def layer_weights(self, index: int) -> List[WeightInfo]: |
|
return self.info.layer_weights(index, self.config) |
|
|
|
def procedural_spaces(self) -> List[ProceduralSpaceInfo]: |
|
return self.info.procedural_spaces(self.config) |
|
|
|
def all_weights(self) -> List[WeightInfo]: |
|
return self.info.all_weights(self.config) |
|
|
|
|
|
class JSONLayerTemplates(BaseModel, frozen=True): |
|
weights: List[WeightInfo] |
|
procedural_spaces: Optional[List[ProceduralSpaceInfo]] = None |
|
|
|
|
|
class JSONArchitectureDefinition(BaseModel, frozen=True): |
|
expected_model_type: str = Field(alias="model_type") |
|
architectures: List[str] |
|
pre_weights: List[WeightInfo] |
|
layer_templates: JSONLayerTemplates |
|
post_weights: List[WeightInfo] |
|
procedural_spaces: Optional[List[ProceduralSpaceInfo]] = None |
|
num_layers_config_key: Optional[str] = None |
|
|
|
|
|
class TemplateWithArithmetic(string.Template): |
|
idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)" |
|
|
|
|
|
def _template_substitution( |
|
template: str, num_layers: int, layer_idx: Optional[int] = None |
|
) -> str: |
|
if "{" not in template: |
|
return template |
|
|
|
substitutions = { |
|
"num_layers": num_layers, |
|
"num_layers+1": num_layers + 1, |
|
"num_layers-1": num_layers - 1, |
|
} |
|
|
|
if layer_idx is not None: |
|
substitutions.update( |
|
{ |
|
"layer_index": layer_idx, |
|
"layer_index+1": layer_idx + 1, |
|
"layer_index-1": layer_idx - 1, |
|
} |
|
) |
|
|
|
return TemplateWithArithmetic(template).substitute(substitutions) |
|
|
|
|
|
class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True): |
|
definition: JSONArchitectureDefinition |
|
|
|
def _substitute( |
|
self, |
|
item: Union[WeightInfo, ProceduralSpaceInfo], |
|
config: PretrainedConfig, |
|
layer_idx: Optional[int] = None, |
|
) -> Union[WeightInfo, ProceduralSpaceInfo]: |
|
num_layers = self.num_layers(config) |
|
|
|
obj_dict = item.model_dump(mode="json", exclude_unset=True) |
|
for key in obj_dict: |
|
if isinstance(obj_dict[key], str): |
|
obj_dict[key] = _template_substitution( |
|
obj_dict[key], num_layers, layer_idx |
|
) |
|
elif isinstance(obj_dict[key], list): |
|
obj_dict[key] = [ |
|
( |
|
_template_substitution(s, num_layers, layer_idx) |
|
if isinstance(s, str) |
|
else s |
|
) |
|
for s in obj_dict[key] |
|
] |
|
return type(item).model_validate(obj_dict) |
|
|
|
def name(self) -> str: |
|
return self.definition.expected_model_type |
|
|
|
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: |
|
return [ |
|
self._substitute(wi, config=config) for wi in self.definition.pre_weights |
|
] |
|
|
|
def layer_weights( |
|
self, index: int, config: PretrainedConfig |
|
) -> Optional[List[WeightInfo]]: |
|
return [ |
|
self._substitute(wi, config=config, layer_idx=index) |
|
for wi in self.definition.layer_templates.weights |
|
] |
|
|
|
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: |
|
return [ |
|
self._substitute(wi, config=config) for wi in self.definition.post_weights |
|
] |
|
|
|
def sliceable(self) -> bool: |
|
return True |
|
|
|
def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: |
|
res = [] |
|
for s in self.definition.procedural_spaces or []: |
|
res.append(self._substitute(s, config=config)) |
|
for idx in range(self.num_layers(config)): |
|
for s in self.definition.layer_templates.procedural_spaces or []: |
|
res.append(self._substitute(s, config=config, layer_idx=idx)) |
|
return res |
|
|
|
def has_defined_spaces(self) -> bool: |
|
if ( |
|
self.definition.procedural_spaces |
|
or self.definition.layer_templates.procedural_spaces |
|
): |
|
return True |
|
for wi in ( |
|
self.definition.layer_templates.weights |
|
+ self.definition.pre_weights |
|
+ self.definition.post_weights |
|
): |
|
if wi.input_space or wi.output_space: |
|
return True |
|
return False |
|
|
|
def num_layers_config_key(self) -> str: |
|
return self.definition.num_layers_config_key |
|
|
|
|
|
class MixtralTensorNames(ArchitectureInfo, BaseModel): |
|
ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" |
|
num_local_experts: int |
|
|
|
def name(self) -> str: |
|
return "mixtral" |
|
|
|
@classmethod |
|
def from_config(cls, config: PretrainedConfig): |
|
return MixtralTensorNames(num_local_experts=config.num_local_experts) |
|
|
|
def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: |
|
return MISTRAL_INFO.pre_weights(config) |
|
|
|
def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: |
|
return MISTRAL_INFO.post_weights(config) |
|
|
|
def num_layers_config_key(self) -> str: |
|
return MISTRAL_INFO.num_layers_config_key() |
|
|
|
def layer_weights( |
|
self, index: int, config: PretrainedConfig |
|
) -> Optional[List[WeightInfo]]: |
|
num_experts = self.num_local_experts |
|
prefix = f"model.layers.{index}" |
|
tensor_names = [] |
|
for expert_idx in range(num_experts): |
|
for param in ("w1", "w2", "w3"): |
|
tensor_names.append( |
|
prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" |
|
) |
|
tensor_names.append(prefix + ".block_sparse_moe.gate.weight") |
|
res = [] |
|
for name in tensor_names: |
|
res.append(WeightInfo(name=name)) |
|
for weight_info in MISTRAL_INFO.layer_weights(index, config): |
|
if ".mlp." in weight_info.name: |
|
continue |
|
res.append(weight_info) |
|
return res |
|
|
|
def sliceable(self) -> bool: |
|
return True |
|
|
|
def has_defined_spaces(self) -> bool: |
|
return False |
|
|
|
|
|
def _load_json_arch(name: str) -> JsonArchitectureInfo: |
|
text = importlib.resources.read_text(mergekit._data.architectures, name) |
|
return JsonArchitectureInfo( |
|
definition=JSONArchitectureDefinition.model_validate_json(text) |
|
) |
|
|
|
|
|
def _load_all_architectures() -> ( |
|
Tuple[List[JsonArchitectureInfo], Dict[str, List[JsonArchitectureInfo]]] |
|
): |
|
architectures: List[JsonArchitectureInfo] = [] |
|
for f in importlib.resources.contents(mergekit._data.architectures): |
|
if f.lower().endswith(".json"): |
|
architectures.append(_load_json_arch(f)) |
|
|
|
name_to_arch: Dict[str, List[JsonArchitectureInfo]] = {} |
|
for arch_info in architectures: |
|
for name in arch_info.definition.architectures: |
|
name_to_arch[name] = name_to_arch.get(name, []) |
|
name_to_arch[name].append(arch_info) |
|
return architectures, name_to_arch |
|
|
|
|
|
JSON_ARCHITECTURES, NAME_TO_ARCH = _load_all_architectures() |
|
MISTRAL_INFO = _load_json_arch("mistral.json") |
|
|
|
|
|
def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: |
|
if len(config.architectures) != 1: |
|
raise RuntimeError("More than one architecture in config?") |
|
|
|
arch_name = config.architectures[0] |
|
|
|
if arch_name == MixtralTensorNames.ARCHITECTURE_NAME: |
|
return MixtralTensorNames.from_config(config) |
|
|
|
if arch_name not in NAME_TO_ARCH: |
|
raise RuntimeError(f"Unsupported architecture {arch_name}") |
|
|
|
candidates = list(NAME_TO_ARCH[arch_name]) |
|
if len(candidates) == 1: |
|
return candidates[0] |
|
|
|
for c in candidates: |
|
if c.definition.expected_model_type == config.model_type: |
|
return c |
|
|
|
raise RuntimeError( |
|
f"Unsupported model_type {config.model_type} for architecture {arch_name}" |
|
) |
|
|