File size: 2,894 Bytes
bf66e5a
366e62e
bf66e5a
 
 
 
 
833b301
b179918
000ad8b
bf66e5a
 
216cf30
bf66e5a
 
4f7f1d5
bf66e5a
 
 
 
 
4c4f932
bf66e5a
000ad8b
0e224b1
b29c898
4c4f932
0e224b1
4c4f932
66e62c6
000ad8b
b179918
6d8b690
8cda5e7
000ad8b
8cda5e7
 
 
6d8b690
 
 
000ad8b
 
6d8b690
000ad8b
02ffbef
e1040a6
b179918
000ad8b
cd860a6
000ad8b
5e35500
392d92f
833b301
bf66e5a
 
 
 
 
 
 
 
 
 
366e62e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList


class EndpointHandler():
    def __init__(self, path=""):
        tokenizer = AutoTokenizer.from_pretrained(path)
        tokenizer.pad_token = tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
        self.tokenizer = tokenizer
        self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        inputs = data.pop("inputs", data)
        additional_bad_words_ids = data.pop("additional_bad_words_ids", [])


        # 3070, 10456, [313, 334] corresponds to "(*", and we do not want to output a comment
        # 13 is a newline character
        # [1976, 441, 29889], [4920, 441, 29889] is "Abort." [4920, 18054, 29889] is "Aborted."
        # [2087, 29885, 4430, 29889] is "Admitted."
        bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
        bad_words_ids.extend(additional_bad_words_ids)

        input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
        max_generation_length = 75  # Desired number of tokens to generate
        # max_input_length = 4092 - max_generation_length  # Maximum input length to allow space for generation

        # # Truncate input_ids to the most recent tokens that fit within the max_input_length
        # if input_ids.shape[1] > max_input_length:
        #     input_ids = input_ids[:, -max_input_length:]

        max_length = input_ids.shape[1] + max_generation_length
        
        generated_ids = self.model.generate(
            input_ids,
            max_length=max_length,  # 50 new tokens
            bad_words_ids=bad_words_ids,
            temperature=1,
            top_k=40,
            do_sample=True,
            stopping_criteria=self.stopping_criteria,
        )

        generated_text = self.tokenizer.decode(generated_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
        prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
        return prediction


class StopAtPeriodCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs):
        # Decode the last generated token to text
        last_token_text = self.tokenizer.decode(input_ids[:, -1], skip_special_tokens=True)
        # Check if the decoded text ends with a period
        return '.' in last_token_text