import torch import sys from safetensors.torch import save_file model = torch.load('pytorch_model.bin', map_location="cpu") for key, value in model.items(): print(key) # sys.exit() # Comment this out after you've found the key you wish to stop at, and also the final key you wish to keep. stop_key = 'h.12.ln_1.weight' # The key you want to stop at (the previous key is kept) final_keys = ['ln_f.weight', 'ln_f.bias'] # The final key/keys in the model which get saved. stripped = {} stopped = False for key, value in model.items(): if key == stop_key: stopped = True continue if key in final_keys: stripped[key] = value if stopped is False: stripped[key] = value save_file(stripped, 'pytorch_model_stripped.safetensors') torch.save(stripped, 'pytorch_model_stripped.bin')