patrickvonplaten commited on
Commit
c6c5536
1 Parent(s): 38707b6

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +1 -53
convert.py CHANGED
@@ -133,57 +133,6 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
133
  errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
134
  return "\n".join(errors)
135
 
136
-
137
- def check_final_model(model_id: str, folder: str):
138
- config = hf_hub_download(repo_id=model_id, filename="config.json")
139
- shutil.copy(config, os.path.join(folder, "config.json"))
140
- config = AutoConfig.from_pretrained(folder)
141
-
142
- _, (pt_model, pt_infos) = infer_framework_load_model(model_id, config, output_loading_info=True)
143
- _, (sf_model, sf_infos) = infer_framework_load_model(folder, config, output_loading_info=True)
144
-
145
- if pt_infos != sf_infos:
146
- error_string = create_diff(pt_infos, sf_infos)
147
- raise ValueError(f"Different infos when reloading the model: {error_string}")
148
-
149
- pt_params = pt_model.state_dict()
150
- sf_params = sf_model.state_dict()
151
-
152
- pt_shared = shared_pointers(pt_params)
153
- sf_shared = shared_pointers(sf_params)
154
- if pt_shared != sf_shared:
155
- raise RuntimeError("The reconstructed model is wrong, shared tensors are different {shared_pt} != {shared_tf}")
156
-
157
- sig = signature(pt_model.forward)
158
- input_ids = torch.arange(10).unsqueeze(0)
159
- pixel_values = torch.randn(1, 3, 224, 224)
160
- input_values = torch.arange(1000).float().unsqueeze(0)
161
- kwargs = {}
162
- if "input_ids" in sig.parameters:
163
- kwargs["input_ids"] = input_ids
164
- if "decoder_input_ids" in sig.parameters:
165
- kwargs["decoder_input_ids"] = input_ids
166
- if "pixel_values" in sig.parameters:
167
- kwargs["pixel_values"] = pixel_values
168
- if "input_values" in sig.parameters:
169
- kwargs["input_values"] = input_values
170
- if "bbox" in sig.parameters:
171
- kwargs["bbox"] = torch.zeros((1, 10, 4)).long()
172
- if "image" in sig.parameters:
173
- kwargs["image"] = pixel_values
174
-
175
- if torch.cuda.is_available():
176
- pt_model = pt_model.cuda()
177
- sf_model = sf_model.cuda()
178
- kwargs = {k: v.cuda() for k, v in kwargs.items()}
179
-
180
- pt_logits = pt_model(**kwargs)[0]
181
- sf_logits = sf_model(**kwargs)[0]
182
-
183
- torch.testing.assert_close(sf_logits, pt_logits)
184
- print(f"Model {model_id} is ok !")
185
-
186
-
187
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
188
  try:
189
  discussions = api.get_repo_discussions(repo_id=model_id)
@@ -218,7 +167,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
218
  def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
219
  pr_title = "Adding `safetensors` variant of this model"
220
  info = api.model_info(model_id)
221
- filenames = set(s.rfilename for s in info.siblings)
222
 
223
  with TemporaryDirectory() as d:
224
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
@@ -242,7 +191,6 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
242
  operations = convert_multi(model_id, folder)
243
  else:
244
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
245
- check_final_model(model_id, folder)
246
  else:
247
  operations = convert_generic(model_id, folder, filenames)
248
 
 
133
  errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
134
  return "\n".join(errors)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
137
  try:
138
  discussions = api.get_repo_discussions(repo_id=model_id)
 
167
  def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
168
  pr_title = "Adding `safetensors` variant of this model"
169
  info = api.model_info(model_id)
170
+ filenames = set(s.rfilename for s in info.siblings if len(s.rfilename.split("/")) > 1)
171
 
172
  with TemporaryDirectory() as d:
173
  folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
 
191
  operations = convert_multi(model_id, folder)
192
  else:
193
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
 
194
  else:
195
  operations = convert_generic(model_id, folder, filenames)
196