jZoNg commited on
Commit
e4f695b
1 Parent(s): 9a00733

fix chat msg

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +234 -233
  3. model.py +60 -62
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Baichuan2 13B Chat
3
- emoji: 📉
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
  title: Baichuan2 13B Chat
3
+ emoji: 🔥
4
  colorFrom: gray
5
  colorTo: indigo
6
  sdk: gradio
app.py CHANGED
@@ -3,269 +3,270 @@ from typing import Iterator
3
  import gradio as gr
4
  import torch
5
 
6
- from model import get_input_token_length, run
7
 
8
- # DEFAULT_SYSTEM_PROMPT = """\
9
- # You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
10
- # """
11
-
12
- DEFAULT_SYSTEM_PROMPT = """
13
- """
14
  MAX_MAX_NEW_TOKENS = 2048
15
  DEFAULT_MAX_NEW_TOKENS = 1024
16
  MAX_INPUT_TOKEN_LENGTH = 4000
17
-
18
  DESCRIPTION = """
19
  # Baichuan2-13B-Chat
20
  Baichuan 2 is the new generation of open-source large language models launched by Baichuan Intelligent Technology. It was trained on a high-quality corpus with 2.6 trillion tokens.
21
  """
22
-
23
- LICENSE = """
24
- """
25
 
26
  if not torch.cuda.is_available():
27
- DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
28
 
29
 
30
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
31
- return '', message
32
-
33
-
34
- def display_input(message: str,
35
- history: list[tuple[str, str]]) -> list[tuple[str, str]]:
36
- history.append((message, ''))
37
- return history
38
 
 
 
 
 
 
 
39
 
40
  def delete_prev_fn(
41
- history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
42
- try:
43
- message, _ = history.pop()
44
- except IndexError:
45
- message = ''
46
- return history, message or ''
47
-
48
 
49
  def generate(
50
- message: str,
51
- history_with_input: list[tuple[str, str]],
52
- system_prompt: str,
53
- max_new_tokens: int,
54
- temperature: float,
55
- top_p: float,
56
- top_k: int,
57
  ) -> Iterator[list[tuple[str, str]]]:
58
- if max_new_tokens > MAX_MAX_NEW_TOKENS:
59
- raise ValueError
60
-
61
- history = history_with_input[:-1]
62
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
63
- try:
64
- first_response = next(generator)
65
- yield history + [(message, first_response)]
66
- except StopIteration:
67
- yield history + [(message, '')]
68
- for response in generator:
69
- yield history + [(message, response)]
70
-
71
 
72
  def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
73
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 5)
74
- for x in generator:
75
- pass
76
- return '', x
77
-
78
-
79
- def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
80
- input_token_length = get_input_token_length(message, chat_history, system_prompt)
81
- if input_token_length > MAX_INPUT_TOKEN_LENGTH:
82
- raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
 
 
 
 
83
 
84
 
85
  with gr.Blocks(css='style.css') as demo:
86
- gr.Markdown(DESCRIPTION)
87
- gr.DuplicateButton(value='Duplicate Space for private use',
88
- elem_id='duplicate-button')
89
-
90
- with gr.Group():
91
- chatbot = gr.Chatbot(label='Chatbot')
92
- with gr.Row():
93
- textbox = gr.Textbox(
94
- container=False,
95
- show_label=False,
96
- placeholder='Type a message...',
97
- scale=10,
98
- )
99
- submit_button = gr.Button('Submit',
100
- variant='primary',
101
- scale=1,
102
- min_width=0)
103
  with gr.Row():
104
- retry_button = gr.Button('🔄 Retry', variant='secondary')
105
- undo_button = gr.Button('↩️ Undo', variant='secondary')
106
- clear_button = gr.Button('🗑️ Clear', variant='secondary')
107
-
108
- saved_input = gr.State()
109
-
110
- with gr.Accordion(label='Advanced options', open=False):
111
- system_prompt = gr.Textbox(label='System prompt',
112
- value=DEFAULT_SYSTEM_PROMPT,
113
- lines=6)
114
- max_new_tokens = gr.Slider(
115
- label='Max new tokens',
116
- minimum=1,
117
- maximum=MAX_MAX_NEW_TOKENS,
118
- step=1,
119
- value=DEFAULT_MAX_NEW_TOKENS,
120
- )
121
- temperature = gr.Slider(
122
- label='Temperature',
123
- minimum=0.1,
124
- maximum=4.0,
125
- step=0.1,
126
- value=1.0,
127
- )
128
- top_p = gr.Slider(
129
- label='Top-p (nucleus sampling)',
130
- minimum=0.05,
131
- maximum=1.0,
132
- step=0.05,
133
- value=0.95,
134
- )
135
- top_k = gr.Slider(
136
- label='Top-k',
137
- minimum=1,
138
- maximum=1000,
139
- step=1,
140
- value=50,
141
- )
142
-
143
- gr.Examples(
144
- examples=[
145
- 'Hello there! How are you doing?',
146
- 'Can you explain briefly to me what is the Python programming language?',
147
- 'Explain the plot of Cinderella in a sentence.',
148
- 'How many hours does it take a man to eat a Helicopter?',
149
- "Write a 100-word article on 'Benefits of Open-Source in AI research'",
150
- ],
151
- inputs=textbox,
152
- outputs=[textbox, chatbot],
153
- fn=process_example,
154
- cache_examples=True,
155
  )
