Staticaliza
commited on
Commit
•
c02118d
1
Parent(s):
120bd05
Update app.py
Browse files
app.py
CHANGED
@@ -2,12 +2,15 @@ import gradio as gr
|
|
2 |
from huggingface_hub import Repository, InferenceClient
|
3 |
import os
|
4 |
import json
|
|
|
5 |
|
6 |
API_TOKEN = os.environ.get("API_TOKEN")
|
7 |
API_ENDPOINT = os.environ.get("API_ENDPOINT")
|
8 |
|
9 |
KEY = os.environ.get("KEY")
|
10 |
|
|
|
|
|
11 |
API_ENDPOINTS = {
|
12 |
"Falcon": "tiiuae/falcon-180B-chat",
|
13 |
"Llama": "meta-llama/Llama-2-70b-chat-hf"
|
@@ -20,25 +23,24 @@ for model_name, model_endpoint in API_ENDPOINTS.items():
|
|
20 |
CHOICES.append(model_name)
|
21 |
CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
|
22 |
|
23 |
-
def format(
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
user_message, bot_message = turn
|
28 |
-
prompt = f"{prompt}\n{USER_NAME}: {user_message}\n{BOT_NAME}: {bot_message}"
|
29 |
-
prompt = f"{prompt}\n{USER_NAME}: {message}\n{BOT_NAME}:"
|
30 |
return prompt
|
31 |
|
32 |
-
def predict(instruction, history, input, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
|
33 |
|
34 |
if (access_key != KEY):
|
35 |
print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
|
36 |
return ("[UNAUTHORIZED ACCESS]", input);
|
37 |
|
38 |
stops = json.loads(stop_seqs)
|
|
|
|
|
39 |
|
40 |
response = CLIENTS[model].text_generation(
|
41 |
-
|
42 |
temperature = temperature,
|
43 |
max_new_tokens = max_tokens,
|
44 |
top_p = top_p,
|
@@ -52,9 +54,15 @@ def predict(instruction, history, input, access_key, model, temperature, top_p,
|
|
52 |
return_full_text = False
|
53 |
)
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
return (
|
58 |
|
59 |
def maintain_cloud():
|
60 |
print(">>> SPACE MAINTAINED!")
|
@@ -68,6 +76,7 @@ with gr.Blocks() as demo:
|
|
68 |
with gr.Column():
|
69 |
history = gr.Chatbot(elem_id = "chatbot")
|
70 |
input = gr.Textbox(label = "Input", lines = 2)
|
|
|
71 |
instruction = gr.Textbox(label = "Instruction", lines = 4)
|
72 |
access_key = gr.Textbox(label = "Access Key", lines = 1)
|
73 |
run = gr.Button("▶")
|
@@ -87,7 +96,7 @@ with gr.Blocks() as demo:
|
|
87 |
with gr.Column():
|
88 |
output = gr.Textbox(label = "Output", value = "", lines = 50)
|
89 |
|
90 |
-
run.click(predict, inputs = [instruction, history, input, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input])
|
91 |
cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
|
92 |
|
93 |
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)
|
|
|
2 |
from huggingface_hub import Repository, InferenceClient
|
3 |
import os
|
4 |
import json
|
5 |
+
import re
|
6 |
|
7 |
API_TOKEN = os.environ.get("API_TOKEN")
|
8 |
API_ENDPOINT = os.environ.get("API_ENDPOINT")
|
9 |
|
10 |
KEY = os.environ.get("KEY")
|
11 |
|
12 |
+
SPECIAL_SYMBOLS = ["‹", "›"]
|
13 |
+
|
14 |
API_ENDPOINTS = {
|
15 |
"Falcon": "tiiuae/falcon-180B-chat",
|
16 |
"Llama": "meta-llama/Llama-2-70b-chat-hf"
|
|
|
23 |
CHOICES.append(model_name)
|
24 |
CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
|
25 |
|
26 |
+
def format(instruction = "", history = "", input = "", preinput = ""):
|
27 |
+
sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
|
28 |
+
formatted_history = '\n'.join(f"{{sy_l}{message}{sy_r}" for message in history)
|
29 |
+
task_message = f"{instruction}\n{formatted_history}\n{sy_l}{input}{sy_r}\n{preinput}"
|
|
|
|
|
|
|
30 |
return prompt
|
31 |
|
32 |
+
def predict(instruction, history, input, preinput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
|
33 |
|
34 |
if (access_key != KEY):
|
35 |
print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
|
36 |
return ("[UNAUTHORIZED ACCESS]", input);
|
37 |
|
38 |
stops = json.loads(stop_seqs)
|
39 |
+
|
40 |
+
formatted_input = format(instruction, history, input, preinput)
|
41 |
|
42 |
response = CLIENTS[model].text_generation(
|
43 |
+
formatted_input,
|
44 |
temperature = temperature,
|
45 |
max_new_tokens = max_tokens,
|
46 |
top_p = top_p,
|
|
|
54 |
return_full_text = False
|
55 |
)
|
56 |
|
57 |
+
sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
|
58 |
+
pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}"
|
59 |
+
pattern = re.compile(f"{sy_l}(.*?){sy_r}", re.DOTALL)
|
60 |
+
match = pattern.search(pre_result)
|
61 |
+
get_result = match.group(1).strip() if match else ""
|
62 |
+
|
63 |
+
print(f"---\nUSER: {input}\nBOT: {get_result}\n---")
|
64 |
|
65 |
+
return (get_result, input)
|
66 |
|
67 |
def maintain_cloud():
|
68 |
print(">>> SPACE MAINTAINED!")
|
|
|
76 |
with gr.Column():
|
77 |
history = gr.Chatbot(elem_id = "chatbot")
|
78 |
input = gr.Textbox(label = "Input", lines = 2)
|
79 |
+
preinput = gr.Textbox(label = "Pre-Input", lines = 1)
|
80 |
instruction = gr.Textbox(label = "Instruction", lines = 4)
|
81 |
access_key = gr.Textbox(label = "Access Key", lines = 1)
|
82 |
run = gr.Button("▶")
|
|
|
96 |
with gr.Column():
|
97 |
output = gr.Textbox(label = "Output", value = "", lines = 50)
|
98 |
|
99 |
+
run.click(predict, inputs = [instruction, history, input, preinput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input])
|
100 |
cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
|
101 |
|
102 |
demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)
|