Spaces:
Sleeping
Sleeping
Commit
•
fd0f949
1
Parent(s):
aecd012
wip
Browse files- README.md +5 -12
- app.py +162 -267
- model.py +0 -75
- requirements.txt +3 -8
README.md
CHANGED
@@ -1,19 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🦙
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: other
|
11 |
-
suggested_hardware: a10g-small
|
12 |
duplicated_from: huggingface-projects/llama-2-13b-chat
|
13 |
---
|
14 |
-
|
15 |
-
# LLAMA v2 Models
|
16 |
-
|
17 |
-
Llama v2 was introduced in [this paper](https://arxiv.org/abs/2307.09288).
|
18 |
-
|
19 |
-
This Space demonstrates [Llama-2-13b-chat-hf](meta-llama/Llama-2-13b-chat-hf) from Meta. Please, check the original model card for details.
|
|
|
1 |
---
|
2 |
+
title: Llama 7B Chat on Inf2
|
3 |
emoji: 🦙
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.46.1
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: other
|
|
|
11 |
duplicated_from: huggingface-projects/llama-2-13b-chat
|
12 |
---
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,279 +1,174 @@
|
|
1 |
from typing import Iterator
|
2 |
-
|
3 |
import gradio as gr
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
""
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
def generate(
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
92 |
-
raise gr.Error(
|
93 |
-
|
94 |
-
|
95 |
-
with gr.Blocks(css='style.css') as demo:
|
96 |
-
gr.Markdown(DESCRIPTION)
|
97 |
-
gr.DuplicateButton(value='Duplicate Space for private use',
|
98 |
-
elem_id='duplicate-button')
|
99 |
-
|
100 |
-
with gr.Group():
|
101 |
-
chatbot = gr.Chatbot(label='Chatbot')
|
102 |
-
with gr.Row():
|
103 |
-
textbox = gr.Textbox(
|
104 |
-
container=False,
|
105 |
-
show_label=False,
|
106 |
-
placeholder='Type a message...',
|
107 |
-
scale=10,
|
108 |
-
)
|
109 |
-
submit_button = gr.Button('Submit',
|
110 |
-
variant='primary',
|
111 |
-
scale=1,
|
112 |
-
min_width=0)
|
113 |
-
with gr.Row():
|
114 |
-
retry_button = gr.Button('🔄 Retry', variant='secondary')
|
115 |
-
undo_button = gr.Button('↩️ Undo', variant='secondary')
|
116 |
-
clear_button = gr.Button('🗑️ Clear', variant='secondary')
|
117 |
-
|
118 |
-
saved_input = gr.State()
|
119 |
-
|
120 |
-
with gr.Accordion(label='Advanced options', open=False):
|
121 |
-
system_prompt = gr.Textbox(label='System prompt',
|
122 |
-
value=DEFAULT_SYSTEM_PROMPT,
|
123 |
-
lines=6)
|
124 |
-
max_new_tokens = gr.Slider(
|
125 |
-
label='Max new tokens',
|
126 |
-
minimum=1,
|
127 |
-
maximum=MAX_MAX_NEW_TOKENS,
|
128 |
-
step=1,
|
129 |
-
value=DEFAULT_MAX_NEW_TOKENS,
|
130 |
-
)
|
131 |
-
temperature = gr.Slider(
|
132 |
-
label='Temperature',
|
133 |
-
minimum=0.1,
|
134 |
-
maximum=4.0,
|
135 |
-
step=0.1,
|
136 |
-
value=0.1,
|
137 |
-
)
|
138 |
-
top_p = gr.Slider(
|
139 |
-
label='Top-p (nucleus sampling)',
|
140 |
-
minimum=0.05,
|
141 |
-
maximum=1.0,
|
142 |
-
step=0.05,
|
143 |
-
value=0.9,
|
144 |
-
)
|
145 |
-
top_k = gr.Slider(
|
146 |
-
label='Top-k',
|
147 |
-
minimum=1,
|
148 |
-
maximum=1000,
|
149 |
-
step=1,
|
150 |
-
value=10,
|
151 |
)
|
152 |
|
153 |
-
gr.Examples(
|
154 |
-
examples=[
|
155 |
-
'What is the Fibonacci sequence?',
|
156 |
-
'Can you explain briefly what Python is good for?',
|
157 |
-
'How can I display a grid of images in SwiftUI?',
|
158 |
-
],
|
159 |
-
inputs=textbox,
|
160 |
-
outputs=[textbox, chatbot],
|
161 |
-
fn=process_example,
|
162 |
-
cache_examples=True,
|
163 |
-
)
|
164 |
-
|
165 |
-
gr.Markdown(LICENSE)
|
166 |
-
|
167 |
-
textbox.submit(
|
168 |
-
fn=clear_and_save_textbox,
|
169 |
-
inputs=textbox,
|
170 |
-
outputs=[textbox, saved_input],
|
171 |
-
api_name=False,
|
172 |
-
queue=False,
|
173 |
-
).then(
|
174 |
-
fn=display_input,
|
175 |
-
inputs=[saved_input, chatbot],
|
176 |
-
outputs=chatbot,
|
177 |
-
api_name=False,
|
178 |
-
queue=False,
|
179 |
-
).then(
|
180 |
-
fn=check_input_token_length,
|
181 |
-
inputs=[saved_input, chatbot, system_prompt],
|
182 |
-
api_name=False,
|
183 |
-
queue=False,
|
184 |
-
).success(
|
185 |
-
fn=generate,
|
186 |
-
inputs=[
|
187 |
-
saved_input,
|
188 |
-
chatbot,
|
189 |
-
system_prompt,
|
190 |
-
max_new_tokens,
|
191 |
-
temperature,
|
192 |
-
top_p,
|
193 |
-
top_k,
|
194 |
-
],
|
195 |
-
outputs=chatbot,
|
196 |
-
api_name=False,
|
197 |
-
)
|
198 |
-
|
199 |
-
button_event_preprocess = submit_button.click(
|
200 |
-
fn=clear_and_save_textbox,
|
201 |
-
inputs=textbox,
|
202 |
-
outputs=[textbox, saved_input],
|
203 |
-
api_name=False,
|
204 |
-
queue=False,
|
205 |
-
).then(
|
206 |
-
fn=display_input,
|
207 |
-
inputs=[saved_input, chatbot],
|
208 |
-
outputs=chatbot,
|
209 |
-
api_name=False,
|
210 |
-
queue=False,
|
211 |
-
).then(
|
212 |
-
fn=check_input_token_length,
|
213 |
-
inputs=[saved_input, chatbot, system_prompt],
|
214 |
-
api_name=False,
|
215 |
-
queue=False,
|
216 |
-
).success(
|
217 |
-
fn=generate,
|
218 |
-
inputs=[
|
219 |
-
saved_input,
|
220 |
-
chatbot,
|
221 |
-
system_prompt,
|
222 |
-
max_new_tokens,
|
223 |
-
temperature,
|
224 |
-
top_p,
|
225 |
-
top_k,
|
226 |
-
],
|
227 |
-
outputs=chatbot,
|
228 |
-
api_name=False,
|
229 |
-
)
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
api_name=False,
|
256 |
-
)
|
257 |
|
258 |
-
undo_button.click(
|
259 |
-
fn=delete_prev_fn,
|
260 |
-
inputs=chatbot,
|
261 |
-
outputs=[chatbot, saved_input],
|
262 |
-
api_name=False,
|
263 |
-
queue=False,
|
264 |
-
).then(
|
265 |
-
fn=lambda x: x,
|
266 |
-
inputs=[saved_input],
|
267 |
-
outputs=textbox,
|
268 |
-
api_name=False,
|
269 |
-
queue=False,
|
270 |
-
)
|
271 |
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
|
279 |
-
demo.queue(
|
|
|
1 |
from typing import Iterator
|
|
|
2 |
import gradio as gr
|
3 |
+
import boto3
|
4 |
+
import io
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID", None)
|
10 |
+
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
|
11 |
+
aws_session_token = os.environ.get("AWS_SESSION_TOKEN", None)
|
12 |
+
region = os.environ.get("AWS_REGION", None)
|
13 |
+
endpoint_name = os.environ.get("SAGEMAKER_ENDPOINT_NAME", None)
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
16 |
+
"aws-neuron/Llama-2-7b-chat-hf-seqlen-2048-bs-4"
|
17 |
+
)
|
18 |
+
|
19 |
+
# if (
|
20 |
+
# aws_access_key_id is None
|
21 |
+
# or aws_secret_access_key is None
|
22 |
+
# or region is None
|
23 |
+
# or endpoint_name is None
|
24 |
+
# ):
|
25 |
+
# raise Exception(
|
26 |
+
# "Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION and SAGEMAKER_ENDPOINT_NAME environment variables"
|
27 |
+
# )
|
28 |
+
|
29 |
+
# boto_session = boto3.Session(
|
30 |
+
# aws_access_key_id=aws_access_key_id,
|
31 |
+
# aws_secret_access_key=aws_secret_access_key,
|
32 |
+
# aws_session_token=aws_session_token,
|
33 |
+
# region_name=region,
|
34 |
+
# )
|
35 |
+
|
36 |
+
# smr = boto_session.client("sagemaker-runtime")
|
37 |
+
|
38 |
+
|
39 |
+
DEFAULT_SYSTEM_PROMPT = (
|
40 |
+
"You are an helpful Assistant, called Llama. Knowing everyting about AWS."
|
41 |
+
)
|
42 |
+
MAX_INPUT_TOKEN_LENGTH = 1024
|
43 |
+
|
44 |
+
# hyperparameters for llm
|
45 |
+
parameters = {
|
46 |
+
"do_sample": True,
|
47 |
+
"top_p": 0.9,
|
48 |
+
"temperature": 0.8,
|
49 |
+
"max_new_tokens": 1024,
|
50 |
+
"repetition_penalty": 1.03,
|
51 |
+
"stop": ["<\s>"],
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
# Helper for reading lines from a stream
|
56 |
+
class LineIterator:
|
57 |
+
def __init__(self, stream):
|
58 |
+
self.byte_iterator = iter(stream)
|
59 |
+
self.buffer = io.BytesIO()
|
60 |
+
self.read_pos = 0
|
61 |
+
|
62 |
+
def __iter__(self):
|
63 |
+
return self
|
64 |
+
|
65 |
+
def __next__(self):
|
66 |
+
while True:
|
67 |
+
self.buffer.seek(self.read_pos)
|
68 |
+
line = self.buffer.readline()
|
69 |
+
if line and line[-1] == ord("\n"):
|
70 |
+
self.read_pos += len(line)
|
71 |
+
return line[:-1]
|
72 |
+
try:
|
73 |
+
chunk = next(self.byte_iterator)
|
74 |
+
except StopIteration:
|
75 |
+
if self.read_pos < self.buffer.getbuffer().nbytes:
|
76 |
+
continue
|
77 |
+
raise
|
78 |
+
if "PayloadPart" not in chunk:
|
79 |
+
print("Unknown event type:" + chunk)
|
80 |
+
continue
|
81 |
+
self.buffer.seek(0, io.SEEK_END)
|
82 |
+
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
83 |
+
|
84 |
+
|
85 |
+
def format_prompt(message, history):
|
86 |
+
messages = [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}]
|
87 |
+
for interaction in history:
|
88 |
+
messages.append({"role": "user", "content": interaction[0]})
|
89 |
+
messages.append({"role": "assistant", "content": interaction[1]})
|
90 |
+
messages.append({"role": "user", "content": message})
|
91 |
+
prompt = tokenizer.apply_chat_template(
|
92 |
+
messages, tokenize=False, add_generation_prompt=True
|
93 |
+
)
|
94 |
+
return prompt
|
95 |
|
96 |
|
97 |
def generate(
|
98 |
+
prompt,
|
99 |
+
history,
|
100 |
+
):
|
101 |
+
formatted_prompt = format_prompt(prompt, history)
|
102 |
+
check_input_token_length(formatted_prompt)
|
103 |
+
|
104 |
+
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
|
105 |
+
resp = {"Body": ""}
|
106 |
+
# resp = {"Body": open("test.json", "rb")}
|
107 |
+
# resp = smr.invoke_endpoint_with_response_stream(
|
108 |
+
# EndpointName=endpoint_name,
|
109 |
+
# Body=json.dumps(request),
|
110 |
+
# ContentType="application/json",
|
111 |
+
# )
|
112 |
+
|
113 |
+
output = "offline"
|
114 |
+
# for c in LineIterator(resp["Body"]):
|
115 |
+
# c = c.decode("utf-8")
|
116 |
+
# if c.startswith("data:"):
|
117 |
+
# chunk = json.loads(c.lstrip("data:").rstrip("/n"))
|
118 |
+
# if chunk["token"]["special"]:
|
119 |
+
# continue
|
120 |
+
# if chunk["token"]["text"] in request["parameters"]["stop"]:
|
121 |
+
# break
|
122 |
+
# output += chunk["token"]["text"]
|
123 |
+
# for stop_str in request["parameters"]["stop"]:
|
124 |
+
# if output.endswith(stop_str):
|
125 |
+
# output = output[: -len(stop_str)]
|
126 |
+
# output = output.rstrip()
|
127 |
+
# yield output
|
128 |
+
|
129 |
+
# yield output
|
130 |
+
return output
|
131 |
+
|
132 |
+
|
133 |
+
def check_input_token_length(prompt: str) -> None:
|
134 |
+
input_token_length = len(tokenizer(prompt)["input_ids"])
|
135 |
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
136 |
+
raise gr.Error(
|
137 |
+
f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
)
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
theme = gr.themes.Monochrome(
|
142 |
+
primary_hue="indigo",
|
143 |
+
secondary_hue="blue",
|
144 |
+
neutral_hue="slate",
|
145 |
+
radius_size=gr.themes.sizes.radius_sm,
|
146 |
+
font=[
|
147 |
+
gr.themes.GoogleFont("Open Sans"),
|
148 |
+
"ui-sans-serif",
|
149 |
+
"system-ui",
|
150 |
+
"sans-serif",
|
151 |
+
],
|
152 |
+
)
|
153 |
+
DESCRIPTION = """
|
154 |
+
<div style="text-align: center; max-width: 650px; margin: 0 auto; display:grid; gap:25px;">
|
155 |
+
<img class="logo" src="https://huggingface.co/datasets/philschmid/assets/resolve/main/aws-neuron_hf.png" alt="Hugging Face Neuron Logo"
|
156 |
+
style="margin: auto; max-width: 14rem;">
|
157 |
+
<h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
|
158 |
+
Llama 2 7B Chat on AWS INF2 ⚡
|
159 |
+
</h1>
|
160 |
+
<p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
|
161 |
+
Llama 2 is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. This is the repository for the 7B fine-tuned model, optimized for dialogue use cases and converted for the Hugging Face Transformers format. Links to other models can be found in the index at the bottom. This demo is running on <a style="text-decoration: underline;" href="https://aws.amazon.com/ec2/instance-types/inf2/?nc1=h_ls">AWS Inferentia2</a>, <a href="https://www.philschmid.de/inferentia2-llama-7b" target="_blank">How does it work?</a>
|
162 |
+
</p>
|
163 |
+
</div>
|
164 |
+
"""
|
|
|
|
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
+
demo = gr.ChatInterface(
|
168 |
+
generate,
|
169 |
+
description=DESCRIPTION,
|
170 |
+
chatbot=gr.Chatbot(layout="panel"),
|
171 |
+
theme=theme,
|
172 |
+
)
|
173 |
|
174 |
+
demo.queue(concurrency_count=5).launch(share=False)
|
model.py
DELETED
@@ -1,75 +0,0 @@
|
|
1 |
-
from threading import Thread
|
2 |
-
from typing import Iterator
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
6 |
-
|
7 |
-
model_id = 'codellama/CodeLlama-13b-Instruct-hf'
|
8 |
-
|
9 |
-
if torch.cuda.is_available():
|
10 |
-
config = AutoConfig.from_pretrained(model_id)
|
11 |
-
config.pretraining_tp = 1
|
12 |
-
model = AutoModelForCausalLM.from_pretrained(
|
13 |
-
model_id,
|
14 |
-
config=config,
|
15 |
-
torch_dtype=torch.float16,
|
16 |
-
load_in_4bit=True,
|
17 |
-
device_map='auto',
|
18 |
-
use_safetensors=False,
|
19 |
-
)
|
20 |
-
else:
|
21 |
-
model = None
|
22 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
23 |
-
|
24 |
-
|
25 |
-
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
26 |
-
system_prompt: str) -> str:
|
27 |
-
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
28 |
-
# The first user input is _not_ stripped
|
29 |
-
do_strip = False
|
30 |
-
for user_input, response in chat_history:
|
31 |
-
user_input = user_input.strip() if do_strip else user_input
|
32 |
-
do_strip = True
|
33 |
-
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
34 |
-
message = message.strip() if do_strip else message
|
35 |
-
texts.append(f'{message} [/INST]')
|
36 |
-
return ''.join(texts)
|
37 |
-
|
38 |
-
|
39 |
-
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
|
40 |
-
prompt = get_prompt(message, chat_history, system_prompt)
|
41 |
-
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
42 |
-
return input_ids.shape[-1]
|
43 |
-
|
44 |
-
|
45 |
-
def run(message: str,
|
46 |
-
chat_history: list[tuple[str, str]],
|
47 |
-
system_prompt: str,
|
48 |
-
max_new_tokens: int = 1024,
|
49 |
-
temperature: float = 0.1,
|
50 |
-
top_p: float = 0.9,
|
51 |
-
top_k: int = 50) -> Iterator[str]:
|
52 |
-
prompt = get_prompt(message, chat_history, system_prompt)
|
53 |
-
inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
|
54 |
-
|
55 |
-
streamer = TextIteratorStreamer(tokenizer,
|
56 |
-
timeout=10.,
|
57 |
-
skip_prompt=True,
|
58 |
-
skip_special_tokens=True)
|
59 |
-
generate_kwargs = dict(
|
60 |
-
inputs,
|
61 |
-
streamer=streamer,
|
62 |
-
max_new_tokens=max_new_tokens,
|
63 |
-
do_sample=True,
|
64 |
-
top_p=top_p,
|
65 |
-
top_k=top_k,
|
66 |
-
temperature=temperature,
|
67 |
-
num_beams=1,
|
68 |
-
)
|
69 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
70 |
-
t.start()
|
71 |
-
|
72 |
-
outputs = []
|
73 |
-
for text in streamer:
|
74 |
-
outputs.append(text)
|
75 |
-
yield ''.join(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,8 +1,3 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
protobuf
|
5 |
-
scipy
|
6 |
-
sentencepiece
|
7 |
-
torch
|
8 |
-
git+https://github.com/huggingface/transformers@main
|
|
|
1 |
+
boto3
|
2 |
+
sagemaker
|
3 |
+
transformers
|
|
|
|
|
|
|
|
|
|