gardarjuto
commited on
Commit
•
210bb23
1
Parent(s):
fc28ced
add custom handler for inference endpoint
Browse files- handler.py +118 -0
handler.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
# Prompts for the different tasks
|
7 |
+
START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
|
8 |
+
END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n"
|
9 |
+
|
10 |
+
START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n"
|
11 |
+
MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n"
|
12 |
+
END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n"
|
13 |
+
|
14 |
+
START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
|
15 |
+
END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n"
|
16 |
+
|
17 |
+
START_PROMPT_TASK = {
|
18 |
+
1: START_PROMPT_TASK1,
|
19 |
+
2: START_PROMPT_TASK2,
|
20 |
+
3: START_PROMPT_TASK3,
|
21 |
+
}
|
22 |
+
END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3}
|
23 |
+
|
24 |
+
SEP = "\n\n"
|
25 |
+
|
26 |
+
|
27 |
+
class EndpointHandler:
|
28 |
+
def __init__(self, path=""):
|
29 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
30 |
+
path, device_map="auto", torch_dtype=torch.bfloat16
|
31 |
+
)
|
32 |
+
# Fix the pad and bos tokens to avoid bug in the tokenizer
|
33 |
+
pad_token = "<unk>"
|
34 |
+
bos_token = "<|endoftext|>"
|
35 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
36 |
+
"AI-Sweden-Models/gpt-sw3-6.7b", pad_token=pad_token, bos_token=bos_token
|
37 |
+
)
|
38 |
+
|
39 |
+
def check_valid_inputs(
|
40 |
+
self, input_a: str, input_b: str, task: int, parameters: Dict[str, Any]
|
41 |
+
) -> bool:
|
42 |
+
"""
|
43 |
+
Check if the inputs are valid
|
44 |
+
"""
|
45 |
+
if task not in [1, 2, 3]:
|
46 |
+
return False
|
47 |
+
if task == 1 or task == 3:
|
48 |
+
if input_a is None:
|
49 |
+
return False
|
50 |
+
elif task == 2:
|
51 |
+
if input_a is None or input_b is None:
|
52 |
+
return False
|
53 |
+
return True
|
54 |
+
|
55 |
+
def tokenize_input(self, input_a: str, input_b: str, task: int) -> List[int]:
|
56 |
+
"""
|
57 |
+
Tokenize the input
|
58 |
+
"""
|
59 |
+
if task == 1 or task == 3:
|
60 |
+
tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
|
61 |
+
tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
|
62 |
+
tokenized_sentence = self.tokenizer(input_a + SEP)["input_ids"]
|
63 |
+
concatted_data = (
|
64 |
+
[self.tokenizer.bos_token_id]
|
65 |
+
+ tokenized_start
|
66 |
+
+ tokenized_sentence
|
67 |
+
+ tokenized_end
|
68 |
+
)
|
69 |
+
elif task == 2:
|
70 |
+
tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
|
71 |
+
tokenized_middle = self.tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"]
|
72 |
+
tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
|
73 |
+
tokenized_sentence_a = self.tokenizer(input_a + SEP)["input_ids"]
|
74 |
+
tokenized_sentence_b = self.tokenizer(input_b + SEP)["input_ids"]
|
75 |
+
concatted_data = (
|
76 |
+
[self.tokenizer.bos_token_id]
|
77 |
+
+ tokenized_start
|
78 |
+
+ tokenized_sentence_a
|
79 |
+
+ tokenized_middle
|
80 |
+
+ tokenized_sentence_b
|
81 |
+
+ tokenized_end
|
82 |
+
)
|
83 |
+
return concatted_data
|
84 |
+
|
85 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
86 |
+
"""
|
87 |
+
data args:
|
88 |
+
inputs (:obj: `str` | `PIL.Image` | `np.array`)
|
89 |
+
kwargs
|
90 |
+
Return:
|
91 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
92 |
+
"""
|
93 |
+
|
94 |
+
# Get inputs
|
95 |
+
input_a = data.pop("input_a", None)
|
96 |
+
input_b = data.pop("input_b", None)
|
97 |
+
task = data.pop("task", None)
|
98 |
+
parameters = data.pop("parameters", None)
|
99 |
+
|
100 |
+
# Check valid inputs
|
101 |
+
if not self.check_valid_inputs(input_a, input_b, task, parameters):
|
102 |
+
return []
|
103 |
+
|
104 |
+
# Tokenize the input
|
105 |
+
tokenized_input = self.tokenize_input(input_a, input_b, task)
|
106 |
+
|
107 |
+
# Move the input to the device
|
108 |
+
input_ids = torch.tensor(tokenized_input).to(self.model.device)
|
109 |
+
input_ids = input_ids.unsqueeze(0)
|
110 |
+
|
111 |
+
# Generate the output
|
112 |
+
output = self.model.generate(input_ids, **parameters)
|
113 |
+
|
114 |
+
# Decode only the new part of the output
|
115 |
+
decoded_output = self.tokenizer.decode(
|
116 |
+
output[0][len(tokenized_input) :], skip_special_tokens=True
|
117 |
+
).strip()
|
118 |
+
return [{"output": decoded_output}]
|