Narsil HF staff commited on
Commit
a5c2203
1 Parent(s): 650d083

Support `diffusers` and `stable-diffusion` (pretty much everything)

Browse files
Files changed (1) hide show
  1. convert.py +90 -34
convert.py CHANGED
@@ -2,18 +2,18 @@ import argparse
2
  import json
3
  import os
4
  import shutil
5
- from tempfile import TemporaryDirectory
6
  from collections import defaultdict
7
  from inspect import signature
8
- from typing import Optional, List
 
9
 
10
  import torch
11
 
12
- from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download, get_repo_discussions
13
  from huggingface_hub.file_download import repo_folder_name
 
14
  from transformers import AutoConfig
15
  from transformers.pipelines.base import infer_framework_load_model
16
- from safetensors.torch import save_file
17
 
18
 
19
  class AlreadyExists(Exception):
@@ -30,15 +30,18 @@ def shared_pointers(tensors):
30
  failing.append(names)
31
  return failing
32
 
 
33
  def check_file_size(sf_filename: str, pt_filename: str):
34
  sf_size = os.stat(sf_filename).st_size
35
  pt_size = os.stat(pt_filename).st_size
36
 
37
  if (sf_size - pt_size) / pt_size > 0.01:
38
- raise RuntimeError(f"""The file size different is more than 1%:
 
39
  - {sf_filename}: {sf_size}
40
  - {pt_filename}: {pt_size}
41
- """)
 
42
 
43
 
44
  def rename(pt_filename: str) -> str:
@@ -53,15 +56,14 @@ def convert_multi(model_id: str, folder: str) -> List["CommitOperationAdd"]:
53
  data = json.load(f)
54
 
55
  filenames = set(data["weight_map"].values())
 
56
  for filename in filenames:
57
- cached_filename = hf_hub_download(repo_id=model_id, filename=filename)
58
- loaded = torch.load(cached_filename)
59
- sf_filename = rename(filename)
60
 
61
- local = os.path.join(folder, sf_filename)
62
- save_file(loaded, local, metadata={"format": "pt"})
63
- check_file_size(local, cached_filename)
64
- local_filenames.append(local)
65
 
66
  index = os.path.join(folder, "model.safetensors.index.json")
67
  with open(index, "w") as f:
@@ -71,17 +73,28 @@ def convert_multi(model_id: str, folder: str) -> List["CommitOperationAdd"]:
71
  json.dump(newdata, f)
72
  local_filenames.append(index)
73
 
74
- operations = [CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames]
 
 
75
 
76
  return operations
77
 
78
 
79
  def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
80
- sf_filename = "model.safetensors"
81
- filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
82
- loaded = torch.load(filename)
 
 
 
 
83
 
84
- local = os.path.join(folder, sf_filename)
 
 
 
 
 
85
  shared = shared_pointers(loaded)
86
  for shared_weights in shared:
87
  for name in shared_weights[1:]:
@@ -90,23 +103,45 @@ def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
90
  # For tensors to be contiguous
91
  loaded = {k: v.contiguous() for k, v in loaded.items()}
92
 
93
- save_file(loaded, local, metadata={"format": "pt"})
 
 
 
 
 
 
 
 
 
94
 
95
- check_file_size(local, filename)
96
 
97
- operations = [CommitOperationAdd(path_in_repo=sf_filename, path_or_fileobj=local)]
98
- return operations
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def check_final_model(model_id: str, folder: str):
101
  config = hf_hub_download(repo_id=model_id, filename="config.json")
102
  shutil.copy(config, os.path.join(folder, "config.json"))
103
  config = AutoConfig.from_pretrained(folder)
104
 
105
- _, pt_model = infer_framework_load_model(model_id, config)
106
- _, sf_model = infer_framework_load_model(folder, config)
107
 
108
- pt_model = pt_model
109
- sf_model = sf_model
 
110
 
111
  pt_params = pt_model.state_dict()
112
  sf_params = sf_model.state_dict()
@@ -134,7 +169,6 @@ def check_final_model(model_id: str, folder: str):
134
  if "image" in sig.parameters:
135
  kwargs["image"] = pixel_values
