philschmid HF staff commited on
Commit
fd0f949
1 Parent(s): aecd012
Files changed (4) hide show
  1. README.md +5 -12
  2. app.py +162 -267
  3. model.py +0 -75
  4. requirements.txt +3 -8
README.md CHANGED
@@ -1,19 +1,12 @@
1
  ---
2
- title: Code Llama 13B Chat
3
  emoji: 🦙
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.37.0
8
  app_file: app.py
9
- pinned: false
10
  license: other
11
- suggested_hardware: a10g-small
12
  duplicated_from: huggingface-projects/llama-2-13b-chat
13
  ---
14
-
15
- # LLAMA v2 Models
16
-
17
- Llama v2 was introduced in [this paper](https://arxiv.org/abs/2307.09288).
18
-
19
- This Space demonstrates [Llama-2-13b-chat-hf](meta-llama/Llama-2-13b-chat-hf) from Meta. Please, check the original model card for details.
 
1
  ---
2
+ title: Llama 7B Chat on Inf2
3
  emoji: 🦙
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.46.1
8
  app_file: app.py
9
+ pinned: true
10
  license: other
 
11
  duplicated_from: huggingface-projects/llama-2-13b-chat
12
  ---
 
 
 
 
 
 
app.py CHANGED
@@ -1,279 +1,174 @@
1
  from typing import Iterator
2
-
3
  import gradio as gr
4
- import torch
5
-
6
- from model import get_input_token_length, run
7
-
8
- DEFAULT_SYSTEM_PROMPT = """\
9
- You are a helpful, respectful and honest assistant with a deep knowledge of code and software design. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
10
- """
11
- MAX_MAX_NEW_TOKENS = 4096
12
- DEFAULT_MAX_NEW_TOKENS = 1024
13
- MAX_INPUT_TOKEN_LENGTH = 4000
14
-
15
- DESCRIPTION = """
16
- # Code Llama 13B Chat
17
-
18
- This Space demonstrates model [CodeLlama-13b-Instruct](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) by Meta, a Code Llama model with 13B parameters fine-tuned for chat instructions and specialized on code tasks. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
19
-
20
- 🔎 For more details about the Code Llama family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/codellama) or [the paper](https://huggingface.co/papers/2308.12950).
21
-
22
- 🏃🏻 Check out our [Playground](https://huggingface.co/spaces/codellama/codellama-playground) for a super-fast code completion demo that leverages a streaming [inference endpoint](https://huggingface.co/inference-endpoints).
23
-
24
- 💪 For a chat demo of the largest Code Llama model (34B parameters), you can [select Code Llama in Hugging Chat!](https://huggingface.co/chat/)
25
-
26
- """
27
-
28
- LICENSE = """
29
- <p/>
30
-
31
- ---
32
- As a derivate work of Code Llama by Meta,
33
- this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/codellama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/codellama-2-13b-chat/blob/main/USE_POLICY.md).
34
- """
35
-
36
- if not torch.cuda.is_available():
37
- DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
38
-
39
-
40
- def clear_and_save_textbox(message: str) -> tuple[str, str]:
41
- return '', message
42
-
43
-
44
- def display_input(message: str,
45
- history: list[tuple[str, str]]) -> list[tuple[str, str]]:
46
- history.append((message, ''))
47
- return history
48
-
49
-
50
- def delete_prev_fn(
51
- history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
52
- try:
53
- message, _ = history.pop()
54
- except IndexError:
55
- message = ''
56
- return history, message or ''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def generate(
60
- message: str,
61
- history_with_input: list[tuple[str, str]],
62
- system_prompt: str,
63
- max_new_tokens: int,
64
- temperature: float,
65
- top_p: float,
66
- top_k: int,
67
- ) -> Iterator[list[tuple[str, str]]]:
68
- if max_new_tokens > MAX_MAX_NEW_TOKENS:
69
- raise ValueError
70
-
71
- history = history_with_input[:-1]
72
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
73
- try:
74
- first_response = next(generator)
75
- yield history + [(message, first_response)]
76
- except StopIteration:
77
- yield history + [(message, '')]
78
- for response in generator:
79
- yield history + [(message, response)]
80
-
81
-
82
- def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
83
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
84
- for x in generator:
85
- pass
86
- return '', x
87
-
88
-
89
- def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
90
- input_token_length = get_input_token_length(message, chat_history, system_prompt)
 
 
 
 
 
 
91
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
92
- raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
93
-
94
-
95
- with gr.Blocks(css='style.css') as demo:
96
- gr.Markdown(DESCRIPTION)
97
- gr.DuplicateButton(value='Duplicate Space for private use',
98
- elem_id='duplicate-button')
99
-
100
- with gr.Group():
101
- chatbot = gr.Chatbot(label='Chatbot')
102
- with gr.Row():
103
- textbox = gr.Textbox(
104
- container=False,
105
- show_label=False,
106
- placeholder='Type a message...',
107
- scale=10,
108
- )
109
- submit_button = gr.Button('Submit',
110
- variant='primary',
111
- scale=1,
112
- min_width=0)
113
- with gr.Row():
114
- retry_button = gr.Button('🔄 Retry', variant='secondary')
115
- undo_button = gr.Button('↩️ Undo', variant='secondary')
116
- clear_button = gr.Button('🗑️ Clear', variant='secondary')
117
-
118
- saved_input = gr.State()
119
-
120
- with gr.Accordion(label='Advanced options', open=False):
121
- system_prompt = gr.Textbox(label='System prompt',
122
- value=DEFAULT_SYSTEM_PROMPT,
123
- lines=6)
124
- max_new_tokens = gr.Slider(
125
- label='Max new tokens',
126
- minimum=1,
127
- maximum=MAX_MAX_NEW_TOKENS,
128
- step=1,
129
- value=DEFAULT_MAX_NEW_TOKENS,
130
- )
131
- temperature = gr.Slider(
132
- label='Temperature',
133
- minimum=0.1,
134
- maximum=4.0,
135
- step=0.1,
136
- value=0.1,
137
- )
138
- top_p = gr.Slider(
139
- label='Top-p (nucleus sampling)',
140
- minimum=0.05,
141
- maximum=1.0,
142
- step=0.05,
143
- value=0.9,
144
- )
145
- top_k = gr.Slider(
146
- label='Top-k',
147
- minimum=1,
148
- maximum=1000,
149
- step=1,
150
- value=10,
151
  )
152
 
153
- gr.Examples(
154
- examples=[
155
- 'What is the Fibonacci sequence?',
156
- 'Can you explain briefly what Python is good for?',
157
- 'How can I display a grid of images in SwiftUI?',
158
- ],
159
- inputs=textbox,
160
- outputs=[textbox, chatbot],
161
- fn=process_example,
162
- cache_examples=True,
163
- )
164
-
165
- gr.Markdown(LICENSE)
166
-
167
- textbox.submit(
168
- fn=clear_and_save_textbox,
169
- inputs=textbox,
170
- outputs=[textbox, saved_input],
171
- api_name=False,
172
- queue=False,
173
- ).then(
174
- fn=display_input,
175
- inputs=[saved_input, chatbot],
176
- outputs=chatbot,
177
- api_name=False,
178
- queue=False,
179
- ).then(
180
- fn=check_input_token_length,
181
- inputs=[saved_input, chatbot, system_prompt],
182
- api_name=False,
183
- queue=False,
184
- ).success(
185
- fn=generate,
186
- inputs=[
187
- saved_input,
188
- chatbot,
189
- system_prompt,
190
- max_new_tokens,
191
- temperature,
192
- top_p,
193
- top_k,
194
- ],
195
- outputs=chatbot,
196
- api_name=False,
197
- )
198
-
199
- button_event_preprocess = submit_button.click(
200
- fn=clear_and_save_textbox,
201
- inputs=textbox,
202
- outputs=[textbox, saved_input],
203
- api_name=False,
204
- queue=False,
205
- ).then(
206
- fn=display_input,
207
- inputs=[saved_input, chatbot],
208
- outputs=chatbot,
209
- api_name=False,
210
- queue=False,
211
- ).then(
212
- fn=check_input_token_length,
213
- inputs=[saved_input, chatbot, system_prompt],
214
- api_name=False,
215
- queue=False,
216
- ).success(
217
- fn=generate,
218
- inputs=[
219
- saved_input,
220
- chatbot,
221
- system_prompt,
222
- max_new_tokens,
223
- temperature,
224
- top_p,
225
- top_k,
226
- ],
227
- outputs=chatbot,
228
- api_name=False,
229
- )
230
 
231
- retry_button.click(
232
- fn=delete_prev_fn,
233
- inputs=chatbot,
234
- outputs=[chatbot, saved_input],
235
- api_name=False,
236
- queue=False,
237
- ).then(
238
- fn=display_input,
239
- inputs=[saved_input, chatbot],
240
- outputs=chatbot,
241
- api_name=False,
242
- queue=False,
243
- ).then(
244
- fn=generate,
245
- inputs=[
246
- saved_input,
247
- chatbot,
248
- system_prompt,
249
- max_new_tokens,
250
- temperature,
251
- top_p,
252
- top_k,
253
- ],
254
- outputs=chatbot,
255
- api_name=False,
256
- )
257
 
258
- undo_button.click(
259
- fn=delete_prev_fn,
260
- inputs=chatbot,
261
- outputs=[chatbot, saved_input],
262
- api_name=False,
263
- queue=False,
264
- ).then(
265
- fn=lambda x: x,
266
- inputs=[saved_input],
267
- outputs=textbox,
268
- api_name=False,
269
- queue=False,
270
- )
271
 
272
- clear_button.click(
273
- fn=lambda: ([], ''),
274
- outputs=[chatbot, saved_input],
275
- queue=False,
276
- api_name=False,
277
- )
278
 
279
- demo.queue(max_size=20).launch()
 
1
  from typing import Iterator
 
2
  import gradio as gr
3
+ import boto3
4
+ import io
5
+ import json
6
+ import os
7
+ from transformers import AutoTokenizer
8
+
9
+ aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID", None)
10
+ aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
11
+ aws_session_token = os.environ.get("AWS_SESSION_TOKEN", None)
12
+ region = os.environ.get("AWS_REGION", None)
13
+ endpoint_name = os.environ.get("SAGEMAKER_ENDPOINT_NAME", None)
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(
16
+ "aws-neuron/Llama-2-7b-chat-hf-seqlen-2048-bs-4"
17
+ )
18
+
19
+ # if (
20
+ # aws_access_key_id is None
21
+ # or aws_secret_access_key is None
22
+ # or region is None
23
+ # or endpoint_name is None
24
+ # ):
25
+ # raise Exception(
26
+ # "Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION and SAGEMAKER_ENDPOINT_NAME environment variables"
27
+ # )
28
+
29
+ # boto_session = boto3.Session(
30
+ # aws_access_key_id=aws_access_key_id,
31
+ # aws_secret_access_key=aws_secret_access_key,
32
+ # aws_session_token=aws_session_token,
33
+ # region_name=region,
34
+ # )
35
+
36
+ # smr = boto_session.client("sagemaker-runtime")
37
+
38
+
39
+ DEFAULT_SYSTEM_PROMPT = (
40
+ "You are an helpful Assistant, called Llama. Knowing everyting about AWS."
41
+ )
42
+ MAX_INPUT_TOKEN_LENGTH = 1024
43
+
44
+ # hyperparameters for llm
45
+ parameters = {
46
+ "do_sample": True,
47
+ "top_p": 0.9,
48
+ "temperature": 0.8,
49
+ "max_new_tokens": 1024,
50
+ "repetition_penalty": 1.03,
51
+ "stop": ["<\s>"],
52
+ }
53
+
54
+
55
+ # Helper for reading lines from a stream
56
+ class LineIterator:
57
+ def __init__(self, stream):
58
+ self.byte_iterator = iter(stream)
59
+ self.buffer = io.BytesIO()
60
+ self.read_pos = 0
61
+
62
+ def __iter__(self):
63
+ return self
64
+
65
+ def __next__(self):
66
+ while True:
67
+ self.buffer.seek(self.read_pos)
68
+ line = self.buffer.readline()
69
+ if line and line[-1] == ord("\n"):
70
+ self.read_pos += len(line)
71
+ return line[:-1]
72
+ try:
73
+ chunk = next(self.byte_iterator)
74
+ except StopIteration:
75
+ if self.read_pos < self.buffer.getbuffer().nbytes:
76
+ continue
77
+ raise
78
+ if "PayloadPart" not in chunk:
79
+ print("Unknown event type:" + chunk)
80
+ continue
81
+ self.buffer.seek(0, io.SEEK_END)
82
+ self.buffer.write(chunk["PayloadPart"]["Bytes"])
83
+
84
+
85
+ def format_prompt(message, history):
86
+ messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
87
+ for interaction in history:
88
+ messages.append({"role": "user", "content": interaction[0]})
89
+ messages.append({"role": "assistant", "content": interaction[1]})
90
+ messages.append({"role": "user", "content": message})
91
+ prompt = tokenizer.apply_chat_template(
92
+ messages, tokenize=False, add_generation_prompt=True
93
+ )
94
+ return prompt
95
 
96
 
97
  def generate(
98
+ prompt,
99
+ history,
100
+ ):
101
+ formatted_prompt = format_prompt(prompt, history)
102
+ check_input_token_length(formatted_prompt)
103
+
104
+ request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
105
+ resp = {"Body": ""}
106
+ # resp = {"Body": open("test.json", "rb")}
107
+ # resp = smr.invoke_endpoint_with_response_stream(
108
+ # EndpointName=endpoint_name,
109
+ # Body=json.dumps(request),
110
+ # ContentType="application/json",
111
+ # )
112
+
113
+ output = "offline"
114
+ # for c in LineIterator(resp["Body"]):
115
+ # c = c.decode("utf-8")
116
+ # if c.startswith("data:"):
117
+ # chunk = json.loads(c.lstrip("data:").rstrip("/n"))
118
+ # if chunk["token"]["special"]:
119
+ # continue
120
+ # if chunk["token"]["text"] in request["parameters"]["stop"]:
121
+ # break
122
+ # output += chunk["token"]["text"]
123
+ # for stop_str in request["parameters"]["stop"]:
124
+ # if output.endswith(stop_str):
125
+ # output = output[: -len(stop_str)]
126
+ # output = output.rstrip()
127
+ # yield output
128
+
129
+ # yield output
130
+ return output
131
+
132
+
133
+ def check_input_token_length(prompt: str) -> None:
134
+ input_token_length = len(tokenizer(prompt)["input_ids"])
135
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
136
+ raise gr.Error(
137
+ f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  )
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ theme = gr.themes.Monochrome(
142
+ primary_hue="indigo",
143
+ secondary_hue="blue",
144
+ neutral_hue="slate",
145
+ radius_size=gr.themes.sizes.radius_sm,
146
+ font=[
147
+ gr.themes.GoogleFont("Open Sans"),
148
+ "ui-sans-serif",
149
+ "system-ui",
150
+ "sans-serif",
151
+ ],
152
+ )
153
+ DESCRIPTION = """
154
+ <div style="text-align: center; max-width: 650px; margin: 0 auto; display:grid; gap:25px;">
155
+ <img class="logo" src="https://huggingface.co/datasets/philschmid/assets/resolve/main/aws-neuron_hf.png" alt="Hugging Face Neuron Logo"
156
+ style="margin: auto; max-width: 14rem;">
157
+ <h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
158
+ Llama 2 7B Chat on AWS INF2 ⚡
159
+ </h1>
160
+ <p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
161
+ Llama 2 is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. This is the repository for the 7B fine-tuned model, optimized for dialogue use cases and converted for the Hugging Face Transformers format. Links to other models can be found in the index at the bottom. This demo is running on <a style="text-decoration: underline;" href="https://aws.amazon.com/ec2/instance-types/inf2/?nc1=h_ls">AWS Inferentia2</a>, <a href="https://www.philschmid.de/inferentia2-llama-7b" target="_blank">How does it work?</a>
162
+ </p>
163
+ </div>
164
+ """
 
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ demo = gr.ChatInterface(
168
+ generate,
169
+ description=DESCRIPTION,
170
+ chatbot=gr.Chatbot(layout="panel"),
171
+ theme=theme,
172
+ )
173
 
174
+ demo.queue(concurrency_count=5).launch(share=False)
model.py DELETED
@@ -1,75 +0,0 @@
1
- from threading import Thread
2
- from typing import Iterator
3
-
4
- import torch
5
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
-
7
- model_id = 'codellama/CodeLlama-13b-Instruct-hf'
8
-
9
- if torch.cuda.is_available():
10
- config = AutoConfig.from_pretrained(model_id)
11
- config.pretraining_tp = 1
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_id,
14
- config=config,
15
- torch_dtype=torch.float16,
16
- load_in_4bit=True,
17
- device_map='auto',
18
- use_safetensors=False,
19
- )
20
- else:
21
- model = None
22
- tokenizer = AutoTokenizer.from_pretrained(model_id)
23
-
24
-
25
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
26
- system_prompt: str) -> str:
27
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
28
- # The first user input is _not_ stripped
29
- do_strip = False
30
- for user_input, response in chat_history:
31
- user_input = user_input.strip() if do_strip else user_input
32
- do_strip = True
33
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
34
- message = message.strip() if do_strip else message
35
- texts.append(f'{message} [/INST]')
36
- return ''.join(texts)
37
-
38
-
39
- def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
40
- prompt = get_prompt(message, chat_history, system_prompt)
41
- input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
42
- return input_ids.shape[-1]
43
-
44
-
45
- def run(message: str,
46
- chat_history: list[tuple[str, str]],
47
- system_prompt: str,
48
- max_new_tokens: int = 1024,
49
- temperature: float = 0.1,
50
- top_p: float = 0.9,
51
- top_k: int = 50) -> Iterator[str]:
52
- prompt = get_prompt(message, chat_history, system_prompt)
53
- inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
54
-
55
- streamer = TextIteratorStreamer(tokenizer,
56
- timeout=10.,
57
- skip_prompt=True,
58
- skip_special_tokens=True)
59
- generate_kwargs = dict(
60
- inputs,
61
- streamer=streamer,
62
- max_new_tokens=max_new_tokens,
63
- do_sample=True,
64
- top_p=top_p,
65
- top_k=top_k,
66
- temperature=temperature,
67
- num_beams=1,
68
- )
69
- t = Thread(target=model.generate, kwargs=generate_kwargs)
70
- t.start()
71
-
72
- outputs = []
73
- for text in streamer:
74
- outputs.append(text)
75
- yield ''.join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,3 @@
1
- accelerate
2
- bitsandbytes
3
- gradio
4
- protobuf
5
- scipy
6
- sentencepiece
7
- torch
8
- git+https://github.com/huggingface/transformers@main
 
1
+ boto3
2
+ sagemaker
3
+ transformers