156
-
157
- gr.Markdown(LICENSE)
158
-
159
- textbox.submit(
160
- fn=clear_and_save_textbox,
161
- inputs=textbox,
162
- outputs=[textbox, saved_input],
163
- api_name=False,
164
- queue=False,
165
- ).then(
166
- fn=display_input,
167
- inputs=[saved_input, chatbot],
168
- outputs=chatbot,
169
- api_name=False,
170
- queue=False,
171
- ).then(
172
- fn=check_input_token_length,
173
- inputs=[saved_input, chatbot, system_prompt],
174
- api_name=False,
175
- queue=False,
176
- ).success(
177
- fn=generate,
178
- inputs=[
179
- saved_input,
180
- chatbot,
181
- system_prompt,
182
- max_new_tokens,
183
- temperature,
184
- top_p,
185
- top_k,
186
- ],
187
- outputs=chatbot,
188
- api_name=False,
189
  )
190
-
191
- button_event_preprocess = submit_button.click(
192
- fn=clear_and_save_textbox,
193
- inputs=textbox,
194
- outputs=[textbox, saved_input],
195
- api_name=False,
196
- queue=False,
197
- ).then(
198
- fn=display_input,
199
- inputs=[saved_input, chatbot],
200
- outputs=chatbot,
201
- api_name=False,
202
- queue=False,
203
- ).then(
204
- fn=check_input_token_length,
205
- inputs=[saved_input, chatbot, system_prompt],
206
- api_name=False,
207
- queue=False,
208
- ).success(
209
- fn=generate,
210
- inputs=[
211
- saved_input,
212
- chatbot,
213
- system_prompt,
214
- max_new_tokens,
215
- temperature,
216
- top_p,
217
- top_k,
218
- ],
219
- outputs=chatbot,
220
- api_name=False,
221
  )
222
-
223
- retry_button.click(
224
- fn=delete_prev_fn,
225
- inputs=chatbot,
226
- outputs=[chatbot, saved_input],
227
- api_name=False,
228
- queue=False,
229
- ).then(
230
- fn=display_input,
231
- inputs=[saved_input, chatbot],
232
- outputs=chatbot,
233
- api_name=False,
234
- queue=False,
235
- ).then(
236
- fn=generate,
237
- inputs=[
238
- saved_input,
239
- chatbot,
240
- system_prompt,
241
- max_new_tokens,
242
- temperature,
243
- top_p,
244
- top_k,
245
- ],
246
- outputs=chatbot,
247
- api_name=False,
248
  )
249
-
250
- undo_button.click(
251
- fn=delete_prev_fn,
252
- inputs=chatbot,
253
- outputs=[chatbot, saved_input],
254
- api_name=False,
255
- queue=False,
256
- ).then(
257
- fn=lambda x: x,
258
- inputs=[saved_input],
259
- outputs=textbox,
260
- api_name=False,
261
- queue=False,
262
  )
263
 
