File size: 4,828 Bytes
66662af dede4f0 66662af dede4f0 ff90924 dede4f0 66662af dede4f0 ff90924 dede4f0 ff90924 dede4f0 66662af ff90924 dede4f0 66662af dede4f0 ff90924 dede4f0 66662af dede4f0 66662af dede4f0 66662af dede4f0 66662af ff90924 66662af dede4f0 ff90924 dede4f0 66662af dede4f0 ff90924 dede4f0 d800fd4 dede4f0 ff90924 dede4f0 ff90924 66662af dede4f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import argparse
import json
import os
import shutil
import torch
from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import save_file
from transformers import AutoConfig
from transformers.pipelines.base import infer_framework_load_model
def check_file_size(sf_filename, pt_filename):
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
"""
)
def rename(pt_filename) -> str:
local = pt_filename.replace(".bin", ".safetensors")
local = local.replace("pytorch_model", "model")
return local
def convert_multi(model_id, folder):
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
with open(filename, "r") as f:
data = json.load(f)
filenames = set(data["weight_map"].values())
local_filenames = []
for filename in filenames:
cached_filename = hf_hub_download(repo_id=model_id, filename=filename)
loaded = torch.load(cached_filename)
sf_filename = rename(filename)
local = os.path.join(folder, sf_filename)
save_file(loaded, local, metadata={"format": "pt"})
check_file_size(local, cached_filename)
local_filenames.append(local)
index = os.path.join(folder, "model.safetensors.index.json")
with open(index, "w") as f:
newdata = {k: v for k, v in data.items()}
newmap = {k: rename(v) for k, v in data["weight_map"].items()}
newdata["weight_map"] = newmap
json.dump(newdata, f)
local_filenames.append(index)
operations = [
CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames
]
return operations
def convert_single(model_id, folder):
sf_filename = "model.safetensors"
filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
loaded = torch.load(filename)
local = os.path.join(folder, sf_filename)
save_file(loaded, local, metadata={"format": "pt"})
check_file_size(local, filename)
operations = [CommitOperationAdd(path_in_repo=sf_filename, path_or_fileobj=local)]
return operations
def check_final_model(model_id, folder):
config = hf_hub_download(repo_id=model_id, filename="config.json")
shutil.copy(config, os.path.join(folder, "config.json"))
config = AutoConfig.from_pretrained(folder)
_, sf_model = infer_framework_load_model(folder, config)
_, pt_model = infer_framework_load_model(model_id, config)
input_ids = torch.arange(10).long().unsqueeze(0)
sf_logits = sf_model(input_ids)
pt_logits = pt_model(input_ids)
torch.testing.assert_close(sf_logits, pt_logits)
print(f"Model {model_id} is ok !")
def convert(api, model_id):
info = api.model_info(model_id)
filenames = set(s.rfilename for s in info.siblings)
folder = repo_folder_name(repo_id=model_id, repo_type="models")
os.makedirs(folder)
new_pr = None
try:
operations = None
if "model.safetensors" in filenames or "model_index.safetensors.index.json" in filenames:
raise RuntimeError(f"Model {model_id} is already converted, skipping..")
elif "pytorch_model.bin" in filenames:
operations = convert_single(model_id, folder)
elif "pytorch_model.bin.index.json" in filenames:
operations = convert_multi(model_id, folder)
else:
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
if operations:
check_final_model(model_id, folder)
new_pr = api.create_commit(
repo_id=model_id,
operations=operations,
commit_message="Adding `safetensors` variant of this model",
create_pr=True,
)
finally:
shutil.rmtree(folder)
return new_pr
if __name__ == "__main__":
DESCRIPTION = """
Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
It is PyTorch exclusive for now.
It works by downloading the weights (PT), converting them locally, and uploading them back
as a PR on the hub.
"""
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument(
"model_id",
type=str,
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
)
args = parser.parse_args()
model_id = args.model_id
api = HfApi()
convert(api, model_id)
|