jeremiebasso commited on
Commit
2ffc97e
1 Parent(s): 8fe5582

fix: no need torch

Browse files
Files changed (1) hide show
  1. onnx_model.py +4 -2
onnx_model.py CHANGED
@@ -8,7 +8,6 @@ from typing import Any
8
  import numpy as np
9
  import onnxruntime as ort
10
  from loguru import logger
11
- from onnxruntime.transformers.io_binding_helper import TypeHelper
12
 
13
 
14
  @dataclass
@@ -36,7 +35,10 @@ class ONNXModel:
36
  else:
37
  self.device = "cpu"
38
 
39
- self.io_types = TypeHelper.get_io_numpy_type_map(model)
 
 
 
40
 
41
  self.input_names = [el.name for el in model.get_inputs()]
42
  self.output_name = model.get_outputs()[0].name
 
8
  import numpy as np
9
  import onnxruntime as ort
10
  from loguru import logger
 
11
 
12
 
13
  @dataclass
 
35
  else:
36
  self.device = "cpu"
37
 
38
+ self.io_types = {
39
+ "input_ids": np.int32,
40
+ "attention_mask": np.bool_
41
+ }
42
 
43
  self.input_names = [el.name for el in model.get_inputs()]
44
  self.output_name = model.get_outputs()[0].name