264
- clear_button.click(
265
- fn=lambda: ([], ''),
266
- outputs=[chatbot, saved_input],
267
- queue=False,
268
- api_name=False,
269
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  demo.queue(max_size=20).launch()
 
3
  import gradio as gr
4
  import torch
5
 
6
+ from model import run
7
 
8
+ DEFAULT_SYSTEM_PROMPT = ""
 
 
 
 
 
9
  MAX_MAX_NEW_TOKENS = 2048
10
  DEFAULT_MAX_NEW_TOKENS = 1024
11
  MAX_INPUT_TOKEN_LENGTH = 4000
 
12
  DESCRIPTION = """
13
  # Baichuan2-13B-Chat
14
  Baichuan 2 is the new generation of open-source large language models launched by Baichuan Intelligent Technology. It was trained on a high-quality corpus with 2.6 trillion tokens.
15
  """
16
+ LICENSE = ""
 
 
17
 
18
  if not torch.cuda.is_available():
19
+ DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
20
 
21
 
22
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
23
+ return '', message
 
 
 
 
 
 
24
 
25
+ def display_input(
26
+ message: str,
27
+ history: list[tuple[str, str]]
28
+ ) -> list[tuple[str, str]]:
29
+ history.append((message, ''))
30
+ return history
31
 
32
  def delete_prev_fn(
33
+ history: list[tuple[str, str]]
34
+ ) -> tuple[list[tuple[str, str]], str]:
35
+ try:
36
+ message, _ = history.pop()
37
+ except IndexError:
38
+ message = ''
39
+ return history, message or ''
40
 
41
  def generate(
42
+ message: str,
43
+ history_with_input: list[tuple[str, str]],
44
+ system_prompt: str,
45
+ max_new_tokens: int,
46
+ temperature: float,
47
+ top_p: float,
48
+ top_k: int,
49
  ) -> Iterator[list[tuple[str, str]]]:
50
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
51
+ raise ValueError
52
+
53
+ history = history_with_input[:-1]
54
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
55
+ try:
56
+ first_response = next(generator)
57
+ yield history + [(message, first_response)]
58
+ except StopIteration:
59
+ yield history + [(message, '')]
60
+ for response in generator:
61
+ yield history + [(message, response)]
 
62
 
63
  def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
64
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, DEFAULT_MAX_NEW_TOKENS, 1, 0.95, 5)
65
+ for x in generator:
66
+ pass
67
+ return '', x
68
+
69
+ def check_input_token_length(
70
+ message: str,
71
+ chat_history: list[tuple[str, str]],
72
+ system_prompt: str
73
+ ) -> None:
74
+ a = 1
75
+ # input_token_length = get_input_token_length(message, chat_history, system_prompt)
76
+ # if input_token_length > MAX_INPUT_TOKEN_LENGTH:
77
+ # raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
78
 
79
 
80
  with gr.Blocks(css='style.css') as demo:
81
+ gr.Markdown(DESCRIPTION)
82
+ gr.DuplicateButton(
83
+ value='Duplicate Space for private use',
84
+ elem_id='duplicate-button'
85
+ )
86
+
87
+ with gr.Group():
88
+ chatbot = gr.Chatbot(label='Chatbot')
 
 
 
 
 
 
 
 
 
89
  with gr.Row():