136
 
137
-
138
  if torch.cuda.is_available():
139
  pt_model = pt_model.cuda()
140
  sf_model = sf_model.cuda()
@@ -146,6 +180,7 @@ def check_final_model(model_id: str, folder: str):
146
  torch.testing.assert_close(sf_logits, pt_logits)
147
  print(f"Model {model_id} is ok !")
148
 
 
149
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
150
  try:
151
  discussions = api.get_repo_discussions(repo_id=model_id)
@@ -156,7 +191,22 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
156
  return discussion
157
 
158
 
159
- def convert(api: "HfApi", model_id: str, force: bool=False) -> Optional["CommitInfo"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  pr_title = "Adding `safetensors` variant of this model"
161
  info = api.model_info(model_id)
162
  filenames = set(s.rfilename for s in info.siblings)
@@ -174,21 +224,27 @@ def convert(api: "HfApi", model_id: str, force: bool=False) -> Optional["CommitI
174
  url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
175
  new_pr = pr
176
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
177
- elif "pytorch_model.bin" in filenames:
178
- operations = convert_single(model_id, folder)
179
- elif "pytorch_model.bin.index.json" in filenames:
180
- operations = convert_multi(model_id, folder)
 
 
 
 
181
  else:
182
- raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
183
 
184
  if operations:
185
- check_final_model(model_id, folder)
186
  new_pr = api.create_commit(
187
  repo_id=model_id,
188
  operations=operations,
189
  commit_message=pr_title,
190
  create_pr=True,
191
  )
 
 
 
192
  finally:
193
  shutil.rmtree(folder)
194
  return new_pr
 
2
  import json
3
  import os
4
  import shutil
 
5
  from collections import defaultdict
6
  from inspect import signature
7
+ from tempfile import TemporaryDirectory
8
+ from typing import Dict, List, Optional, Set
9
 
10
  import torch
11
 
12
+ from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
13
  from huggingface_hub.file_download import repo_folder_name
14
+ from safetensors.torch import load_file, save_file
15
  from transformers import AutoConfig
16
  from transformers.pipelines.base import infer_framework_load_model
 
17
 
18
 
19
  class AlreadyExists(Exception):
 
30
  failing.append(names)
31
  return failing
32
 
33
+
34
  def check_file_size(sf_filename: str, pt_filename: str):
35
  sf_size = os.stat(sf_filename).st_size
36
  pt_size = os.stat(pt_filename).st_size
37
 
38
  if (sf_size - pt_size) / pt_size > 0.01:
39
+ raise RuntimeError(
40
+ f"""The file size different is more than 1%:
41
  - {sf_filename}: {sf_size}
42
  - {pt_filename}: {pt_size}
43
+ """
44
+ )
45
 
46
 
47
  def rename(pt_filename: str) -> str:
 
56
  data = json.load(f)
57
 
58
  filenames = set(data["weight_map"].values())
59
+ local_filenames = []
60
  for filename in filenames:
61
+ pt_filename = hf_hub_download(repo_id=model_id, filename=filename)
 
 
62
 
63
+ sf_filename = rename(pt_filename)
64
+ sf_filename = os.path.join(folder, sf_filename)
65
+ convert_file(pt_filename, sf_filename)
66
+ local_filenames.append(sf_filename)
67
 
68
  index = os.path.join(folder, "model.safetensors.index.json")
69
  with open(index, "w") as f:
 
73
  json.dump(newdata, f)
74
  local_filenames.append(index)
75
 
76
+ operations = [
77
+ CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames
78
+ ]
79
 
80
  return operations
81
 
82
 
83
  def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
84
+ pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
85
+
86
+ sf_name = "model.safetensors"
87
+ sf_filename = os.path.join(folder, sf_name)
88
+ convert_file(pt_filename, sf_filename)
89
+ operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)]
90
+ return operations
91
 
92
+
93
+ def convert_file(
94
+ pt_filename: str,
95
+ sf_filename: str,
96
+ ):
97
+ loaded = torch.load(pt_filename)
98
  shared = shared_pointers(loaded)
99
  for shared_weights in shared:
100
  for name in shared_weights[1:]:
 
