File size: 4,596 Bytes
84c4d0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import torch
from threading import Thread
from typing import Iterator
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    StoppingCriteria, 
    StoppingCriteriaList
)

from huggingface_hub import login
login(token=os.environ["hf_read_token"])


class StopWordsCriteria(StoppingCriteria):
    def __init__(self, tokenizer, stop_words, stop_ids, stream_callback):
        self._tokenizer = tokenizer
        self._stop_words = stop_words
        self._stop_ids = stop_ids
        self._partial_result = ''
        self._stream_buffer = ''
        self._stream_callback = stream_callback

    # use both stop words (human id) and stop token ids (EOS tokens)
    def __call__(
            self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
            ) -> bool:
        first = not self._partial_result
        text = self._tokenizer.decode(input_ids[0, -1])
        self._partial_result += text
        # Check stop words
        for stop_word in self._stop_words:
            if stop_word in self._partial_result:
                return True
        # Check stop ids
        for stop_id in self._stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        if self._stream_callback:
            if first:
                text = text.lstrip()
            # buffer tokens if the partial result ends with a prefix of a stop word, e.g. "<hu"
            for stop_word in self._stop_words:
                for i in range(1, len(stop_word)):
                    if self._partial_result.endswith(stop_word[0:i]):
                        self._stream_buffer += text
                        return False
            self._stream_callback(self._stream_buffer + text)
            self._stream_buffer = ''
        return False
        

model_id = "medalpaca/medalpaca-7b"

if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map='auto',
	use_auth_token=True,
    )
else:
    model = None

tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)


def get_prompt(message: str, chat_history: list[tuple[str, str]],
               system_prompt: str) -> str:
    texts = [f'<<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
    # The first user input is _not_ stripped
    do_strip = False
    for user_input, response in chat_history:
        user_input = user_input.strip() if do_strip else user_input
        do_strip = True
        texts.append(f'{user_input} <Answer>: {response.strip()} <Question>: ')
    message = message.strip() if do_strip else message
    texts.append(f'{message} <Answer>:')
    print(texts)
    print('---------------------------------------------')
    return ''.join(texts)


def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
    prompt = get_prompt(message, chat_history, system_prompt)
    input_ids = tokenizer(
        [prompt], 
        return_token_type_ids=False, 
        return_tensors='np', 
        add_special_tokens=False)['input_ids']
    return input_ids.shape[-1]


def run(message: str,
        chat_history: list[tuple[str, str]],
        system_prompt: str,
        max_new_tokens: int = 1024,
        temperature: float = 0.8,
        top_p: float = 0.90,
        top_k: int = 20) -> Iterator[str]:
    prompt = get_prompt(message, chat_history, system_prompt)
    print(prompt)
    print('=================================================')
    inputs = tokenizer(
        [prompt], 
        return_token_type_ids=False, 
        return_tensors='pt', 
        add_special_tokens=False).to('cuda')

    streamer = TextIteratorStreamer(tokenizer,
                                    timeout=10.,
                                    skip_prompt=True,
                                    skip_special_tokens=True)
    stop_criteria = StopWordsCriteria(
        tokenizer=tokenizer, 
        stop_words=["<Question>", "<Answer>"], 
        stop_ids=[1,2,32001,32002],
        stream_callback=None
    )
        
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        stopping_criteria=StoppingCriteriaList([stop_criteria]),
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield ''.join(outputs)