Narsil HF staff commited on
Commit
a3c5547
1 Parent(s): dd5ad09

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +62 -117
convert.py CHANGED
@@ -11,9 +11,7 @@ import torch
11
 
12
  from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
13
  from huggingface_hub.file_download import repo_folder_name
14
- from safetensors.torch import load_file, save_file
15
- from transformers import AutoConfig
16
- from transformers.pipelines.base import infer_framework_load_model
17
 
18
 
19
  COMMIT_DESCRIPTION = """
@@ -40,30 +38,6 @@ class AlreadyExists(Exception):
40
  pass
41
 
42
 
43
- def shared_pointers(tensors):
44
- ptrs = defaultdict(list)
45
- for k, v in tensors.items():
46
- ptrs[v.data_ptr()].append(k)
47
- failing = []
48
- for ptr, names in ptrs.items():
49
- if len(names) > 1:
50
- failing.append(names)
51
- return failing
52
-
53
-
54
- def check_file_size(sf_filename: str, pt_filename: str):
55
- sf_size = os.stat(sf_filename).st_size
56
- pt_size = os.stat(pt_filename).st_size
57
-
58
- if (sf_size - pt_size) / pt_size > 0.01:
59
- raise RuntimeError(
60
- f"""The file size different is more than 1%:
61
- - {sf_filename}: {sf_size}
62
- - {pt_filename}: {pt_size}
63
- """
64
- )
65
-
66
-
67
  def rename(pt_filename: str) -> str:
68
  filename, ext = os.path.splitext(pt_filename)
69
  local = f"{filename}.safetensors"
@@ -77,29 +51,40 @@ def convert_multi(model_id: str, folder: str) -> ConversionResult:
77
  data = json.load(f)
78
 
79
  filenames = set(data["weight_map"].values())
80
- local_filenames = []
81
- for filename in filenames:
82
- pt_filename = hf_hub_download(repo_id=model_id, filename=filename)
83
 
84
- sf_filename = rename(pt_filename)
85
- sf_filename = os.path.join(folder, sf_filename)
86
- convert_file(pt_filename, sf_filename)
87
- local_filenames.append(sf_filename)
88
 
 
89
  index = os.path.join(folder, "model.safetensors.index.json")
90
  with open(index, "w") as f:
91
  newdata = {k: v for k, v in data.items()}
92
  newmap = {k: rename(v) for k, v in data["weight_map"].items()}
93
  newdata["weight_map"] = newmap
94
  json.dump(newdata, f, indent=4)
95
- local_filenames.append(index)
96
 
97
- operations = [
98
- CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames
99
- ]
100
- errors: List[Tuple[str, "Exception"]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- return operations, errors
103
 
104
 
105
  def convert_single(model_id: str, folder: str) -> ConversionResult:
@@ -108,9 +93,15 @@ def convert_single(model_id: str, folder: str) -> ConversionResult:
108
  sf_name = "model.safetensors"
109
  sf_filename = os.path.join(folder, sf_name)
110
  convert_file(pt_filename, sf_filename)
111
- operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)]
112
- errors: List[Tuple[str, "Exception"]] = []
113
- return operations, errors
 
 
 
 
 
 
114
 
115
 
116
  def convert_file(
@@ -120,18 +111,13 @@ def convert_file(
120
  loaded = torch.load(pt_filename, map_location="cpu")
121
  if "state_dict" in loaded:
122
  loaded = loaded["state_dict"]
123
- shared = shared_pointers(loaded)
124
- for shared_weights in shared:
125
- for name in shared_weights[1:]:
126
- loaded.pop(name)
127
-
128
  # For tensors to be contiguous
129
  loaded = {k: v.contiguous() for k, v in loaded.items()}
130
 
131
  dirname = os.path.dirname(sf_filename)
132
  os.makedirs(dirname, exist_ok=True)
133
  save_file(loaded, sf_filename, metadata={"format": "pt"})
134
- check_file_size(sf_filename, pt_filename)
135
  reloaded = load_file(sf_filename)
136
  for k in loaded:
137
  pt_tensor = loaded[k]
@@ -156,56 +142,6 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
156
  return "\n".join(errors)
157
 
158
 
159
- def check_final_model(model_id: str, folder: str):
160
- config = hf_hub_download(repo_id=model_id, filename="config.json")
161
- shutil.copy(config, os.path.join(folder, "config.json"))
162
- config = AutoConfig.from_pretrained(folder)
163
-
164
- _, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
165
- _, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)
166
-
167
- if pt_infos != sf_infos:
168
- error_string = create_diff(pt_infos, sf_infos)
169
- raise ValueError(f"Different infos when reloading the model: {error_string}")
170
-
171
- pt_params = pt_model.state_dict()
172
- sf_params = sf_model.state_dict()
173
-
174
- pt_shared = shared_pointers(pt_params)
175
- sf_shared = shared_pointers(sf_params)
176
- if pt_shared != sf_shared:
177
- raise RuntimeError("The reconstructed model is wrong, shared tensors are different {shared_pt} != {shared_tf}")
178
-
179
- sig = signature(pt_model.forward)
180
- input_ids = torch.arange(10).unsqueeze(0)
181
- pixel_values = torch.randn(1, 3, 224, 224)
182
- input_values = torch.arange(1000).float().unsqueeze(0)
183
- kwargs = {}
184
- if "input_ids" in sig.parameters:
185
- kwargs["input_ids"] = input_ids
186
- if "decoder_input_ids" in sig.parameters:
187
- kwargs["decoder_input_ids"] = input_ids
188
- if "pixel_values" in sig.parameters:
189
- kwargs["pixel_values"] = pixel_values
190
- if "input_values" in sig.parameters:
191
- kwargs["input_values"] = input_values
192
- if "bbox" in sig.parameters:
193
- kwargs["bbox"] = torch.zeros((1, 10, 4)).long()
194
- if "image" in sig.parameters:
195
- kwargs["image"] = pixel_values
196
-
197
- if torch.cuda.is_available():
198
- pt_model = pt_model.cuda()
199
- sf_model = sf_model.cuda()
200
- kwargs = {k: v.cuda() for k, v in kwargs.items()}
201
-
202
- pt_logits = pt_model(**kwargs)[0]
203
- sf_logits = sf_model(**kwargs)[0]
204
-
205
- torch.testing.assert_close(sf_logits, pt_logits)
206
- print(f"Model {model_id} is ok !")
207
-
208
-
209
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
210
  try:
211
  main_commit = api.list_repo_commits(model_id)[0].commit_id
@@ -226,6 +162,8 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> Conversi
226
  errors = []
227
 
228
  extensions = set([".bin", ".ckpt"])
 
 
229
  for filename in filenames:
230
  prefix, ext = os.path.splitext(filename)
231
  if ext in extensions:
@@ -240,10 +178,28 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> Conversi
240
  sf_filename = os.path.join(folder, sf_in_repo)
241
  try:
242
  convert_file(pt_filename, sf_filename)
243
- operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  except Exception as e:
245
  errors.append((pt_filename, e))
246
- return operations, errors
247
 
248
 
249
  def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List["Exception"]]:
@@ -268,26 +224,15 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
268
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
269
  elif library_name == "transformers":
270
  if "pytorch_model.bin" in filenames:
271
- operations, errors = convert_single(model_id, folder)
272
  elif "pytorch_model.bin.index.json" in filenames:
273
- operations, errors = convert_multi(model_id, folder)
274
  else:
275
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
276
- check_final_model(model_id, folder)
277
- else:
278
- operations, errors = convert_generic(model_id, folder, filenames)
279
-
280
- if operations:
281
- new_pr = api.create_commit(
282
- repo_id=model_id,
283
- operations=operations,
284
- commit_message=pr_title,
285
- commit_description=COMMIT_DESCRIPTION,
286
- create_pr=True,
287
- )
288
- print(f"Pr created at {new_pr.pr_url}")
289
  else:
290
- print("No files to convert")
 
 
291
  finally:
292
  shutil.rmtree(folder)
293
  return new_pr, errors
 
11
 
12
  from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
13
  from huggingface_hub.file_download import repo_folder_name
14
+ from safetensors.torch import load_file, save_file, _remove_duplicate_names
 
 
15
 
16
 
17
  COMMIT_DESCRIPTION = """
 