103
  # For tensors to be contiguous
104
  loaded = {k: v.contiguous() for k, v in loaded.items()}
105
 
106
+ dirname = sf_filename.rsplit(os.path.sep, 1)[0]
107
+ os.makedirs(dirname, exist_ok=True)
108
+ save_file(loaded, sf_filename, metadata={"format": "pt"})
109
+ check_file_size(sf_filename, pt_filename)
110
+ reloaded = load_file(sf_filename)
111
+ for k in loaded:
112
+ pt_tensor = loaded[k]
113
+ sf_tensor = reloaded[k]
114
+ if not torch.equal(pt_tensor, sf_tensor):
115
+ raise RuntimeError(f"The output tensors do not match for key {k}")
116
 
 
117
 
118
+ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]]) -> str:
119
+ errors = []
120
+ for key in ["missing_keys", "mismatched_keys", "unexpected_keys"]:
121
+ pt_set = set(pt_infos[key])
122
+ sf_set = set(sf_infos[key])
123
+
124
+ pt_only = pt_set - sf_set
125
+ sf_only = sf_set - pt_set
126
+
127
+ if pt_only:
128
+ errors.append(f"{key} : PT warnings contain {pt_only} which are not present in SF warnings")
129
+ if sf_only:
130
+ errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
131
+ return "\n".join(errors)
132
+
133
 
134
  def check_final_model(model_id: str, folder: str):
135
  config = hf_hub_download(repo_id=model_id, filename="config.json")
136
  shutil.copy(config, os.path.join(folder, "config.json"))
137
  config = AutoConfig.from_pretrained(folder)
138
 
139
+ _, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
140
+ _, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)
141
 
142
+ if pt_infos != sf_infos:
143
+ error_string = create_diff(pt_infos, sf_infos)
144
+ raise ValueError(f"Different infos when reloading the model: {error_string}")
145
 
146
  pt_params = pt_model.state_dict()
147
  sf_params = sf_model.state_dict()
 
169
  if "image" in sig.parameters:
170
  kwargs["image"] = pixel_values
171
 
 
172
  if torch.cuda.is_available():
173
  pt_model = pt_model.cuda()
174
  sf_model = sf_model.cuda()
 
180
  torch.testing.assert_close(sf_logits, pt_logits)
181
  print(f"Model {model_id} is ok !")
182
 
183
+
184
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
185
  try:
186
  discussions = api.get_repo_discussions(repo_id=model_id)
 
191
  return discussion
192
 
193
 
194
+ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["CommitOperationAdd"]:
195
+ operations = []
196
+
197
+ extensions = set([".bin", ".ckpt"])
198
+ for filename in filenames:
199
+ prefix, ext = os.path.splitext(filename)
200
+ if ext in extensions:
201
+ pt_filename = hf_hub_download(model_id, filename=filename)
202
+ sf_in_repo = f"{filename}.safetensors"
203
+ sf_filename = os.path.join(folder, sf_in_repo)
204
+ convert_file(pt_filename, sf_filename)
205
+ operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
206
+ return operations
207
+
208
+
209
+ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
210
  pr_title = "Adding `safetensors` variant of this model"
211
  info = api.model_info(model_id)
212
  filenames = set(s.rfilename for s in info.siblings)
 
224
  url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
225
  new_pr = pr
226
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
227
+ elif info.library_name == "transformers":
228
+ if "pytorch_model.bin" in filenames:
229
+ operations = convert_single(model_id, folder)
230
+ elif "pytorch_model.bin.index.json" in filenames:
231
+ operations = convert_multi(model_id, folder)
232
+ else:
233
+ raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
234
+ check_final_model(model_id, folder)
235
  else:
236
+ operations = convert_generic(model_id, folder, filenames)
237
 
238
  if operations:
 
239
  new_pr = api.create_commit(
240
  repo_id=model_id,
241
  operations=operations,
242
  commit_message=pr_title,
243
  create_pr=True,
244
  )
245
+ print(f"Pr created at {new_pr.pr_url}")
246
+ else:
247
+ print("No files to convert")
248
  finally:
249
  shutil.rmtree(folder)
250
  return new_pr