echarlaix HF staff commited on
Commit
12e01c3
1 Parent(s): 04e8a16

fix-onnx-export (#3)

Browse files

- Fix ONNX export (16686c5c06ad6a49f2aac8d22414b2554533c3b5)

Files changed (1) hide show
  1. onnx_export.py +6 -7
onnx_export.py CHANGED
@@ -41,7 +41,11 @@ def convert_onnx(model_id: str, task: str, folder: str, opset: int) -> List:
41
  model_name = getattr(model, "name", None)
42
 
43
  onnx_config_constructor = TasksManager.get_exporter_config_constructor(
44
- model_type, "onnx", task=task, model_name=model_name
 
 
 
 
45
  )
46
  onnx_config = onnx_config_constructor(model.config)
47
 
@@ -66,12 +70,7 @@ def convert_onnx(model_id: str, task: str, folder: str, opset: int) -> List:
66
  opset = onnx_config.DEFAULT_ONNX_OPSET
67
 
68
  output = Path(folder).joinpath("model.onnx")
69
- onnx_inputs, onnx_outputs = export(
70
- model,
71
- onnx_config,
72
- opset,
73
- output,
74
- )
75
 
76
  atol = onnx_config.ATOL_FOR_VALIDATION
77
  if isinstance(atol, dict):
 
41
  model_name = getattr(model, "name", None)
42
 
43
  onnx_config_constructor = TasksManager.get_exporter_config_constructor(
44
+ exporter="onnx",
45
+ model=model,
46
+ task=task,
47
+ model_name=model_name,
48
+ model_type=model_type,
49
  )
50
  onnx_config = onnx_config_constructor(model.config)
51
 
 
70
  opset = onnx_config.DEFAULT_ONNX_OPSET
71
 
72
  output = Path(folder).joinpath("model.onnx")
73
+ onnx_inputs, onnx_outputs = export(model, onnx_config, output, opset)
 
 
 
 
 
74
 
75
  atol = onnx_config.ATOL_FOR_VALIDATION
76
  if isinstance(atol, dict):