90
+ textbox = gr.Textbox(
91
+ container=False,
92
+ show_label=False,
93
+ placeholder='Type a message...',
94
+ scale=10,
95
+ )
96
+ submit_button = gr.Button(
97
+ 'Submit',
98
+ variant='primary',
99
+ scale=1,
100
+ min_width=0
101
+ )
102
+
103
+ with gr.Row():
104
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
105
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
106
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
107
+
108
+ saved_input = gr.State()
109
+
110
+ with gr.Accordion(label='Advanced options', open=False):
111
+ system_prompt = gr.Textbox(
112
+ label='System prompt',
113
+ value=DEFAULT_SYSTEM_PROMPT,
114
+ lines=6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  )
116
+ max_new_tokens = gr.Slider(
117
+ label='Max new tokens',
118
+ minimum=1,
119
+ maximum=MAX_MAX_NEW_TOKENS,
120
+ step=1,
121
+ value=DEFAULT_MAX_NEW_TOKENS,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
123
+ temperature = gr.Slider(
124
+ label='Temperature',
125
+ minimum=0.1,
126
+ maximum=4.0,
127
+ step=0.1,
128
+ value=1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  )
130
+ top_p = gr.Slider(
131
+ label='Top-p (nucleus sampling)',
132
+ minimum=0.05,
133
+ maximum=1.0,
134
+ step=0.05,
135
+ value=0.95,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
+ top_k = gr.Slider(
138
+ label='Top-k',
139
+ minimum=1,
140
+ maximum=1000,
141
+ step=1,
142
+ value=5,
 
 
 
 
 
 
 
143
  )
144
 
145
+ gr.Examples(
146
+ examples=[
147
+ '介绍下你自己',
148
+ '找到下列数组的中位数[3.1,6.2,1.3,8.4,10.5,11.6,2.1],请用python代码完成以上功能',
149
+ '鸡和兔在一个笼子里,共有26个头,68只脚,那么鸡有多少只,兔有多少只?',
150
+ '以下物理常识题目,哪一个是错误的?A.在自然环境下,声音在固体中传播速度最快。B.牛顿第一定律:一个物体如果不受力作用,将保持静止或匀速直线运动的状态。C.牛顿第三定律:对于每个作用力,都有一个相等而反向的反作用力。D.声音在空气中的传播速度为1000m/s。',
151
+ ],
152
+ inputs=textbox,
153
+ outputs=[textbox, chatbot],
154
+ fn=process_example,
155
+ cache_examples=True,
156
+ )
157
+
158
+ gr.Markdown(LICENSE)
159
+
160
+ textbox.submit(
161
+ fn=clear_and_save_textbox,
162
+ inputs=textbox,
163
+ outputs=[textbox, saved_input],
164
+ api_name=False,
165
+ queue=False,
166
+ ).then(
167
+ fn=display_input,
168
+ inputs=[saved_input, chatbot],
169
+ outputs=chatbot,
170
+ api_name=False,
171
+ queue=False,
172
+ ).then(
173
+ fn=check_input_token_length,
174
+ inputs=[saved_input, chatbot, system_prompt],
175
+ api_name=False,
176
+ queue=False,
177
+ ).success(
178
+ fn=generate,
179
+ inputs=[
180
+ saved_input,
181
+ chatbot,
182
+ system_prompt,
183
+ max_new_tokens,
184
+ temperature,
185
+ top_p,
186
+ top_k,
187
+ ],
188
+ outputs=chatbot,
189
+ api_name=False,
190
+ )
191
+
192
+ button_event_preprocess = submit_button.click(
193
+ fn=clear_and_save_textbox,
194
+ inputs=textbox,
195
+ outputs=[textbox, saved_input],
196
+ api_name=False,
197
+ queue=False,
198
+ ).then(
199
+ fn=display_input,
200
+ inputs=[saved_input, chatbot],
201
+ outputs=chatbot,
202
+ api_name=False,
203
+ queue=False,
204
+ ).then(
205
+ fn=check_input_token_length,
206
+ inputs=[saved_input, chatbot, system_prompt],
207
+ api_name=False,
208
+ queue=False,
209
+ ).success(
210
+ fn=generate,
211
+ inputs=[
212
+ saved_input,
213
+ chatbot,
214
+ system_prompt,
215
+ max_new_tokens,
216
+ temperature,
217
+ top_p,
218
+ top_k,
219
+ ],
220
+ outputs=chatbot,
221
+ api_name=False,
222
+ )
223
+
224
+ retry_button.click(
225
+ fn=delete_prev_fn,
226
+ inputs=chatbot,
227
+ outputs=[chatbot, saved_input],
228
+ api_name=False,
229
+ queue=False,
230
+ ).then(
231
+ fn=display_input,
232
+ inputs=[saved_input, chatbot],
233
+ outputs=chatbot,
234
+ api_name=False,
235
+ queue=False,
236
+ ).then(
237
+ fn=generate,
238
+ inputs=[
239
+ saved_input,
240
+ chatbot,
241
+ system_prompt,
242
+ max_new_tokens,
243
+ temperature,
244
+ top_p,
245
+ top_k,
246
+ ],
247
+ outputs=chatbot,
248
+ api_name=False,
249
+ )
250
+
251
+ undo_button.click(
252
+ fn=delete_prev_fn,
253
+ inputs=chatbot,
254
+ outputs=[chatbot, saved_input],
255
+ api_name=False,
256
+ queue=False,
257
+ ).then(
258
+ fn=lambda x: x,
259
+ inputs=[saved_input],
260
+ outputs=textbox,
261
+ api_name=False,
262
+ queue=False,
263
+ )
264
+
265
+ clear_button.click(
266
+ fn=lambda: ([], ''),
267
+ outputs=[chatbot, saved_input],
268
+ queue=False,
269
+ api_name=False,
270
+ )
271
 
272
  demo.queue(max_size=20).launch()
model.py CHANGED
@@ -2,79 +2,77 @@ from threading import Thread
2
  from typing import Iterator
3
 
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from transformers.generation.utils import GenerationConfig
7
 
8
  model_id = 'baichuan-inc/Baichuan2-13B-Chat'
9
 
10
  if torch.cuda.is_available():
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_id,
13
- # device_map='auto',
14
- torch_dtype=torch.float16,
15
- trust_remote_code=True
16
- )
17
- model = model.quantize(4).cuda()
18
- model.generation_config = GenerationConfig.from_pretrained(model_id)
19
- else:
20
- model = None
21
- tokenizer = AutoTokenizer.from_pretrained(
22
  model_id,
23
- use_fast=False,
 
24
  trust_remote_code=True
 
 
 
 
 
 
 
 
 
25
  )
26
 
27
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
28
- system_prompt: str) -> str:
29
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
30
- # The first user input is _not_ stripped
31
- do_strip = False
32
- for user_input, response in chat_history:
33
- user_input = user_input.strip() if do_strip else user_input
34
- do_strip = True
35
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
36
- message = message.strip() if do_strip else message
37
- texts.append(f'{message} [/INST]')
38
- return ''.join(texts)
39
-
40
-
41
- def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
42
- prompt = get_prompt(message, chat_history, system_prompt)
43
- input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
44
- return input_ids.shape[-1]
45
 
 
 
 
 
 
 
 
 
