Narsil HF staff commited on
Commit
b2f4808
1 Parent(s): ebcde60

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +4 -2
convert.py CHANGED
@@ -95,7 +95,9 @@ def convert_file(
95
  pt_filename: str,
96
  sf_filename: str,
97
  ):
98
- loaded = torch.load(pt_filename, map_location="cpu")
 
 
99
  shared = shared_pointers(loaded)
100
  for shared_weights in shared:
101
  for name in shared_weights[1:]:
@@ -238,7 +240,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
238
  operations = convert_multi(model_id, folder)
239
  else:
240
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
241
- # check_final_model(model_id, folder)
242
  else:
243
  operations = convert_generic(model_id, folder, filenames)
244
 
 
95
  pt_filename: str,
96
  sf_filename: str,
97
  ):
98
+ loaded = torch.load(pt_filename)
99
+ if "state_dict" in loaded:
100
+ loaded = loaded["state_dict"]
101
  shared = shared_pointers(loaded)
102
  for shared_weights in shared:
103
  for name in shared_weights[1:]:
 
240
  operations = convert_multi(model_id, folder)
241
  else:
242
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
243
+ check_final_model(model_id, folder)
244
  else:
245
  operations = convert_generic(model_id, folder, filenames)
246