updated predict for handling streaming events
Browse files
app.py
CHANGED
@@ -7,48 +7,6 @@ import requests
|
|
7 |
#Streaming endpoint
|
8 |
API_URL = os.getenv("API_URL") + "/generate_stream"
|
9 |
|
10 |
-
def predict_old(inputs, top_p, temperature, top_k, repetition_penalty, history=[]):
|
11 |
-
if not inputs.startswith("User: "):
|
12 |
-
inputs = "User: " + inputs + "\n"
|
13 |
-
payload = {
|
14 |
-
"inputs": inputs, #"My name is Jane and I",
|
15 |
-
"parameters": {
|
16 |
-
"details": True,
|
17 |
-
"do_sample": True,
|
18 |
-
"max_new_tokens": 100,
|
19 |
-
"repetition_penalty": repetition_penalty, #1.03,
|
20 |
-
"seed": 0,
|
21 |
-
"temperature": temperature, #0.5,
|
22 |
-
"top_k": top_k, #10,
|
23 |
-
"top_p": top_p #0.95
|
24 |
-
}
|
25 |
-
}
|
26 |
-
|
27 |
-
headers = {
|
28 |
-
'accept': 'text/event-stream',
|
29 |
-
'Content-Type': 'application/json'
|
30 |
-
}
|
31 |
-
|
32 |
-
history.append(inputs)
|
33 |
-
response = requests.post(API_URL, headers=headers, json=payload)
|
34 |
-
responses = response.text.split("\n\n")
|
35 |
-
|
36 |
-
partial_words = ""
|
37 |
-
for idx, resp in enumerate(responses):
|
38 |
-
if resp[:4] == 'data':
|
39 |
-
partial_words = partial_words + json.loads(resp[5:])['token']['text']
|
40 |
-
#print(partial_words)
|
41 |
-
time.sleep(0.05)
|
42 |
-
if idx == 0:
|
43 |
-
history.append(" " + partial_words)
|
44 |
-
else:
|
45 |
-
history[-1] = partial_words
|
46 |
-
|
47 |
-
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
|
48 |
-
|
49 |
-
yield chat, history #resembles {chatbot: chat, state: history}
|
50 |
-
|
51 |
-
|
52 |
def predict(inputs, top_p, temperature, top_k, repetition_penalty, history=[]):
|
53 |
if not inputs.startswith("User: "):
|
54 |
inputs = "User: " + inputs + "\n"
|
|
|
7 |
#Streaming endpoint
|
8 |
API_URL = os.getenv("API_URL") + "/generate_stream"
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def predict(inputs, top_p, temperature, top_k, repetition_penalty, history=[]):
|
11 |
if not inputs.startswith("User: "):
|
12 |
inputs = "User: " + inputs + "\n"
|