Yingxu He
commited on
Commit
•
5ee0c02
1
Parent(s):
cdaf751
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
import gradio as gr
|
4 |
-
|
5 |
-
import requests
|
6 |
|
7 |
endpoint_url = os.getenv('ENDPOINT_URL')
|
8 |
personal_secret_token = os.getenv('PERSONAL_HF_TOKEN')
|
@@ -12,15 +11,7 @@ system_symbol = os.getenv('SYSTEM_SYMBOL')
|
|
12 |
user_symbol = os.getenv('USER_SYMBOL')
|
13 |
assistant_symbol = os.getenv('ASSISTANT_SYMBOL')
|
14 |
|
15 |
-
|
16 |
-
"Accept" : "application/json",
|
17 |
-
"Authorization": f"Bearer {personal_secret_token}",
|
18 |
-
"Content-Type": "application/json"
|
19 |
-
}
|
20 |
-
|
21 |
-
def query(payload):
|
22 |
-
response = requests.post(endpoint_url, headers=headers, json=payload)
|
23 |
-
return response.json()
|
24 |
|
25 |
def respond(
|
26 |
message,
|
@@ -29,7 +20,16 @@ def respond(
|
|
29 |
max_new_tokens,
|
30 |
temperature,
|
31 |
top_p,
|
|
|
32 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
all_messages = [system_message]
|
34 |
|
35 |
for val in history:
|
@@ -48,11 +48,13 @@ def respond(
|
|
48 |
# stream=True,
|
49 |
)
|
50 |
|
51 |
-
response =
|
52 |
"inputs": turn_breaker.join(all_messages),
|
53 |
"parameters": generation_kwargs
|
54 |
})
|
55 |
|
|
|
|
|
56 |
return response
|
57 |
|
58 |
|
@@ -73,6 +75,7 @@ demo = gr.ChatInterface(
|
|
73 |
label="Top-p (nucleus sampling)",
|
74 |
),
|
75 |
],
|
|
|
76 |
)
|
77 |
|
78 |
|
|
|
1 |
import os
|
2 |
+
import time
|
3 |
import gradio as gr
|
4 |
+
from huggingface_hub import get_inference_endpoint
|
|
|
5 |
|
6 |
endpoint_url = os.getenv('ENDPOINT_URL')
|
7 |
personal_secret_token = os.getenv('PERSONAL_HF_TOKEN')
|
|
|
11 |
user_symbol = os.getenv('USER_SYMBOL')
|
12 |
assistant_symbol = os.getenv('ASSISTANT_SYMBOL')
|
13 |
|
14 |
+
endpoint = get_inference_endpoint(endpoint_url, token=personal_secret_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def respond(
|
17 |
message,
|
|
|
20 |
max_new_tokens,
|
21 |
temperature,
|
22 |
top_p,
|
23 |
+
progress=gr.Progress()
|
24 |
):
|
25 |
+
progress(0, desc="Starting")
|
26 |
+
|
27 |
+
while endpoint.status != "running":
|
28 |
+
progress(0.25, desc="Waking up model")
|
29 |
+
time.sleep(1)
|
30 |
+
|
31 |
+
progress(0.5, desc="Generating")
|
32 |
+
|
33 |
all_messages = [system_message]
|
34 |
|
35 |
for val in history:
|
|
|
48 |
# stream=True,
|
49 |
)
|
50 |
|
51 |
+
response = endpoint.client.post({
|
52 |
"inputs": turn_breaker.join(all_messages),
|
53 |
"parameters": generation_kwargs
|
54 |
})
|
55 |
|
56 |
+
progress(1, desc="Generating")
|
57 |
+
|
58 |
return response
|
59 |
|
60 |
|
|
|
75 |
label="Top-p (nucleus sampling)",
|
76 |
),
|
77 |
],
|
78 |
+
show_progress="full"
|
79 |
)
|
80 |
|
81 |
|