Gemma4-Text / extract_text_model.py
OpenMOSE's picture
Upload folder using huggingface_hub
15dcafa
"""
Extract the language model (text-only) weights from Gemma 4 multimodal safetensors.
- Filters keys containing 'language_model'
- Renames: model.language_model.X -> model.X
- Saves as sharded safetensors (10GB per shard)
- Generates model.safetensors.index.json
"""
import glob
import json
import os
import torch
from safetensors import safe_open
from safetensors.torch import save_file
SRC_DIR = "/workspace/llm/gemma-4-31B-it"
DST_DIR = "/workspace/llm/gemma-4-31B-Text"
MAX_SHARD_SIZE = 10 * 1024 * 1024 * 1024 # 10GB
def main():
os.makedirs(DST_DIR, exist_ok=True)
src_files = sorted(glob.glob(os.path.join(SRC_DIR, "*.safetensors")))
print(f"Source files: {len(src_files)}")
# Step 1: Collect all language_model tensors with renamed keys
all_tensors = {}
for path in src_files:
print(f"Reading {os.path.basename(path)}...")
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
if "language_model" in key:
new_key = key.replace("model.language_model.", "model.")
all_tensors[new_key] = f.get_tensor(key)
print(f"Extracted {len(all_tensors)} tensors")
# Step 2: Split into shards by size
shards = []
current_shard = {}
current_size = 0
for key in sorted(all_tensors.keys()):
tensor = all_tensors[key]
tensor_size = tensor.nelement() * tensor.element_size()
if current_shard and current_size + tensor_size > MAX_SHARD_SIZE:
shards.append(current_shard)
current_shard = {}
current_size = 0
current_shard[key] = tensor
current_size += tensor_size
if current_shard:
shards.append(current_shard)
print(f"Splitting into {len(shards)} shards")
# Step 3: Save each shard and build weight_map
total_shards = len(shards)
weight_map = {}
for i, shard in enumerate(shards):
filename = f"model-{i+1:05d}-of-{total_shards:05d}.safetensors"
filepath = os.path.join(DST_DIR, filename)
shard_size = sum(t.nelement() * t.element_size() for t in shard.values())
print(f"Saving {filename} ({shard_size / 1e9:.2f} GB, {len(shard)} tensors)...")
save_file(shard, filepath)
for key in shard:
weight_map[key] = filename
# Step 4: Write index file
index = {
"metadata": {"total_size": sum(t.nelement() * t.element_size() for t in all_tensors.values())},
"weight_map": weight_map,
}
index_path = os.path.join(DST_DIR, "model.safetensors.index.json")
with open(index_path, "w") as f:
json.dump(index, f, indent=2)
print(f"Done! Index written to {index_path}")
if __name__ == "__main__":
main()