gmastrapas commited on
Commit
05e4a8c
1 Parent(s): 06150c7

fix: st backwards compatibility

Browse files
Files changed (1) hide show
  1. custom_st.py +8 -1
custom_st.py CHANGED
@@ -2,7 +2,7 @@ import base64
2
  import json
3
  import os
4
  from io import BytesIO
5
- from typing import Any, Dict, List, Optional, Tuple, Union
6
 
7
  import requests
8
  import torch
@@ -43,8 +43,15 @@ class Transformer(nn.Module):
43
  cache_dir: Optional[str] = None,
44
  do_lower_case: bool = False,
45
  tokenizer_name_or_path: str = None,
 
 
46
  ) -> None:
47
  super(Transformer, self).__init__()
 
 
 
 
 
48
  self.config_keys = ["max_seq_length", "do_lower_case"]
49
  self.do_lower_case = do_lower_case
50
  if model_args is None:
 
2
  import json
3
  import os
4
  from io import BytesIO
5
+ from typing import Any, Dict, List, Literal, Optional, Union
6
 
7
  import requests
8
  import torch
 
43
  cache_dir: Optional[str] = None,
44
  do_lower_case: bool = False,
45
  tokenizer_name_or_path: str = None,
46
+ backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
47
+ **_,
48
  ) -> None:
49
  super(Transformer, self).__init__()
50
+ if backend != 'torch':
51
+ raise ValueError(
52
+ f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
53
+ )
54
+
55
  self.config_keys = ["max_seq_length", "do_lower_case"]
56
  self.do_lower_case = do_lower_case
57
  if model_args is None: