TheMaisk commited on
Commit
d8b357d
1 Parent(s): 9035577

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import InferenceClient
3
+ import os
4
+
5
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
6
+
7
+ secret_prompt = os.getenv("SECRET_PROMPT")
8
+
9
+ def format_prompt(message, history):
10
+ prompt = secret_prompt
11
+ for user_prompt, bot_response in history:
12
+ prompt += f"[INST] {user_prompt} [/INST]"
13
+ prompt += f" {bot_response}</s> "
14
+ prompt += f"[INST] {message} [/INST]"
15
+ return prompt
16
+
17
+ def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
18
+ temperature = float(temperature)
19
+ if temperature < 1e-2:
20
+ temperature = 1e-2
21
+ top_p = float(top_p)
22
+
23
+ generate_kwargs = dict(
24
+ temperature=temperature,
25
+ max_new_tokens=max_new_tokens,
26
+ top_p=top_p,
27
+ repetition_penalty=repetition_penalty,
28
+ do_sample=True,
29
+ seed=42,
30
+ )
31
+
32
+ formatted_prompt = format_prompt(prompt, history)
33
+
34
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
35
+ output = ""
36
+
37
+ for response in stream:
38
+ output += response.token.text
39
+ yield output
40
+ return output
41
+
42
+ st.title("Einfach.Mistral 7B v0.3")
43
+
44
+ history = []
45
+
46
+ with st.sidebar:
47
+ temperature = st.slider(
48
+ "Temperature",
49
+ value=0.9,
50
+ min_value=0.0,
51
+ max_value=1.0,
52
+ step=0.05,
53
+ help="Higher values produce more diverse outputs",
54
+ )
55
+ max_new_tokens = st.slider(
56
+ "Max new tokens",
57
+ value=256,
58
+ min_value=0,
59
+ max_value=1048,
60
+ step=64,
61
+ help="The maximum numbers of new tokens",
62
+ )
63
+ top_p = st.slider(
64
+ "Top-p (nucleus sampling)",
65
+ value=0.90,
66
+ min_value=0.0,
67
+ max_value=1.0,
68
+ step=0.05,
69
+ help="Higher values sample more low-probability tokens",
70
+ )
71
+ repetition_penalty = st.slider(
72
+ "Repetition penalty",
73
+ value=1.2,
74
+ min_value=1.0,
75
+ max_value=2.0,
76
+ step=0.05,
77
+ help="Penalize repeated tokens",
78
+ )
79
+
80
+ message = st.text_input("Your message:", "")
81
+
82
+ if st.button("Generate"):
83
+ if message:
84
+ for output in generate(message, history, temperature, max_new_tokens, top_p, repetition_penalty):
85
+ st.text_area("Generated Text", value=output, height=400)
86
+ history.append((message, output))
87
+ else:
88
+ st.warning("Please enter a message.")