hysts HF staff commited on
Commit
98df5b4
1 Parent(s): 40eb6a4

Some fixes (#6)

Browse files

- Fix (fb53399598611d22359a6ee808da275d4b733d41)

Files changed (2) hide show
  1. app.py +24 -6
  2. model.py +8 -2
app.py CHANGED
@@ -3,13 +3,14 @@ from typing import Iterator
3
  import gradio as gr
4
  import torch
5
 
6
- from model import run
7
 
8
  DEFAULT_SYSTEM_PROMPT = """\
9
  You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
10
  """
11
  MAX_MAX_NEW_TOKENS = 2048
12
  DEFAULT_MAX_NEW_TOKENS = 1024
 
13
 
14
  DESCRIPTION = """
15
  # Llama-2 13B Chat
@@ -34,6 +35,7 @@ this demo is governed by the original [license](https://huggingface.co/spaces/hu
34
  if not torch.cuda.is_available():
35
  DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
36
 
 
37
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
38
  return '', message
39
 
@@ -58,16 +60,15 @@ def generate(
58
  history_with_input: list[tuple[str, str]],
59
  system_prompt: str,
60
  max_new_tokens: int,
61
- top_p: float,
62
  temperature: float,
 
63
  top_k: int,
64
  ) -> Iterator[list[tuple[str, str]]]:
65
  if max_new_tokens > MAX_MAX_NEW_TOKENS:
66
  raise ValueError
67
 
68
  history = history_with_input[:-1]
69
- generator = run(message, history, system_prompt, max_new_tokens,
70
- temperature, top_p, top_k)
71
  try:
72
  first_response = next(generator)
73
  yield history + [(message, first_response)]
@@ -78,13 +79,18 @@ def generate(
78
 
79
 
80
  def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
81
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 0.95, 1,
82
- 1000)
83
  for x in generator:
84
  pass
85
  return '', x
86
 
87
 
 
 
 
 
 
 
88
  with gr.Blocks(css='style.css') as demo:
89
  gr.Markdown(DESCRIPTION)
90
  gr.DuplicateButton(value='Duplicate Space for private use',
@@ -156,6 +162,7 @@ with gr.Blocks(css='style.css') as demo:
156
  fn=process_example,
157
  cache_examples=True,
158
  )
 
159
  gr.Markdown(LICENSE)
160
 
161
  textbox.submit(
@@ -171,6 +178,11 @@ with gr.Blocks(css='style.css') as demo:
171
  api_name=False,
172
  queue=False,
173
  ).then(
 
 
 
 
 
174
  fn=generate,
175
  inputs=[
176
  saved_input,
@@ -198,6 +210,11 @@ with gr.Blocks(css='style.css') as demo:
198
  api_name=False,
199
  queue=False,
200
  ).then(
 
 
 
 
 
201
  fn=generate,
202
  inputs=[
203
  saved_input,
@@ -229,6 +246,7 @@ with gr.Blocks(css='style.css') as demo:
229
  inputs=[
230
  saved_input,
231
  chatbot,
 
232
  max_new_tokens,
233
  temperature,
234
  top_p,
 
3
  import gradio as gr
4
  import torch
5
 
6
+ from model import get_input_token_length, run
7
 
8
  DEFAULT_SYSTEM_PROMPT = """\
9
  You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
10
  """
11
  MAX_MAX_NEW_TOKENS = 2048
12
  DEFAULT_MAX_NEW_TOKENS = 1024
13
+ MAX_INPUT_TOKEN_LENGTH = 4000
14
 
15
  DESCRIPTION = """
16
  # Llama-2 13B Chat
 
35
  if not torch.cuda.is_available():
36
  DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
37
 
38
+
39
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
40
  return '', message
41
 
 
60
  history_with_input: list[tuple[str, str]],
61
  system_prompt: str,
62
  max_new_tokens: int,
 
63
  temperature: float,
64
+ top_p: float,
65
  top_k: int,
66
  ) -> Iterator[list[tuple[str, str]]]:
67
  if max_new_tokens > MAX_MAX_NEW_TOKENS:
68
  raise ValueError
69
 
70
  history = history_with_input[:-1]
71
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
 
72
  try:
73
  first_response = next(generator)
74
  yield history + [(message, first_response)]
 
79
 
80
 
81
  def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
82
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
 
83
  for x in generator:
84
  pass
85
  return '', x
86
 
87
 
88
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
89
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
90
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
91
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
92
+
93
+
94
  with gr.Blocks(css='style.css') as demo:
95
  gr.Markdown(DESCRIPTION)
96
  gr.DuplicateButton(value='Duplicate Space for private use',
 
162
  fn=process_example,
163
  cache_examples=True,
164
  )
165
+
166
  gr.Markdown(LICENSE)
167
 
168
  textbox.submit(
 
178
  api_name=False,
179
  queue=False,
180
  ).then(
181
+ fn=check_input_token_length,
182
+ inputs=[saved_input, chatbot, system_prompt],
183
+ api_name=False,
184
+ queue=False,
185
+ ).success(
186
  fn=generate,
187
  inputs=[
188
  saved_input,
 
210
  api_name=False,
211
  queue=False,
212
  ).then(
213
+ fn=check_input_token_length,
214
+ inputs=[saved_input, chatbot, system_prompt],
215
+ api_name=False,
216
+ queue=False,
217
+ ).success(
218
  fn=generate,
219
  inputs=[
220
  saved_input,
 
246
  inputs=[
247
  saved_input,
248
  chatbot,
249
+ system_prompt,
250
  max_new_tokens,
251
  temperature,
252
  top_p,
model.py CHANGED
@@ -25,11 +25,17 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
25
  system_prompt: str) -> str:
26
  texts = [f'[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
27
  for user_input, response in chat_history:
28
- texts.append(f'{user_input} [/INST] {response} [INST] ')
29
  texts.append(f'{message.strip()} [/INST]')
30
  return ''.join(texts)
31
 
32
 
 
 
 
 
 
 
33
  def run(message: str,
34
  chat_history: list[tuple[str, str]],
35
  system_prompt: str,
@@ -38,7 +44,7 @@ def run(message: str,
38
  top_p: float = 0.95,
39
  top_k: int = 50) -> Iterator[str]:
40
  prompt = get_prompt(message, chat_history, system_prompt)
41
- inputs = tokenizer([prompt], return_tensors='pt').to("cuda")
42
 
43
  streamer = TextIteratorStreamer(tokenizer,
44
  timeout=10.,
 
25
  system_prompt: str) -> str:
26
  texts = [f'[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
27
  for user_input, response in chat_history:
28
+ texts.append(f'{user_input.strip()} [/INST] {response.strip()} </s><s> [INST] ')
29
  texts.append(f'{message.strip()} [/INST]')
30
  return ''.join(texts)
31
 
32
 
33
+ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
34
+ prompt = get_prompt(message, chat_history, system_prompt)
35
+ input_ids = tokenizer([prompt], return_tensors='np')['input_ids']
36
+ return input_ids.shape[-1]
37
+
38
+
39
  def run(message: str,
40
  chat_history: list[tuple[str, str]],
41
  system_prompt: str,
 
44
  top_p: float = 0.95,
45
  top_k: int = 50) -> Iterator[str]:
46
  prompt = get_prompt(message, chat_history, system_prompt)
47
+ inputs = tokenizer([prompt], return_tensors='pt').to('cuda')
48
 
49
  streamer = TextIteratorStreamer(tokenizer,
50
  timeout=10.,