gardarjuto commited on
Commit
210bb23
1 Parent(s): fc28ced

add custom handler for inference endpoint

Browse files
Files changed (1) hide show
  1. handler.py +118 -0
handler.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+
6
+ # Prompts for the different tasks
7
+ START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
8
+ END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n"
9
+
10
+ START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n"
11
+ MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n"
12
+ END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n"
13
+
14
+ START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
15
+ END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n"
16
+
17
+ START_PROMPT_TASK = {
18
+ 1: START_PROMPT_TASK1,
19
+ 2: START_PROMPT_TASK2,
20
+ 3: START_PROMPT_TASK3,
21
+ }
22
+ END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3}
23
+
24
+ SEP = "\n\n"
25
+
26
+
27
+ class EndpointHandler:
28
+ def __init__(self, path=""):
29
+ self.model = AutoModelForCausalLM.from_pretrained(
30
+ path, device_map="auto", torch_dtype=torch.bfloat16
31
+ )
32
+ # Fix the pad and bos tokens to avoid bug in the tokenizer
33
+ pad_token = "<unk>"
34
+ bos_token = "<|endoftext|>"
35
+ self.tokenizer = AutoTokenizer.from_pretrained(
36
+ "AI-Sweden-Models/gpt-sw3-6.7b", pad_token=pad_token, bos_token=bos_token
37
+ )
38
+
39
+ def check_valid_inputs(
40
+ self, input_a: str, input_b: str, task: int, parameters: Dict[str, Any]
41
+ ) -> bool:
42
+ """
43
+ Check if the inputs are valid
44
+ """
45
+ if task not in [1, 2, 3]:
46
+ return False
47
+ if task == 1 or task == 3:
48
+ if input_a is None:
49
+ return False
50
+ elif task == 2:
51
+ if input_a is None or input_b is None:
52
+ return False
53
+ return True
54
+
55
+ def tokenize_input(self, input_a: str, input_b: str, task: int) -> List[int]:
56
+ """
57
+ Tokenize the input
58
+ """
59
+ if task == 1 or task == 3:
60
+ tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
61
+ tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
62
+ tokenized_sentence = self.tokenizer(input_a + SEP)["input_ids"]
63
+ concatted_data = (
64
+ [self.tokenizer.bos_token_id]
65
+ + tokenized_start
66
+ + tokenized_sentence
67
+ + tokenized_end
68
+ )
69
+ elif task == 2:
70
+ tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
71
+ tokenized_middle = self.tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"]
72
+ tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
73
+ tokenized_sentence_a = self.tokenizer(input_a + SEP)["input_ids"]
74
+ tokenized_sentence_b = self.tokenizer(input_b + SEP)["input_ids"]
75
+ concatted_data = (
76
+ [self.tokenizer.bos_token_id]
77
+ + tokenized_start
78
+ + tokenized_sentence_a
79
+ + tokenized_middle
80
+ + tokenized_sentence_b
81
+ + tokenized_end
82
+ )
83
+ return concatted_data
84
+
85
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
86
+ """
87
+ data args:
88
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
89
+ kwargs
90
+ Return:
91
+ A :obj:`list` | `dict`: will be serialized and returned
92
+ """
93
+
94
+ # Get inputs
95
+ input_a = data.pop("input_a", None)
96
+ input_b = data.pop("input_b", None)
97
+ task = data.pop("task", None)
98
+ parameters = data.pop("parameters", None)
99
+
100
+ # Check valid inputs
101
+ if not self.check_valid_inputs(input_a, input_b, task, parameters):
102
+ return []
103
+
104
+ # Tokenize the input
105
+ tokenized_input = self.tokenize_input(input_a, input_b, task)
106
+
107
+ # Move the input to the device
108
+ input_ids = torch.tensor(tokenized_input).to(self.model.device)
109
+ input_ids = input_ids.unsqueeze(0)
110
+
111
+ # Generate the output
112
+ output = self.model.generate(input_ids, **parameters)
113
+
114
+ # Decode only the new part of the output
115
+ decoded_output = self.tokenizer.decode(
116
+ output[0][len(tokenized_input) :], skip_special_tokens=True
117
+ ).strip()
118
+ return [{"output": decoded_output}]