Update convert.py
Browse files- convert.py +3 -1
convert.py
CHANGED
@@ -183,6 +183,7 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
183 |
pixel_values = torch.randn(1, 3, 224, 224)
|
184 |
input_values = torch.arange(1000).float().unsqueeze(0)
|
185 |
kwargs = {}
|
|
|
186 |
if "input_ids" in sig.parameters:
|
187 |
kwargs["input_ids"] = input_ids
|
188 |
if "decoder_input_ids" in sig.parameters:
|
@@ -213,7 +214,8 @@ def check_final_model(model_id: str, folder: str, token: Optional[str]):
|
|
213 |
kwargs["decoder_input_ids"] = decoder_input_ids
|
214 |
pt_logits = pt_model(**kwargs)[0]
|
215 |
except Exception:
|
216 |
-
|
|
|
217 |
sf_logits = sf_model(**kwargs)[0]
|
218 |
|
219 |
torch.testing.assert_close(sf_logits, pt_logits)
|
|
|
183 |
pixel_values = torch.randn(1, 3, 224, 224)
|
184 |
input_values = torch.arange(1000).float().unsqueeze(0)
|
185 |
kwargs = {}
|
186 |
+
import ipdb;ipdb.set_trace()
|
187 |
if "input_ids" in sig.parameters:
|
188 |
kwargs["input_ids"] = input_ids
|
189 |
if "decoder_input_ids" in sig.parameters:
|
|
|
214 |
kwargs["decoder_input_ids"] = decoder_input_ids
|
215 |
pt_logits = pt_model(**kwargs)[0]
|
216 |
except Exception:
|
217 |
+
print(f"Model {model_id} could not be checked, ignoring {e}")
|
218 |
+
return
|
219 |
sf_logits = sf_model(**kwargs)[0]
|
220 |
|
221 |
torch.testing.assert_close(sf_logits, pt_logits)
|