GPT-2-XL-Stripped / stripper-gpt2.py
xzuyn's picture
Upload 8 files
31b540f
import torch
import sys
from safetensors.torch import save_file
model = torch.load('pytorch_model.bin', map_location="cpu")
for key, value in model.items():
print(key)
# sys.exit() # Comment this out after you've found the key you wish to stop at, and also the final key you wish to keep.
stop_key = 'h.12.ln_1.weight' # The key you want to stop at (the previous key is kept)
final_keys = ['ln_f.weight', 'ln_f.bias'] # The final key/keys in the model which get saved.
stripped = {}
stopped = False
for key, value in model.items():
if key == stop_key:
stopped = True
continue
if key in final_keys:
stripped[key] = value
if stopped is False:
stripped[key] = value
save_file(stripped, 'pytorch_model_stripped.safetensors')
torch.save(stripped, 'pytorch_model_stripped.bin')