hysts HF staff commited on
Commit
09b3f75
1 Parent(s): 323df56

Migrate from yapf to black

Browse files
Files changed (3) hide show
  1. README.md +0 -2
  2. app.py +66 -67
  3. model.py +20 -26
README.md CHANGED
@@ -17,5 +17,3 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
17
  Llama v2 was introduced in [this paper](https://arxiv.org/abs/2307.09288).
18
 
19
  This Space demonstrates [Llama-2-7b-chat-hf](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/meta-llama/Llama-2-7b-chat-hf) from Meta. Please, check the original model card for details.
20
-
21
-
 
17
  Llama v2 was introduced in [this paper](https://arxiv.org/abs/2307.09288).
18
 
19
  This Space demonstrates [Llama-2-7b-chat-hf](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/meta-llama/Llama-2-7b-chat-hf) from Meta. Please, check the original model card for details.
 
 
app.py CHANGED
@@ -33,26 +33,24 @@ this demo is governed by the original [license](https://huggingface.co/spaces/hu
33
  """
34
 
35
  if not torch.cuda.is_available():
36
- DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
37
 
38
 
39
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
40
- return '', message
41
 
42
 
43
- def display_input(message: str,
44
- history: list[tuple[str, str]]) -> list[tuple[str, str]]:
45
- history.append((message, ''))
46
  return history
47
 
48
 
49
- def delete_prev_fn(
50
- history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
51
  try:
52
  message, _ = history.pop()
53
  except IndexError:
54
- message = ''
55
- return history, message or ''
56
 
57
 
58
  def generate(
@@ -73,7 +71,7 @@ def generate(
73
  first_response = next(generator)
74
  yield history + [(message, first_response)]
75
  except StopIteration:
76
- yield history + [(message, '')]
77
  for response in generator:
78
  yield history + [(message, response)]
79
 
@@ -82,67 +80,63 @@ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
82
  generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
83
  for x in generator:
84
  pass
85
- return '', x
86
 
87
 
88
  def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
89
  input_token_length = get_input_token_length(message, chat_history, system_prompt)
90
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
91
- raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
 
 
92
 
93
 
94
- with gr.Blocks(css='style.css') as demo:
95
  gr.Markdown(DESCRIPTION)
96
- gr.DuplicateButton(value='Duplicate Space for private use',
97
- elem_id='duplicate-button')
98
 
99
  with gr.Group():
100
- chatbot = gr.Chatbot(label='Chatbot')
101
  with gr.Row():
102
  textbox = gr.Textbox(
103
  container=False,
104
  show_label=False,
105
- placeholder='Type a message...',
106
  scale=10,
107
  )
108
- submit_button = gr.Button('Submit',
109
- variant='primary',
110
- scale=1,
111
- min_width=0)
112
  with gr.Row():
113
- retry_button = gr.Button('🔄 Retry', variant='secondary')
114
- undo_button = gr.Button('↩️ Undo', variant='secondary')
115
- clear_button = gr.Button('🗑️ Clear', variant='secondary')
116
 
117
  saved_input = gr.State()
118
 
119
- with gr.Accordion(label='Advanced options', open=False):
120
- system_prompt = gr.Textbox(label='System prompt',
121
- value=DEFAULT_SYSTEM_PROMPT,
122
- lines=6)
123
  max_new_tokens = gr.Slider(
124
- label='Max new tokens',
125
  minimum=1,
126
  maximum=MAX_MAX_NEW_TOKENS,
127
  step=1,
128
  value=DEFAULT_MAX_NEW_TOKENS,
129
  )
130
  temperature = gr.Slider(
131
- label='Temperature',
132
  minimum=0.1,
133
  maximum=4.0,
134
  step=0.1,
135
  value=1.0,
136
  )
137
  top_p = gr.Slider(
138
- label='Top-p (nucleus sampling)',
139
  minimum=0.05,
140
  maximum=1.0,
141
  step=0.05,
142
  value=0.95,
143
  )
144
  top_k = gr.Slider(
145
- label='Top-k',
146
  minimum=1,
147
  maximum=1000,
148
  step=1,
@@ -151,10 +145,10 @@ with gr.Blocks(css='style.css') as demo:
151
 
152
  gr.Examples(
153
  examples=[
154
- 'Hello there! How are you doing?',
155
- 'Can you explain briefly to me what is the Python programming language?',
156
- 'Explain the plot of Cinderella in a sentence.',
157
- 'How many hours does it take a man to eat a Helicopter?',
158
  "Write a 100-word article on 'Benefits of Open-Source in AI research'",
159
  ],
160
  inputs=textbox,
@@ -197,36 +191,41 @@ with gr.Blocks(css='style.css') as demo:
197
  api_name=False,
198
  )
199
 
200
- button_event_preprocess = submit_button.click(
201
- fn=clear_and_save_textbox,
202
- inputs=textbox,
203
- outputs=[textbox, saved_input],
204
- api_name=False,
205
- queue=False,
206
- ).then(
207
- fn=display_input,
208
- inputs=[saved_input, chatbot],
209
- outputs=chatbot,
210
- api_name=False,
211
- queue=False,
212
- ).then(
213
- fn=check_input_token_length,
214
- inputs=[saved_input, chatbot, system_prompt],
215
- api_name=False,
216
- queue=False,
217
- ).success(
218
- fn=generate,
219
- inputs=[
220
- saved_input,
221
- chatbot,
222
- system_prompt,
223
- max_new_tokens,
224
- temperature,
225
- top_p,
226
- top_k,
227
- ],
228
- outputs=chatbot,
229
- api_name=False,
 
 
 
 
 
230
  )
231
 
232
  retry_button.click(
@@ -271,7 +270,7 @@ with gr.Blocks(css='style.css') as demo:
271
  )
272
 
273
  clear_button.click(
274
- fn=lambda: ([], ''),
275
  outputs=[chatbot, saved_input],
276
  queue=False,
277
  api_name=False,
 
33
  """
34
 
35
  if not torch.cuda.is_available():
36
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
37
 
38
 
39
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
40
+ return "", message
41
 
42
 
43
+ def display_input(message: str, history: list[tuple[str, str]]) -> list[tuple[str, str]]:
44
+ history.append((message, ""))
 
45
  return history
46
 
47
 
48
+ def delete_prev_fn(history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
 
49
  try:
50
  message, _ = history.pop()
51
  except IndexError:
52
+ message = ""
53
+ return history, message or ""
54
 
55
 
56
  def generate(
 
71
  first_response = next(generator)
72
  yield history + [(message, first_response)]
73
  except StopIteration:
74
+ yield history + [(message, "")]
75
  for response in generator:
76
  yield history + [(message, response)]
77
 
 
80
  generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
81
  for x in generator:
82
  pass
83
+ return "", x
84
 
85
 
86
  def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
87
  input_token_length = get_input_token_length(message, chat_history, system_prompt)
88
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
89
+ raise gr.Error(
90
+ f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
91
+ )
92
 
93
 
94
+ with gr.Blocks(css="style.css") as demo:
95
  gr.Markdown(DESCRIPTION)
96
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
 
97
 
98
  with gr.Group():
99
+ chatbot = gr.Chatbot(label="Chatbot")
100
  with gr.Row():
101
  textbox = gr.Textbox(
102
  container=False,
103
  show_label=False,
104
+ placeholder="Type a message...",
105
  scale=10,
106
  )
107
+ submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
 
 
 
108
  with gr.Row():
109
+ retry_button = gr.Button("🔄 Retry", variant="secondary")
110
+ undo_button = gr.Button("↩️ Undo", variant="secondary")
111
+ clear_button = gr.Button("🗑️ Clear", variant="secondary")
112
 
113
  saved_input = gr.State()
114
 
115
+ with gr.Accordion(label="Advanced options", open=False):
116
+ system_prompt = gr.Textbox(label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
 
 
117
  max_new_tokens = gr.Slider(
118
+ label="Max new tokens",
119
  minimum=1,
120
  maximum=MAX_MAX_NEW_TOKENS,
121
  step=1,
122
  value=DEFAULT_MAX_NEW_TOKENS,
123
  )
124
  temperature = gr.Slider(
125
+ label="Temperature",
126
  minimum=0.1,
127
  maximum=4.0,
128
  step=0.1,
129
  value=1.0,
130
  )
131
  top_p = gr.Slider(
132
+ label="Top-p (nucleus sampling)",
133
  minimum=0.05,
134
  maximum=1.0,
135
  step=0.05,
136
  value=0.95,
137
  )
138
  top_k = gr.Slider(
139
+ label="Top-k",
140
  minimum=1,
141
  maximum=1000,
142
  step=1,
 
145
 
146
  gr.Examples(
147
  examples=[
148
+ "Hello there! How are you doing?",
149
+ "Can you explain briefly to me what is the Python programming language?",
150
+ "Explain the plot of Cinderella in a sentence.",
151
+ "How many hours does it take a man to eat a Helicopter?",
152
  "Write a 100-word article on 'Benefits of Open-Source in AI research'",
153
  ],
154
  inputs=textbox,
 
191
  api_name=False,
192
  )
193
 
194
+ button_event_preprocess = (
195
+ submit_button.click(
196
+ fn=clear_and_save_textbox,
197
+ inputs=textbox,
198
+ outputs=[textbox, saved_input],
199
+ api_name=False,
200
+ queue=False,
201
+ )
202
+ .then(
203
+ fn=display_input,
204
+ inputs=[saved_input, chatbot],
205
+ outputs=chatbot,
206
+ api_name=False,
207
+ queue=False,
208
+ )
209
+ .then(
210
+ fn=check_input_token_length,
211
+ inputs=[saved_input, chatbot, system_prompt],
212
+ api_name=False,
213
+ queue=False,
214
+ )
215
+ .success(
216
+ fn=generate,
217
+ inputs=[
218
+ saved_input,
219
+ chatbot,
220
+ system_prompt,
221
+ max_new_tokens,
222
+ temperature,
223
+ top_p,
224
+ top_k,
225
+ ],
226
+ outputs=chatbot,
227
+ api_name=False,
228
+ )
229
  )
230
 
231
  retry_button.click(
 
270
  )
271
 
272
  clear_button.click(
273
+ fn=lambda: ([], ""),
274
  outputs=[chatbot, saved_input],
275
  queue=False,
276
  api_name=False,
model.py CHANGED
@@ -4,53 +4,47 @@ from typing import Iterator
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
 
7
- model_id = 'meta-llama/Llama-2-7b-chat-hf'
8
 
9
  if torch.cuda.is_available():
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.float16,
13
- device_map='auto'
14
- )
15
  else:
16
  model = None
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
 
20
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
21
- system_prompt: str) -> str:
22
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
23
  # The first user input is _not_ stripped
24
  do_strip = False
25
  for user_input, response in chat_history:
26
  user_input = user_input.strip() if do_strip else user_input
27
  do_strip = True
28
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
29
  message = message.strip() if do_strip else message
30
- texts.append(f'{message} [/INST]')
31
- return ''.join(texts)
32
 
33
 
34
  def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
35
  prompt = get_prompt(message, chat_history, system_prompt)
36
- input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
37
  return input_ids.shape[-1]
38
 
39
 
40
- def run(message: str,
41
- chat_history: list[tuple[str, str]],
42
- system_prompt: str,
43
- max_new_tokens: int = 1024,
44
- temperature: float = 0.8,
45
- top_p: float = 0.95,
46
- top_k: int = 50) -> Iterator[str]:
 
 
47
  prompt = get_prompt(message, chat_history, system_prompt)
48
- inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
49
 
50
- streamer = TextIteratorStreamer(tokenizer,
51
- timeout=10.,
52
- skip_prompt=True,
53
- skip_special_tokens=True)
54
  generate_kwargs = dict(
55
  inputs,
56
  streamer=streamer,
@@ -67,4 +61,4 @@ def run(message: str,
67
  outputs = []
68
  for text in streamer:
69
  outputs.append(text)
70
- yield ''.join(outputs)
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
 
7
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
8
 
9
  if torch.cuda.is_available():
10
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
 
 
 
 
11
  else:
12
  model = None
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
 
15
 
16
+ def get_prompt(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> str:
17
+ texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
 
18
  # The first user input is _not_ stripped
19
  do_strip = False
20
  for user_input, response in chat_history:
21
  user_input = user_input.strip() if do_strip else user_input
22
  do_strip = True
23
+ texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
24
  message = message.strip() if do_strip else message
25
+ texts.append(f"{message} [/INST]")
26
+ return "".join(texts)
27
 
28
 
29
  def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
30
  prompt = get_prompt(message, chat_history, system_prompt)
31
+ input_ids = tokenizer([prompt], return_tensors="np", add_special_tokens=False)["input_ids"]
32
  return input_ids.shape[-1]
33
 
34
 
35
+ def run(
36
+ message: str,
37
+ chat_history: list[tuple[str, str]],
38
+ system_prompt: str,
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.8,
41
+ top_p: float = 0.95,
42
+ top_k: int = 50,
43
+ ) -> Iterator[str]:
44
  prompt = get_prompt(message, chat_history, system_prompt)
45
+ inputs = tokenizer([prompt], return_tensors="pt", add_special_tokens=False).to("cuda")
46
 
47
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
48
  generate_kwargs = dict(
49
  inputs,
50
  streamer=streamer,
 
61
  outputs = []
62
  for text in streamer:
63
  outputs.append(text)
64
+ yield "".join(outputs)