# ztrain/model.py # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted from collections import defaultdict import re def generate_merge_group(group_data : list, parents : list[int] = []): # drill down until we find a list of strings, then yield it with a parent tree index for i, g in enumerate(group_data): if isinstance(g, list): yield from generate_merge_group(g, parents + [i]) else: yield g, parents + [i] def merge_groups(group_data : list): results = defaultdict(list) for g, k in generate_merge_group(group_data): key = tuple(k[:-1]) results[key].append(g) return results def get_layer_type(k : str) -> tuple[int, str, str, str]: matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)\.(.+)") m = matcher.match(k) if m is not None: return int(m.group(1)), m.group(2), m.group(3), m.group(4) matcher = re.compile(r"model.layers.(\d+)\.(.+)\.(.+)") if m is not None: return int(m.group(1)), m.group(2), "", m.group(3) if "model.norm.weight" == k: return -1, "norm", "", "weight" if "model.embed_tokens.weight" == k: return -1, "embed_tokens", "", "weight" if "lm_head.weight" == k: return -1, "lm_head", "", "weight" print(f"Unknown key {k}") return -1, "unknown", "unknown", "unknown"