wop commited on
Commit
cf25a56
·
verified ·
1 Parent(s): e5ec746

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -21
app.py CHANGED
@@ -1,10 +1,23 @@
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
 
3
 
4
  client = InferenceClient(
5
  "mistralai/Mistral-7B-Instruct-v0.1"
6
  )
7
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def format_prompt(message, history):
10
  prompt = "<s>"
@@ -14,32 +27,28 @@ def format_prompt(message, history):
14
  prompt += f"[INST] {message} [/INST]"
15
  return prompt
16
 
 
 
 
 
 
 
 
 
 
17
  def generate(
18
- prompt, history, temperature=0.9, max_new_tokens=2000, top_p=0.9, repetition_penalty=1.2,
19
  ):
20
  temperature = float(temperature)
21
  if temperature < 1e-2:
22
  temperature = 1e-2
23
  top_p = float(top_p)
24
 
25
- generate_kwargs = dict(
26
- temperature=temperature,
27
- max_new_tokens=max_new_tokens,
28
- top_p=top_p,
29
- repetition_penalty=repetition_penalty,
30
- do_sample=True,
31
- seed=42,
32
- )
33
-
34
  formatted_prompt = format_prompt(prompt, history)
 
 
35
 
36
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
- output = ""
38
-
39
- for response in stream:
40
- output += response.token.text
41
- yield output
42
- return output
43
 
44
  css = """
45
  #mkd {
@@ -50,12 +59,10 @@ css = """
50
  """
51
 
52
  with gr.Blocks(css=css) as demo:
53
- gr.HTML("<h1><center><h1><center>")
54
- gr.HTML("<h3><center><h3><center>")
55
- gr.HTML("<h3><center><h3><center>")
56
  gr.ChatInterface(
57
  generate,
58
- examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."], ["Write a short story about Paris."]]
 
59
  )
60
 
61
  demo.launch(debug=True)
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
+ import json
4
 
5
  client = InferenceClient(
6
  "mistralai/Mistral-7B-Instruct-v0.1"
7
  )
8
 
9
+ DATABASE_PATH = "database.json"
10
+
11
+ def load_database():
12
+ try:
13
+ with open(DATABASE_PATH, "r") as file:
14
+ return json.load(file)
15
+ except FileNotFoundError:
16
+ return {}
17
+
18
+ def save_database(database):
19
+ with open(DATABASE_PATH, "w") as file:
20
+ json.dump(database, file)
21
 
22
  def format_prompt(message, history):
23
  prompt = "<s>"
 
27
  prompt += f"[INST] {message} [/INST]"
28
  return prompt
29
 
30
+ def generate_response(prompt, database):
31
+ if prompt in database:
32
+ return database[prompt]
33
+ else:
34
+ response = next(client.text_generation(prompt, details=True, return_full_text=False)).token.text
35
+ database[prompt] = response
36
+ save_database(database)
37
+ return response
38
+
39
  def generate(
40
+ prompt, history, database, temperature=0.9, max_new_tokens=2000, top_p=0.9, repetition_penalty=1.2,
41
  ):
42
  temperature = float(temperature)
43
  if temperature < 1e-2:
44
  temperature = 1e-2
45
  top_p = float(top_p)
46
 
 
 
 
 
 
 
 
 
 
47
  formatted_prompt = format_prompt(prompt, history)
48
+ response = generate_response(formatted_prompt, database)
49
+ yield response
50
 
51
+ database = load_database()
 
 
 
 
 
 
52
 
53
  css = """
54
  #mkd {
 
59
  """
60
 
61
  with gr.Blocks(css=css) as demo:
 
 
 
62
  gr.ChatInterface(
63
  generate,
64
+ examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."], ["Write a short story about Paris."]],
65
+ database=database
66
  )
67
 
68
  demo.launch(debug=True)