Narsil HF staff commited on
Commit
16b2783
1 Parent(s): be4725f

Update convert.py

Browse files

Keeping one part of the final test still retaining the no RAM usage feature.

Files changed (1) hide show
  1. convert.py +12 -8
convert.py CHANGED
@@ -163,13 +163,17 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
163
  import transformers
164
 
165
  class_ = getattr(transformers, config.architectures[0])
166
- (pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
167
- (sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)
168
-
169
- if pt_infos != sf_infos:
170
- error_string = create_diff(pt_infos, sf_infos)
171
- raise ValueError(f"Different infos when reloading the model: {error_string}")
172
-
 
 
 
 
173
  pt_params = pt_model.state_dict()
174
  sf_params = sf_model.state_dict()
175
 
@@ -291,7 +295,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
291
  operations, errors = convert_multi(model_id, folder, token=api.token)
292
  else:
293
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
294
- # check_final_model(model_id, folder, token=api.token)
295
  else:
296
  operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
297
 
 
163
  import transformers
164
 
165
  class_ = getattr(transformers, config.architectures[0])
166
+ with torch.device("meta"):
167
+ (pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True)
168
+ (sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True)
169
+
170
+ if pt_infos != sf_infos:
171
+ error_string = create_diff(pt_infos, sf_infos)
172
+ raise ValueError(f"Different infos when reloading the model: {error_string}")
173
+
174
+ #### XXXXXXXXXXXXXXXXXXXXXXXXXXXXX
175
+ #### SKIPPING THE REST OF THE test to save RAM
176
+ return
177
  pt_params = pt_model.state_dict()
178
  sf_params = sf_model.state_dict()
179
 
 
295
  operations, errors = convert_multi(model_id, folder, token=api.token)
296
  else:
297
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
298
+ check_final_model(model_id, folder, token=api.token)
299
  else:
300
  operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
301