echarlaix HF staff commited on
Commit
3a01f1b
1 Parent(s): 242da95

fix validation step

Browse files
Files changed (1) hide show
  1. export.py +2 -6
export.py CHANGED
@@ -71,15 +71,11 @@ def convert_openvino(model_id: str, task: str, folder: str) -> List:
71
  )
72
  openvino_config = exporter_config_class(model.config)
73
  inputs = openvino_config.generate_dummy_inputs(framework="pt")
74
-
75
  ov_outputs = ov_model(**inputs)
76
  outputs = model(**inputs)
77
 
78
- if isinstance(outputs, torch.Tensor):
79
- outputs = {"logits": outputs}
80
- ov_outputs = {"logits": ov_outputs}
81
- for output_name in outputs:
82
- if not torch.allclose(outputs[output_name], ov_outputs[output_name], atol=1e-3):
83
  raise ValueError(
84
  "The exported model does not have the same outputs as the original model. Export interrupted."
85
  )
 
71
  )
72
  openvino_config = exporter_config_class(model.config)
73
  inputs = openvino_config.generate_dummy_inputs(framework="pt")
 
74
  ov_outputs = ov_model(**inputs)
75
  outputs = model(**inputs)
76
 
77
+ for output_name in ov_outputs:
78
+ if isinstance(outputs, torch.Tensor) and not torch.allclose(outputs[output_name], ov_outputs[output_name], atol=1e-3):
 
 
 
79
  raise ValueError(
80
  "The exported model does not have the same outputs as the original model. Export interrupted."
81
  )