freddyaboulton HF staff commited on
Commit
e34e07c
1 Parent(s): 41f8286
Files changed (3) hide show
  1. app.py +49 -245
  2. model.py +14 -6
  3. style.css +0 -16
app.py CHANGED
@@ -1,7 +1,6 @@
1
- from typing import Iterator
2
-
3
  import gradio as gr
4
  import torch
 
5
 
6
  from model import get_input_token_length, run
7
 
@@ -12,17 +11,6 @@ MAX_MAX_NEW_TOKENS = 2048
12
  DEFAULT_MAX_NEW_TOKENS = 1024
13
  MAX_INPUT_TOKEN_LENGTH = 4000
14
 
15
- DESCRIPTION = """
16
- # Llama-2 13B Chat
17
-
18
- This Space demonstrates model [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta, a Llama 2 model with 13B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
19
-
20
- 🔎 For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
21
-
22
- 🔨 Looking for an even more powerful model? Check out the large [**70B** model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI).
23
- 🐇 For a smaller model that you can run on many GPUs, check our [7B model demo](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat).
24
-
25
- """
26
 
27
  LICENSE = """
28
  <p/>
@@ -32,249 +20,65 @@ As a derivate work of [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama
32
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/USE_POLICY.md).
33
  """
34
 
35
- if not torch.cuda.is_available():
36
- DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
37
-
38
-
39
- def clear_and_save_textbox(message: str) -> tuple[str, str]:
40
- return '', message
41
-
42
-
43
- def display_input(message: str,
44
- history: list[tuple[str, str]]) -> list[tuple[str, str]]:
45
- history.append((message, ''))
46
- return history
47
-
48
-
49
- def delete_prev_fn(
50
- history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
51
- try:
52
- message, _ = history.pop()
53
- except IndexError:
54
- message = ''
55
- return history, message or ''
56
 
57
 
58
  def generate(
59
  message: str,
60
  history_with_input: list[tuple[str, str]],
61
- system_prompt: str,
62
- max_new_tokens: int,
63
- temperature: float,
64
- top_p: float,
65
- top_k: int,
66
- ) -> Iterator[list[tuple[str, str]]]:
 
 
67
  if max_new_tokens > MAX_MAX_NEW_TOKENS:
68
  raise ValueError
69
-
70
  history = history_with_input[:-1]
71
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
72
- try:
73
- first_response = next(generator)
74
- yield history + [(message, first_response)]
75
- except StopIteration:
76
- yield history + [(message, '')]
77
- for response in generator:
78
- yield history + [(message, response)]
79
-
80
-
81
- def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
82
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
83
- for x in generator:
84
- pass
85
- return '', x
86
-
87
-
88
- def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
89
- input_token_length = get_input_token_length(message, chat_history, system_prompt)
90
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
91
- raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
92
-
93
-
94
- with gr.Blocks(css='style.css') as demo:
95
- gr.Markdown(DESCRIPTION)
96
- gr.DuplicateButton(value='Duplicate Space for private use',
97
- elem_id='duplicate-button')
98
-
99
- with gr.Group():
100
- chatbot = gr.Chatbot(label='Chatbot')
101
- with gr.Row():
102
- textbox = gr.Textbox(
103
- container=False,
104
- show_label=False,
105
- placeholder='Type a message...',
106
- scale=10,
107
- )
108
- submit_button = gr.Button('Submit',
109
- variant='primary',
110
- scale=1,
111
- min_width=0)
112
- with gr.Row():
113
- retry_button = gr.Button('🔄 Retry', variant='secondary')
114
- undo_button = gr.Button('↩️ Undo', variant='secondary')
115
- clear_button = gr.Button('🗑️ Clear', variant='secondary')
116
-
117
- saved_input = gr.State()
118
-
119
- with gr.Accordion(label='Advanced options', open=False):
120
- system_prompt = gr.Textbox(label='System prompt',
121
- value=DEFAULT_SYSTEM_PROMPT,
122
- lines=6)
123
- max_new_tokens = gr.Slider(
124
- label='Max new tokens',
125
- minimum=1,
126
- maximum=MAX_MAX_NEW_TOKENS,
127
- step=1,
128
- value=DEFAULT_MAX_NEW_TOKENS,
129
- )
130
- temperature = gr.Slider(
131
- label='Temperature',
132
- minimum=0.1,
133
- maximum=4.0,
134
- step=0.1,
135
- value=1.0,
136
- )
137
- top_p = gr.Slider(
138
- label='Top-p (nucleus sampling)',
139
- minimum=0.05,
140
- maximum=1.0,
141
- step=0.05,
142
- value=0.95,
143
- )
144
- top_k = gr.Slider(
145
- label='Top-k',
146
- minimum=1,
147
- maximum=1000,
148
- step=1,
149
- value=50,
150
- )
151
-
152
- gr.Examples(
153
- examples=[
154
- 'Hello there! How are you doing?',
155
- 'Can you explain briefly to me what is the Python programming language?',
156
- 'Explain the plot of Cinderella in a sentence.',
157
- 'How many hours does it take a man to eat a Helicopter?',
158
- "Write a 100-word article on 'Benefits of Open-Source in AI research'",
159
- ],
160
- inputs=textbox,
161
- outputs=[textbox, chatbot],
162
- fn=process_example,
163
- cache_examples=True,
164
  )
165
 
166
  gr.Markdown(LICENSE)
167
-
168
- textbox.submit(
169
- fn=clear_and_save_textbox,
170
- inputs=textbox,
171
- outputs=[textbox, saved_input],
172
- api_name=False,
173
- queue=False,
174
- ).then(
175
- fn=display_input,
176
- inputs=[saved_input, chatbot],
177
- outputs=chatbot,
178
- api_name=False,
179
- queue=False,
180
- ).then(
181
- fn=check_input_token_length,
182
- inputs=[saved_input, chatbot, system_prompt],
183
- api_name=False,
184
- queue=False,
185
- ).success(
186
- fn=generate,
187
- inputs=[
188
- saved_input,
189
- chatbot,
190
- system_prompt,
191
- max_new_tokens,
192
- temperature,
193
- top_p,
194
- top_k,
195
- ],
196
- outputs=chatbot,
197
- api_name=False,
198
- )
199
-
200
- button_event_preprocess = submit_button.click(
201
- fn=clear_and_save_textbox,
202
- inputs=textbox,
203
- outputs=[textbox, saved_input],
204
- api_name=False,
205
- queue=False,
206
- ).then(
207
- fn=display_input,
208
- inputs=[saved_input, chatbot],
209
- outputs=chatbot,
210
- api_name=False,
211
- queue=False,
212
- ).then(
213
- fn=check_input_token_length,
214
- inputs=[saved_input, chatbot, system_prompt],
215
- api_name=False,
216
- queue=False,
217
- ).success(
218
- fn=generate,
219
- inputs=[
220
- saved_input,
221
- chatbot,
222
- system_prompt,
223
- max_new_tokens,
224
- temperature,
225
- top_p,
226
- top_k,
227
- ],
228
- outputs=chatbot,
229
- api_name=False,
230
- )
231
-
232
- retry_button.click(
233
- fn=delete_prev_fn,
234
- inputs=chatbot,
235
- outputs=[chatbot, saved_input],
236
- api_name=False,
237
- queue=False,
238
- ).then(
239
- fn=display_input,
240
- inputs=[saved_input, chatbot],
241
- outputs=chatbot,
242
- api_name=False,
243
- queue=False,
244
- ).then(
245
- fn=generate,
246
- inputs=[
247
- saved_input,
248
- chatbot,
249
- system_prompt,
250
- max_new_tokens,
251
- temperature,
252
- top_p,
253
- top_k,
254
- ],
255
- outputs=chatbot,
256
- api_name=False,
257
- )
258
-
259
- undo_button.click(
260
- fn=delete_prev_fn,
261
- inputs=chatbot,
262
- outputs=[chatbot, saved_input],
263
- api_name=False,
264
- queue=False,
265
- ).then(
266
- fn=lambda x: x,
267
- inputs=[saved_input],
268
- outputs=textbox,
269
- api_name=False,
270
- queue=False,
271
- )
272
-
273
- clear_button.click(
274
- fn=lambda: ([], ''),
275
- outputs=[chatbot, saved_input],
276
- queue=False,
277
- api_name=False,
278
- )
279
 
280
  demo.queue(max_size=20).launch()
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import os
4
 
5
  from model import get_input_token_length, run
6
 
 
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = 4000
13
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  LICENSE = """
16
  <p/>
 
20
  this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/USE_POLICY.md).
21
  """
22
 
23
+ is_spaces = True if "SPACE_ID" in os.environ else False
24
+ if is_spaces :
25
+ is_shared_ui = True if "gradio-discord-bots/llama-2-13b-chat-transformers" in os.environ['SPACE_ID'] else False
26
+ else:
27
+ is_shared_ui = False
28
+ is_gpu_associated = torch.cuda.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def generate(
32
  message: str,
33
  history_with_input: list[tuple[str, str]],
34
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
35
+ max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
36
+ temperature=1.0,
37
+ top_p=0.95,
38
+ top_k=50,
39
+ ) -> tuple[str, list[tuple[str, str]]]:
40
+ if is_shared_ui:
41
+ raise ValueError("Cannot use demo running in shared_ui. Must duplicate your own space.")
42
  if max_new_tokens > MAX_MAX_NEW_TOKENS:
43
  raise ValueError
44
+
45
  history = history_with_input[:-1]
46
+ input_token_length = get_input_token_length(message, history, system_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
48
+ response = f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Please create a new thread.'
49
+ else:
50
+ response = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
51
+ return response, history + [(message, response)]
52
+
53
+
54
+ with gr.Blocks() as demo:
55
+
56
+ gr.Markdown(
57
+ """
58
+ # Llama-2-13b-chat-hf Discord Bot Powered by Gradio and Hugging Face Transformers
59
+
60
+ ### First install the `gradio_client`
61
+
62
+ ```bash
63
+ pip install gradio_client
64
+ ```
65
+
66
+ ### Then deploy to discord in one line! ⚡️
67
+
68
+ ```python
69
+ secrets = {"HUGGING_FACE_HUB_TOKEN": "<your-key-here>",}
70
+ client = grc.Client.duplicate("gradio-discord-bots/llama-2-13b-chat-transformers", secrets=secrets, hardware="a10g-small")
71
+ client.deploy_discord(api_names=["chat"])
72
+ ```
73
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
 
76
  gr.Markdown(LICENSE)
77
+ with gr.Row(visible=False):
78
+ state = gr.State([])
79
+ msg = gr.Textbox()
80
+ output = gr.Textbox()
81
+ btn = gr.Button()
82
+ btn.click(generate, [msg, state], [output, state], api_name="chat")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  demo.queue(max_size=20).launch()
model.py CHANGED
@@ -1,12 +1,18 @@
1
  from threading import Thread
2
- from typing import Iterator
3
-
4
  import torch
5
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
 
7
  model_id = 'meta-llama/Llama-2-13b-chat-hf'
8
 
9
- if torch.cuda.is_available():
 
 
 
 
 
 
 
10
  config = AutoConfig.from_pretrained(model_id)
11
  config.pretraining_tp = 1
12
  model = AutoModelForCausalLM.from_pretrained(
@@ -16,9 +22,10 @@ if torch.cuda.is_available():
16
  load_in_4bit=True,
17
  device_map='auto'
18
  )
 
19
  else:
20
  model = None
21
- tokenizer = AutoTokenizer.from_pretrained(model_id)
22
 
23
 
24
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
@@ -47,7 +54,7 @@ def run(message: str,
47
  max_new_tokens: int = 1024,
48
  temperature: float = 0.8,
49
  top_p: float = 0.95,
50
- top_k: int = 50) -> Iterator[str]:
51
  prompt = get_prompt(message, chat_history, system_prompt)
52
  inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
53
 
@@ -71,4 +78,5 @@ def run(message: str,
71
  outputs = []
72
  for text in streamer:
73
  outputs.append(text)
74
- yield ''.join(outputs)
 
 
1
  from threading import Thread
2
+ import os
 
3
  import torch
4
  from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
 
6
  model_id = 'meta-llama/Llama-2-13b-chat-hf'
7
 
8
+ is_spaces = True if "SPACE_ID" in os.environ else False
9
+ if is_spaces :
10
+ is_shared_ui = True if "gradio-discord-bots/llama-2-13b-chat-transformers" in os.environ['SPACE_ID'] else False
11
+ else:
12
+ is_shared_ui = False
13
+ is_gpu_associated = torch.cuda.is_available()
14
+
15
+ if torch.cuda.is_available() and not is_shared_ui:
16
  config = AutoConfig.from_pretrained(model_id)
17
  config.pretraining_tp = 1
18
  model = AutoModelForCausalLM.from_pretrained(
 
22
  load_in_4bit=True,
23
  device_map='auto'
24
  )
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
  else:
27
  model = None
28
+ tokenizer = None
29
 
30
 
31
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
 
54
  max_new_tokens: int = 1024,
55
  temperature: float = 0.8,
56
  top_p: float = 0.95,
57
+ top_k: int = 50) -> str:
58
  prompt = get_prompt(message, chat_history, system_prompt)
59
  inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
60
 
 
78
  outputs = []
79
  for text in streamer:
80
  outputs.append(text)
81
+
82
+ return "".join(outputs)
style.css DELETED
@@ -1,16 +0,0 @@
1
- h1 {
2
- text-align: center;
3
- }
4
-
5
- #duplicate-button {
6
- margin: auto;
7
- color: white;
8
- background: #1565c0;
9
- border-radius: 100vh;
10
- }
11
-
12
- #component-0 {
13
- max-width: 900px;
14
- margin: auto;
15
- padding-top: 1.5rem;
16
- }