38
  pass
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def rename(pt_filename: str) -> str:
42
  filename, ext = os.path.splitext(pt_filename)
43
  local = f"{filename}.safetensors"
 
51
  data = json.load(f)
52
 
53
  filenames = set(data["weight_map"].values())
 
 
 
54
 
 
 
 
 
55
 
56
+
57
  index = os.path.join(folder, "model.safetensors.index.json")
58
  with open(index, "w") as f:
59
  newdata = {k: v for k, v in data.items()}
60
  newmap = {k: rename(v) for k, v in data["weight_map"].items()}
61
  newdata["weight_map"] = newmap
62
  json.dump(newdata, f, indent=4)
 
63
 
64
+
65
+ new_pr = api.create_commit(
66
+ repo_id=model_id,
67
+ operations=[CommitOperationAdd(path_in_repo=index.split("/")[-1], path_or_fileobj=index)],
68
+ commit_message=pr_title,
69
+ commit_description=COMMIT_DESCRIPTION,
70
+ create_pr=True,
71
+ )
72
+
73
+ for filename in filenames:
74
+ pt_filename = hf_hub_download(repo_id=model_id, filename=filename)
75
+ sf_filename = rename(pt_filename)
76
+ sf_filename = os.path.join(folder, sf_filename)
77
+ convert_file(pt_filename, sf_filename)
78
+ api.create_commit(
79
+ repo_id=repo_id,
80
+ commit_message=f"Adds {sf_filename}",
81
+ revision=new_pr.git_reference,
82
+ operations=[CommitOperationAdd(path_in_repo=sf_filename.split("/")[-1], path_or_fileobj=sf_filename)],
83
+ create_pr=False,
84
+ )
85
+ os.remove(pt_filename)
86
+ os.remove(sf_filename)
87
 
 
88
 
