Narsil HF staff commited on
Commit
65428ae
1 Parent(s): 0349d80

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +40 -5
convert.py CHANGED
@@ -161,8 +161,11 @@ def check_final_model(model_id: str, folder: str):
161
  shutil.copy(config, os.path.join(folder, "config.json"))
162
  config = AutoConfig.from_pretrained(folder)
163
 
164
- _, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
165
- _, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)
 
 
 
166
 
167
  if pt_infos != sf_infos:
168
  error_string = create_diff(pt_infos, sf_infos)
@@ -199,7 +202,19 @@ def check_final_model(model_id: str, folder: str):
199
  sf_model = sf_model.cuda()
200
  kwargs = {k: v.cuda() for k, v in kwargs.items()}
201
 
202
- pt_logits = pt_model(**kwargs)[0]
 
 
 
 
 
 
 
 
 
 
 
 
203
  sf_logits = sf_model(**kwargs)[0]
204
 
205
  torch.testing.assert_close(sf_logits, pt_logits)
@@ -246,7 +261,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> Conversi
246
  return operations, errors
247
 
248
 
249
- def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List["Exception"]]:
250
  pr_title = "Adding `safetensors` variant of this model"
251
  info = api.model_info(model_id)
252
  filenames = set(s.rfilename for s in info.siblings)
@@ -328,6 +343,26 @@ if __name__ == "__main__":
328
  " Continue [Y/n] ?"
329
  )
330
  if txt.lower() in {"", "y"}:
331
- _commit_info, _errors = convert(api, model_id, force=args.force)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  else:
333
  print(f"Answer was `{txt}` aborting.")
 
161
  shutil.copy(config, os.path.join(folder, "config.json"))
162
  config = AutoConfig.from_pretrained(folder)
163
 
164
+ import transformers
165
+
166
+ class_ = getattr(transformers, config.architectures[0])
167
+ (pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
168
+ (sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)
169
 
170
  if pt_infos != sf_infos:
171
  error_string = create_diff(pt_infos, sf_infos)
 
202
  sf_model = sf_model.cuda()
203
  kwargs = {k: v.cuda() for k, v in kwargs.items()}
204
 
205
+ try:
206
+ pt_logits = pt_model(**kwargs)[0]
207
+ except Exception as e:
208
+ try:
209
+ # Musicgen special exception.
210
+ decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long)
211
+ if torch.cuda.is_available():
212
+ decoder_input_ids = decoder_input_ids.cuda()
213
+
214
+ kwargs["decoder_input_ids"] = decoder_input_ids
215
+ pt_logits = pt_model(**kwargs)[0]
216
+ except Exception:
217
+ raise e
218
  sf_logits = sf_model(**kwargs)[0]
219
 
220
  torch.testing.assert_close(sf_logits, pt_logits)
 
261
  return operations, errors
262
 
263
 
264
+ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
265
  pr_title = "Adding `safetensors` variant of this model"
266
  info = api.model_info(model_id)
267
  filenames = set(s.rfilename for s in info.siblings)
 
343
  " Continue [Y/n] ?"
344
  )
345
  if txt.lower() in {"", "y"}:
346
+ try:
347
+ commit_info, errors = convert(api, model_id, force=args.force)
348
+ string = f"""
349
+ ### Success 🔥
350
+ Yay! This model was successfully converted and a PR was open using your token, here:
351
+ [{commit_info.pr_url}]({commit_info.pr_url})
352
+ """
353
+ if errors:
354
+ string += "\nErrors during conversion:\n"
355
+ string += "\n".join(
356
+ f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
357
+ )
358
+ print(string)
359
+ except Exception as e:
360
+ print(
361
+ f"""
362
+ ### Error 😢😢😢
363
+
364
+ {e}
365
+ """
366
+ )
367
  else:
368
  print(f"Answer was `{txt}` aborting.")