KK-dev commited on
Commit
84c4d0e
1 Parent(s): 07c9d88

model-file

Browse files
Files changed (1) hide show
  1. model.py +140 -0
model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from threading import Thread
4
+ from typing import Iterator
5
+ from transformers import (
6
+ AutoModelForCausalLM,
7
+ AutoTokenizer,
8
+ TextIteratorStreamer,
9
+ StoppingCriteria,
10
+ StoppingCriteriaList
11
+ )
12
+
13
+ from huggingface_hub import login
14
+ login(token=os.environ["hf_read_token"])
15
+
16
+
17
+ class StopWordsCriteria(StoppingCriteria):
18
+ def __init__(self, tokenizer, stop_words, stop_ids, stream_callback):
19
+ self._tokenizer = tokenizer
20
+ self._stop_words = stop_words
21
+ self._stop_ids = stop_ids
22
+ self._partial_result = ''
23
+ self._stream_buffer = ''
24
+ self._stream_callback = stream_callback
25
+
26
+ # use both stop words (human id) and stop token ids (EOS tokens)
27
+ def __call__(
28
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
29
+ ) -> bool:
30
+ first = not self._partial_result
31
+ text = self._tokenizer.decode(input_ids[0, -1])
32
+ self._partial_result += text
33
+ # Check stop words
34
+ for stop_word in self._stop_words:
35
+ if stop_word in self._partial_result:
36
+ return True
37
+ # Check stop ids
38
+ for stop_id in self._stop_ids:
39
+ if input_ids[0][-1] == stop_id:
40
+ return True
41
+ if self._stream_callback:
42
+ if first:
43
+ text = text.lstrip()
44
+ # buffer tokens if the partial result ends with a prefix of a stop word, e.g. "<hu"
45
+ for stop_word in self._stop_words:
46
+ for i in range(1, len(stop_word)):
47
+ if self._partial_result.endswith(stop_word[0:i]):
48
+ self._stream_buffer += text
49
+ return False
50
+ self._stream_callback(self._stream_buffer + text)
51
+ self._stream_buffer = ''
52
+ return False
53
+
54
+
55
+ model_id = "medalpaca/medalpaca-7b"
56
+
57
+ if torch.cuda.is_available():
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ model_id,
60
+ torch_dtype=torch.float16,
61
+ device_map='auto',
62
+ use_auth_token=True,
63
+ )
64
+ else:
65
+ model = None
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)
68
+
69
+
70
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
71
+ system_prompt: str) -> str:
72
+ texts = [f'<<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
73
+ # The first user input is _not_ stripped
74
+ do_strip = False
75
+ for user_input, response in chat_history:
76
+ user_input = user_input.strip() if do_strip else user_input
77
+ do_strip = True
78
+ texts.append(f'{user_input} <Answer>: {response.strip()} <Question>: ')
79
+ message = message.strip() if do_strip else message
80
+ texts.append(f'{message} <Answer>:')
81
+ print(texts)
82
+ print('---------------------------------------------')
83
+ return ''.join(texts)
84
+
85
+
86
+ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
87
+ prompt = get_prompt(message, chat_history, system_prompt)
88
+ input_ids = tokenizer(
89
+ [prompt],
90
+ return_token_type_ids=False,
91
+ return_tensors='np',
92
+ add_special_tokens=False)['input_ids']
93
+ return input_ids.shape[-1]
94
+
95
+
96
+ def run(message: str,
97
+ chat_history: list[tuple[str, str]],
98
+ system_prompt: str,
99
+ max_new_tokens: int = 1024,
100
+ temperature: float = 0.8,
101
+ top_p: float = 0.90,
102
+ top_k: int = 20) -> Iterator[str]:
103
+ prompt = get_prompt(message, chat_history, system_prompt)
104
+ print(prompt)
105
+ print('=================================================')
106
+ inputs = tokenizer(
107
+ [prompt],
108
+ return_token_type_ids=False,
109
+ return_tensors='pt',
110
+ add_special_tokens=False).to('cuda')
111
+
112
+ streamer = TextIteratorStreamer(tokenizer,
113
+ timeout=10.,
114
+ skip_prompt=True,
115
+ skip_special_tokens=True)
116
+ stop_criteria = StopWordsCriteria(
117
+ tokenizer=tokenizer,
118
+ stop_words=["<Question>", "<Answer>"],
119
+ stop_ids=[1,2,32001,32002],
120
+ stream_callback=None
121
+ )
122
+
123
+ generate_kwargs = dict(
124
+ inputs,
125
+ streamer=streamer,
126
+ max_new_tokens=max_new_tokens,
127
+ do_sample=True,
128
+ top_p=top_p,
129
+ top_k=top_k,
130
+ temperature=temperature,
131
+ stopping_criteria=StoppingCriteriaList([stop_criteria]),
132
+ num_beams=1,
133
+ )
134
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
135
+ t.start()
136
+
137
+ outputs = []
138
+ for text in streamer:
139
+ outputs.append(text)
140
+ yield ''.join(outputs)