| | """Merge LoRA adapter into base model weights. |
| | |
| | Usage: |
| | pip install torch transformers safetensors tqdm |
| | python merge.py --output ./merged_model |
| | |
| | Loads the MXFP4 base model, dequantizes to bf16, applies LoRA deltas, saves merged model. |
| | Requires ~300GB RAM. No GPU needed. |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import shutil |
| | from pathlib import Path |
| |
|
| | import torch |
| | from huggingface_hub import snapshot_download |
| | from safetensors.torch import load_file, save_file |
| | from tqdm import tqdm |
| | from transformers import AutoModelForCausalLM |
| |
|
| | BASE_MODEL = "openai/gpt-oss-120b" |
| | ADAPTER_REPO = "LightningRodLabs/Trump-Forecaster" |
| |
|
| |
|
| | def merge(output_dir: str): |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | print("Downloading adapter...") |
| | adapter_dir = Path(snapshot_download(ADAPTER_REPO)) |
| | adapter_config = json.loads((adapter_dir / "adapter_config.json").read_text()) |
| | scaling = adapter_config["lora_alpha"] / adapter_config["r"] |
| | adapter_weights = load_file(str(adapter_dir / "adapter_model.safetensors")) |
| | print(f"Adapter: {len(adapter_weights)} keys, scaling={scaling}") |
| |
|
| | |
| | print("Loading base model (this takes a while — ~240GB bf16)...") |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cpu", trust_remote_code=True, |
| | ) |
| | state_dict = base_model.state_dict() |
| | del base_model |
| |
|
| | |
| | lora_pairs = {} |
| | for key, tensor in adapter_weights.items(): |
| | clean = key.replace("base_model.model.", "", 1) |
| | if ".lora_A.weight" in clean: |
| | lora_pairs.setdefault(clean.replace(".lora_A.weight", ""), {})["A"] = tensor |
| | elif ".lora_B.weight" in clean: |
| | lora_pairs.setdefault(clean.replace(".lora_B.weight", ""), {})["B"] = tensor |
| |
|
| | |
| | |
| | |
| | base_key_ops = {} |
| | for adapter_path in lora_pairs: |
| | if "unembed_tokens" in adapter_path: |
| | base_key_ops.setdefault("lm_head.weight", []).append(("add", adapter_path)) |
| | elif ".attn." in adapter_path: |
| | base_key = adapter_path.replace(".attn.", ".self_attn.") + ".weight" |
| | base_key_ops.setdefault(base_key, []).append(("add", adapter_path)) |
| | elif ".mlp.experts.w1" in adapter_path: |
| | prefix = adapter_path.split(".mlp.experts.w1")[0] |
| | base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("even_t", adapter_path)) |
| | elif ".mlp.experts.w3" in adapter_path: |
| | prefix = adapter_path.split(".mlp.experts.w3")[0] |
| | base_key_ops.setdefault(prefix + ".mlp.experts.gate_up_proj", []).append(("odd_t", adapter_path)) |
| | elif ".mlp.experts.w2" in adapter_path: |
| | prefix = adapter_path.split(".mlp.experts.w2")[0] |
| | base_key_ops.setdefault(prefix + ".mlp.experts.down_proj", []).append(("add_t", adapter_path)) |
| |
|
| | |
| | for base_key, ops in tqdm(sorted(base_key_ops.items()), desc="Merging LoRA"): |
| | w = state_dict[base_key].float() |
| | for op_type, adapter_path in ops: |
| | A = lora_pairs[adapter_path]["A"].float() |
| | B = lora_pairs[adapter_path]["B"].float() |
| | delta = torch.matmul(B, A) * scaling |
| | if op_type == "add": |
| | w += delta |
| | elif op_type == "even_t": |
| | w[:, :, ::2] += delta.transpose(1, 2) |
| | elif op_type == "odd_t": |
| | w[:, :, 1::2] += delta.transpose(1, 2) |
| | elif op_type == "add_t": |
| | w += delta.transpose(1, 2) |
| | state_dict[base_key] = w.to(torch.bfloat16) |
| |
|
| | |
| | print(f"Saving to {output_dir}...") |
| | max_shard = 5 * 1024**3 |
| | shards, current, size = [], {}, 0 |
| | for k, v in state_dict.items(): |
| | nbytes = v.numel() * v.element_size() |
| | if size + nbytes > max_shard and current: |
| | shards.append(current) |
| | current, size = {}, 0 |
| | current[k] = v |
| | size += nbytes |
| | if current: |
| | shards.append(current) |
| |
|
| | weight_map, total = {}, 0 |
| | for i, shard in enumerate(shards): |
| | fname = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors" |
| | save_file(shard, str(output_dir / fname)) |
| | for k, v in shard.items(): |
| | weight_map[k] = fname |
| | total += v.numel() * v.element_size() |
| |
|
| | (output_dir / "model.safetensors.index.json").write_text( |
| | json.dumps({"metadata": {"total_size": total}, "weight_map": weight_map}, indent=2) |
| | ) |
| |
|
| | |
| | base_cache = Path(snapshot_download(BASE_MODEL, allow_patterns=["*.py", "*.json", "tokenizer*", "*.model"])) |
| | for f in base_cache.iterdir(): |
| | if f.is_file() and f.name != "model.safetensors.index.json": |
| | shutil.copy2(f, output_dir / f.name) |
| | cfg = json.loads((output_dir / "config.json").read_text()) |
| | cfg.pop("quantization_config", None) |
| | cfg["torch_dtype"] = "bfloat16" |
| | (output_dir / "config.json").write_text(json.dumps(cfg, indent=2)) |
| |
|
| | print(f"Done! Merged model saved to {output_dir} ({len(shards)} shards)") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--output", required=True, help="Output directory for merged model") |
| | merge(parser.parse_args().output) |
| |
|