Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,23 @@
|
|
1 |
from huggingface_hub import InferenceClient
|
2 |
import gradio as gr
|
|
|
3 |
|
4 |
client = InferenceClient(
|
5 |
"mistralai/Mistral-7B-Instruct-v0.1"
|
6 |
)
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def format_prompt(message, history):
|
10 |
prompt = "<s>"
|
@@ -14,32 +27,28 @@ def format_prompt(message, history):
|
|
14 |
prompt += f"[INST] {message} [/INST]"
|
15 |
return prompt
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
def generate(
|
18 |
-
prompt, history, temperature=0.9, max_new_tokens=2000, top_p=0.9, repetition_penalty=1.2,
|
19 |
):
|
20 |
temperature = float(temperature)
|
21 |
if temperature < 1e-2:
|
22 |
temperature = 1e-2
|
23 |
top_p = float(top_p)
|
24 |
|
25 |
-
generate_kwargs = dict(
|
26 |
-
temperature=temperature,
|
27 |
-
max_new_tokens=max_new_tokens,
|
28 |
-
top_p=top_p,
|
29 |
-
repetition_penalty=repetition_penalty,
|
30 |
-
do_sample=True,
|
31 |
-
seed=42,
|
32 |
-
)
|
33 |
-
|
34 |
formatted_prompt = format_prompt(prompt, history)
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
output = ""
|
38 |
-
|
39 |
-
for response in stream:
|
40 |
-
output += response.token.text
|
41 |
-
yield output
|
42 |
-
return output
|
43 |
|
44 |
css = """
|
45 |
#mkd {
|
@@ -50,12 +59,10 @@ css = """
|
|
50 |
"""
|
51 |
|
52 |
with gr.Blocks(css=css) as demo:
|
53 |
-
gr.HTML("<h1><center><h1><center>")
|
54 |
-
gr.HTML("<h3><center><h3><center>")
|
55 |
-
gr.HTML("<h3><center><h3><center>")
|
56 |
gr.ChatInterface(
|
57 |
generate,
|
58 |
-
examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."], ["Write a short story about Paris."]]
|
|
|
59 |
)
|
60 |
|
61 |
demo.launch(debug=True)
|
|
|
1 |
from huggingface_hub import InferenceClient
|
2 |
import gradio as gr
|
3 |
+
import json
|
4 |
|
5 |
client = InferenceClient(
|
6 |
"mistralai/Mistral-7B-Instruct-v0.1"
|
7 |
)
|
8 |
|
9 |
+
DATABASE_PATH = "database.json"
|
10 |
+
|
11 |
+
def load_database():
|
12 |
+
try:
|
13 |
+
with open(DATABASE_PATH, "r") as file:
|
14 |
+
return json.load(file)
|
15 |
+
except FileNotFoundError:
|
16 |
+
return {}
|
17 |
+
|
18 |
+
def save_database(database):
|
19 |
+
with open(DATABASE_PATH, "w") as file:
|
20 |
+
json.dump(database, file)
|
21 |
|
22 |
def format_prompt(message, history):
|
23 |
prompt = "<s>"
|
|
|
27 |
prompt += f"[INST] {message} [/INST]"
|
28 |
return prompt
|
29 |
|
30 |
+
def generate_response(prompt, database):
|
31 |
+
if prompt in database:
|
32 |
+
return database[prompt]
|
33 |
+
else:
|
34 |
+
response = next(client.text_generation(prompt, details=True, return_full_text=False)).token.text
|
35 |
+
database[prompt] = response
|
36 |
+
save_database(database)
|
37 |
+
return response
|
38 |
+
|
39 |
def generate(
|
40 |
+
prompt, history, database, temperature=0.9, max_new_tokens=2000, top_p=0.9, repetition_penalty=1.2,
|
41 |
):
|
42 |
temperature = float(temperature)
|
43 |
if temperature < 1e-2:
|
44 |
temperature = 1e-2
|
45 |
top_p = float(top_p)
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
formatted_prompt = format_prompt(prompt, history)
|
48 |
+
response = generate_response(formatted_prompt, database)
|
49 |
+
yield response
|
50 |
|
51 |
+
database = load_database()
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
css = """
|
54 |
#mkd {
|
|
|
59 |
"""
|
60 |
|
61 |
with gr.Blocks(css=css) as demo:
|
|
|
|
|
|
|
62 |
gr.ChatInterface(
|
63 |
generate,
|
64 |
+
examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."], ["Write a short story about Paris."]],
|
65 |
+
database=database
|
66 |
)
|
67 |
|
68 |
demo.launch(debug=True)
|