89
 
90
  def convert_single(model_id: str, folder: str) -> ConversionResult:
 
93
  sf_name = "model.safetensors"
94
  sf_filename = os.path.join(folder, sf_name)
95
  convert_file(pt_filename, sf_filename)
96
+
97
+ new_pr = api.create_commit(
98
+ repo_id=model_id,
99
+ operations=[CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)],
100
+ commit_message=pr_title,
101
+ commit_description=COMMIT_DESCRIPTION,
102
+ create_pr=True,
103
+ )
104
+ return new_pr
105
 
106
 
107
  def convert_file(
 
111
  loaded = torch.load(pt_filename, map_location="cpu")
112
  if "state_dict" in loaded:
113
  loaded = loaded["state_dict"]
114
+ loaded = _remove_duplicate_names(loaded)
 
 
 
 
115
  # For tensors to be contiguous
116
  loaded = {k: v.contiguous() for k, v in loaded.items()}
117
 
118
  dirname = os.path.dirname(sf_filename)
119
  os.makedirs(dirname, exist_ok=True)
120
  save_file(loaded, sf_filename, metadata={"format": "pt"})
 
121
  reloaded = load_file(sf_filename)
122
  for k in loaded:
123
  pt_tensor = loaded[k]
 
142
  return "\n".join(errors)
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
146
  try:
147
  main_commit = api.list_repo_commits(model_id)[0].commit_id
 
162
  errors = []
163
 
164
  extensions = set([".bin", ".ckpt"])
165
+
166
+ new_pr = None
167
  for filename in filenames:
168
  prefix, ext = os.path.splitext(filename)
169
  if ext in extensions:
 
178
  sf_filename = os.path.join(folder, sf_in_repo)
179
  try:
180
  convert_file(pt_filename, sf_filename)
181
+
182
+ if new_pr is None:
183
+ new_pr = api.create_commit(
184
+ repo_id=model_id,
185
+ operations=[CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename)],
186
+ commit_message=pr_title,
187
+ commit_description=COMMIT_DESCRIPTION,
188
+ create_pr=True,
189
+ )
190
+ else:
191
+ api.create_commit(
192
+ repo_id=repo_id,
193
+ commit_message=f"Adds {sf_filename}",
194
+ revision=new_pr.git_reference,
195
+ operations=[CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename)],
196
+ create_pr=False,
197
+ )
198
+ os.remove(pt_filename)
199
+ os.remove(sf_filename)
200
  except Exception as e:
201
  errors.append((pt_filename, e))
202
+ return new_pr
203
 
204
 
205
  def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List["Exception"]]:
 
224
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
225
  elif library_name == "transformers":
226
  if "pytorch_model.bin" in filenames:
227
+ new_pr = convert_single(model_id, folder)
228
  elif "pytorch_model.bin.index.json" in filenames:
229
+ new_pr = convert_multi(model_id, folder)
230
  else:
231
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  else:
233
+ new_pr = convert_generic(model_id, folder, filenames)
234
+
235
+ print(f"Pr created at {new_pr.pr_url}")
236
  finally:
237
  shutil.rmtree(folder)
238
  return new_pr, errors