File size: 5,491 Bytes
210bb23
 
 
eb01dd5
 
 
 
210bb23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7163f3d
eb01dd5
 
 
7163f3d
210bb23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb01dd5
210bb23
 
 
 
 
fc73940
210bb23
 
fc73940
 
6ce586a
fc73940
210bb23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging

logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)


# Prompts for the different tasks
START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n"

START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n"
MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n"
END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n"

START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n"

START_PROMPT_TASK = {
    1: START_PROMPT_TASK1,
    2: START_PROMPT_TASK2,
    3: START_PROMPT_TASK3,
}
END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3}

SEP = "\n\n"


class EndpointHandler:
    def __init__(self, path=""):
        self.model = AutoModelForCausalLM.from_pretrained(
            path, device_map="auto", torch_dtype=torch.bfloat16
        )
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        LOGGER.info(f"Inference model loaded from {path}")
        LOGGER.info(f"Model device: {self.model.device}")

    def check_valid_inputs(self, input_a: str, input_b: str, task: int) -> bool:
        """
        Check if the inputs are valid
        """
        if task not in [1, 2, 3]:
            return False
        if task == 1 or task == 3:
            if input_a is None:
                return False
        elif task == 2:
            if input_a is None or input_b is None:
                return False
        return True

    def tokenize_input(self, input_a: str, input_b: str, task: int) -> List[int]:
        """
        Tokenize the input
        """
        if task == 1 or task == 3:
            tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
            tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
            tokenized_sentence = self.tokenizer(input_a + SEP)["input_ids"]
            concatted_data = (
                [self.tokenizer.bos_token_id]
                + tokenized_start
                + tokenized_sentence
                + tokenized_end
            )
        elif task == 2:
            tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
            tokenized_middle = self.tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"]
            tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
            tokenized_sentence_a = self.tokenizer(input_a + SEP)["input_ids"]
            tokenized_sentence_b = self.tokenizer(input_b + SEP)["input_ids"]
            concatted_data = (
                [self.tokenizer.bos_token_id]
                + tokenized_start
                + tokenized_sentence_a
                + tokenized_middle
                + tokenized_sentence_b
                + tokenized_end
            )
        return concatted_data

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
         data args:
              inputs (:obj: `str` | `PIL.Image` | `np.array`)
              kwargs
        Return:
              A :obj:`list` | `dict`: will be serialized and returned
        """
        LOGGER.info(f"Received data: {data}")

        # Get inputs
        input_a = data.pop("input_a", None)
        input_b = data.pop("input_b", None)
        task = data.pop("task", None)
        parameters = data.pop("parameters", {})

        # Check valid inputs
        if not self.check_valid_inputs(input_a, input_b, task):
            return [{"error": "Invalid inputs"}]
        if "max_new_tokens" not in parameters and "max_length" not in parameters:
            parameters["max_new_tokens"] = 512

        # Tokenize the input
        tokenized_input = self.tokenize_input(input_a, input_b, task)

        # Move the input to the device
        input_ids = torch.tensor(tokenized_input).to(self.model.device)
        input_ids = input_ids.unsqueeze(0)

        # Generate the output
        output = self.model.generate(input_ids, **parameters)

        # Decode only the new part of the output
        decoded_output = self.tokenizer.decode(
            output[0][len(tokenized_input) :], skip_special_tokens=True
        ).strip()
        return [{"output": decoded_output}]