46
 
47
- def run(message: str,
48
- chat_history: list[tuple[str, str]],
49
- system_prompt: str,
50
- max_new_tokens: int = 2048,
51
- temperature: float = 0.3,
52
- top_p: float = 0.85,
53
- top_k: int = 5) -> Iterator[str]:
54
- prompt = get_prompt(message, chat_history, system_prompt)
55
- inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
 
56
 
57
- streamer = TextIteratorStreamer(
58
- tokenizer,
59
- timeout=10.,
60
- skip_prompt=True,
61
- skip_special_tokens=True
62
- )
63
 
64
- generate_kwargs = dict(
65
- inputs,
66
- streamer=streamer,
67
- max_new_tokens=max_new_tokens,
68
- do_sample=True,
69
- top_p=top_p,
70
- top_k=top_k,
71
- temperature=temperature,
72
- num_beams=1,
73
- )
74
- t = Thread(target=model.generate, kwargs=generate_kwargs)
75
- t.start()
76
 
77
- outputs = []
78
- for text in streamer:
79
- outputs.append(text)
80
- yield ''.join(outputs)
 
 
 
 
2
  from typing import Iterator
3
 
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from transformers.generation.utils import GenerationConfig
7
 
8
  model_id = 'baichuan-inc/Baichuan2-13B-Chat'
9
 
10
  if torch.cuda.is_available():
11
+ model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
 
 
 
12
  model_id,
13
+ # device_map='auto',
14
+ torch_dtype=torch.float16,
15
  trust_remote_code=True
16
+ )
17
+ model = model.quantize(4).cuda()
18
+ model.generation_config = GenerationConfig.from_pretrained(model_id)
19
+ else:
20
+ model = None
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ model_id,
23
+ use_fast=False,
24
+ trust_remote_code=True
25
  )
26
 
27
+ def get_prompt(
28
+ message: str,
29
+ chat_history: list[tuple[str, str]],
30
+ system_prompt: str
31
+ ) -> str:
32
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
33
+ # The first user input is _not_ stripped
34
+ do_strip = False
35
+ for user_input, response in chat_history:
36
+ user_input = user_input.strip() if do_strip else user_input
37
+ do_strip = True
38
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
39
+ message = message.strip() if do_strip else message
40
+ texts.append(f'{message} [/INST]')
41
+ return ''.join(texts)
 
 
 
42
 
43
+ def get_input_token_length(
44
+ message: str,
45
+ chat_history: list[tuple[str, str]],
46
+ system_prompt: str
47
+ ) -> int:
48
+ prompt = get_prompt(message, chat_history, system_prompt)
49
+ input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
50
+ return input_ids.shape[-1]
51
 
52
+ def run(
53
+ message: str,
54
+ chat_history: list[tuple[str, str]],
55
+ system_prompt: str,
56
+ max_new_tokens: int = 1024,
57
+ temperature: float = 1.0,
58
+ top_p: float = 0.95,
59
+ top_k: int = 5
60
+ ) -> Iterator[str]:
61
+ print(chat_history)
62
 
63
+ history = []
64
+ result=""
 
 
 
 
65
 
66
+ for i in chat_history:
67
+ history.append({"role": "user", "content": i[0]})
68
+ history.append({"role": "assistant", "content": i[1]})
69
+
70
+ print(history)
 
 
 
 
 
 
 
71
 
72
+ history.append({"role": "user", "content": message})
73
+
74
+ for response in model.chat(tokenizer, history, stream=True):
75
+ print(response)
76
+ if "content" in response["choices"][0]["delta"]:
77
+ result = result + response["choices"][0]["delta"]["content"]
78
+ yield result