add all streamlit
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ def load():
|
|
10 |
trust_remote_code=True,
|
11 |
)
|
12 |
"""
|
|
|
13 |
tokenizer = LlamaTokenizer.from_pretrained(
|
14 |
"novelai/nerdstash-tokenizer-v1",
|
15 |
additional_special_tokens=['▁▁'],
|
@@ -22,6 +23,24 @@ def generate():
|
|
22 |
|
23 |
st.header(":dna: 遺伝カウンセリング対話AI")
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
st.sidebar.header("Options")
|
26 |
st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"])
|
27 |
st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"])
|
@@ -29,4 +48,23 @@ st.session_state["options"]["top_p"] = st.sidebar.slider("top_p", min_value=0.0,
|
|
29 |
st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"])
|
30 |
st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"])
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
model, tokenizer = load()
|
|
|
10 |
trust_remote_code=True,
|
11 |
)
|
12 |
"""
|
13 |
+
model = None
|
14 |
tokenizer = LlamaTokenizer.from_pretrained(
|
15 |
"novelai/nerdstash-tokenizer-v1",
|
16 |
additional_special_tokens=['▁▁'],
|
|
|
23 |
|
24 |
st.header(":dna: 遺伝カウンセリング対話AI")
|
25 |
|
26 |
+
|
27 |
+
# 初期化
|
28 |
+
if "messages" not in st.session_state:
|
29 |
+
st.session_state["messages"] = []
|
30 |
+
if "options" not in st.session_state:
|
31 |
+
st.session_state["options"] = {
|
32 |
+
"temperature": 0.0,
|
33 |
+
"top_k": 50,
|
34 |
+
"top_p": 0.95,
|
35 |
+
"repetition_penalty": 1.1,
|
36 |
+
"system_prompt": """あなたは誠実かつ優秀な遺伝子カウンセリングのカウンセラーです。
|
37 |
+
常に安全を考慮し、できる限り有益な回答を心がけてください。
|
38 |
+
あなたの回答には、有害、非倫理的、人種差別的、性差別的、有害、危険、違法な内容が含まれてはいけません。
|
39 |
+
社会的に偏りのない、前向きな回答を心がけてください。
|
40 |
+
質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことを答えるのではなく、その理由を説明してください。
|
41 |
+
質問の答えを知らない場合は、誤った情報を共有しないでください。"""}
|
42 |
+
|
43 |
+
# サイドバー
|
44 |
st.sidebar.header("Options")
|
45 |
st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"])
|
46 |
st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"])
|
|
|
48 |
st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"])
|
49 |
st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"])
|
50 |
|
51 |
+
# リセット
|
52 |
+
if clear_chat:
|
53 |
+
st.session_state["messages"] = []
|
54 |
+
|
55 |
+
# チャット履歴の表示
|
56 |
+
for message in st.session_state["messages"]:
|
57 |
+
with st.chat_message(message["role"]):
|
58 |
+
st.markdown(message["content"])
|
59 |
+
|
60 |
+
# 現在のチャット
|
61 |
+
if user_prompt := st.chat_input("質問を送信してください"):
|
62 |
+
with st.chat_message("user"):
|
63 |
+
st.text(user_prompt)
|
64 |
+
st.session_state["messages"].append({"role": "user", "content": user_prompt})
|
65 |
+
response = None
|
66 |
+
with st.chat_message("assistant"):
|
67 |
+
st.text(response)
|
68 |
+
st.session_state["messages"].append({"role": "assistant", "content": user_prompt})
|
69 |
+
|
70 |
model, tokenizer = load()
|