simplify tokenizer bug handling
#5
by
gardari
- opened
- handler.py +2 -10
- tokenizer_config.json +4 -2
handler.py
CHANGED
@@ -33,19 +33,11 @@ class EndpointHandler:
|
|
33 |
self.model = AutoModelForCausalLM.from_pretrained(
|
34 |
path, device_map="auto", torch_dtype=torch.bfloat16
|
35 |
)
|
|
|
36 |
LOGGER.info(f"Inference model loaded from {path}")
|
37 |
LOGGER.info(f"Model device: {self.model.device}")
|
38 |
|
39 |
-
|
40 |
-
pad_token = "<unk>"
|
41 |
-
bos_token = "<|endoftext|>"
|
42 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
43 |
-
"AI-Sweden-Models/gpt-sw3-6.7b", pad_token=pad_token, bos_token=bos_token
|
44 |
-
)
|
45 |
-
|
46 |
-
def check_valid_inputs(
|
47 |
-
self, input_a: str, input_b: str, task: int
|
48 |
-
) -> bool:
|
49 |
"""
|
50 |
Check if the inputs are valid
|
51 |
"""
|
|
|
33 |
self.model = AutoModelForCausalLM.from_pretrained(
|
34 |
path, device_map="auto", torch_dtype=torch.bfloat16
|
35 |
)
|
36 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
37 |
LOGGER.info(f"Inference model loaded from {path}")
|
38 |
LOGGER.info(f"Model device: {self.model.device}")
|
39 |
|
40 |
+
def check_valid_inputs(self, input_a: str, input_b: str, task: int) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
"""
|
42 |
Check if the inputs are valid
|
43 |
"""
|
tokenizer_config.json
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
{
|
2 |
-
"name_or_path": "AI-Sweden-Models/gpt-sw3-6.7b"
|
3 |
-
|
|
|
|
|
|
1 |
{
|
2 |
+
"name_or_path": "AI-Sweden-Models/gpt-sw3-6.7b",
|
3 |
+
"bos_token": "<|endoftext|>",
|
4 |
+
"pad_token": "<unk>"
|
5 |
+
}
|