gmastrapas
commited on
Commit
•
05e4a8c
1
Parent(s):
06150c7
fix: st backwards compatibility
Browse files- custom_st.py +8 -1
custom_st.py
CHANGED
@@ -2,7 +2,7 @@ import base64
|
|
2 |
import json
|
3 |
import os
|
4 |
from io import BytesIO
|
5 |
-
from typing import Any, Dict, List,
|
6 |
|
7 |
import requests
|
8 |
import torch
|
@@ -43,8 +43,15 @@ class Transformer(nn.Module):
|
|
43 |
cache_dir: Optional[str] = None,
|
44 |
do_lower_case: bool = False,
|
45 |
tokenizer_name_or_path: str = None,
|
|
|
|
|
46 |
) -> None:
|
47 |
super(Transformer, self).__init__()
|
|
|
|
|
|
|
|
|
|
|
48 |
self.config_keys = ["max_seq_length", "do_lower_case"]
|
49 |
self.do_lower_case = do_lower_case
|
50 |
if model_args is None:
|
|
|
2 |
import json
|
3 |
import os
|
4 |
from io import BytesIO
|
5 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
6 |
|
7 |
import requests
|
8 |
import torch
|
|
|
43 |
cache_dir: Optional[str] = None,
|
44 |
do_lower_case: bool = False,
|
45 |
tokenizer_name_or_path: str = None,
|
46 |
+
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
47 |
+
**_,
|
48 |
) -> None:
|
49 |
super(Transformer, self).__init__()
|
50 |
+
if backend != 'torch':
|
51 |
+
raise ValueError(
|
52 |
+
f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
|
53 |
+
)
|
54 |
+
|
55 |
self.config_keys = ["max_seq_length", "do_lower_case"]
|
56 |
self.do_lower_case = do_lower_case
|
57 |
if model_args is None:
|