File size: 4,451 Bytes
0bf9df3
 
 
 
 
 
388871f
b3f598f
 
 
 
 
 
 
 
 
 
 
 
 
388871f
0bf9df3
b3f598f
0bf9df3
 
 
20734fc
 
 
 
 
 
b3f598f
20734fc
 
 
 
 
 
 
 
 
 
0bf9df3
 
 
 
 
 
6d7bada
 
20734fc
6d7bada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18b3695
 
0bf9df3
 
 
 
 
 
 
6d7bada
 
 
 
 
 
 
 
 
 
 
 
 
 
20734fc
 
6d7bada
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
from transformers import AutoModelForCausalLM, LlamaTokenizer


@st.cache_resource
def load():
    """
    base_model = AutoModelForCausalLM.from_pretrained(
        "stabilityai/japanese-stablelm-instruct-alpha-7b", 
        device_map="auto",
        low_cpu_mem_usage=True,
        variant="int8",
        load_in_8bit=True,
        trust_remote_code=True,
        )
    model = PeftModel.from_pretrained(
        base_model, 
        "lora_adapter",
        device_map="auto",
        )
    """
    tokenizer = LlamaTokenizer.from_pretrained(
        "lora_adapter", 
        )
    return model, tokenizer

def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "):
    prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
    roles = ["指示", "応答"]
    msgs = [": \n" + user_query, ": "]
    if messages:
        roles.insert(1, "入力")
        msgs.insert(1, ": \n" + "\n".join(message["content"] for message in messages))

    for role, msg in zip(roles, msgs):
        prompt += sep + role + msg
    return prompt

def get_input_token_length(user_query, system_prompt, messages=""):
    prompt = get_prompt(user_query, system_prompt, messages)
    input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
    return input_ids.shape[-1]

def generate():
    pass


st.header(":dna: 遺伝カウンセリング対話AI")


# 初期化
model, tokenizer = load()
if "messages" not in st.session_state:
    st.session_state["messages"] = []
if "options" not in st.session_state:
    st.session_state["options"] = {
        "temperature": 0.0, 
        "top_k": 50, 
        "top_p": 0.95, 
        "repetition_penalty": 1.1,
        "system_prompt": """あなたは誠実かつ優秀な遺伝子カウンセリングのカウンセラーです。
常に安全を考慮し、できる限り有益な回答を心がけてください。
あなたの回答には、有害、非倫理的、人種差別的、性差別的、有害、危険、違法な内容が含まれてはいけません。
社会的に偏りのない、前向きな回答を心がけてください。
質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことを答えるのではなく、その理由を説明してください。
質問の答えを知らない場合は、誤った情報を共有しないでください。"""}

# サイドバー
clear_chat = st.sidebar.button(":sparkles: 新しくチャットを始める", key="clear_chat")

st.sidebar.header("Options")
st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"])
st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"])
st.session_state["options"]["top_p"] = st.sidebar.slider("top_p", min_value=0.0, max_value=1.0, step=0.1, value=st.session_state["options"]["top_p"])
st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"])
st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"])

# リセット
if clear_chat:
    st.session_state["messages"] = [] 

# チャット履歴の表示
for message in st.session_state["messages"]:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# 現在のチャット
if user_prompt := st.chat_input("質問を送信してください"):
    with st.chat_message("user"):
        st.text(user_prompt)
    st.session_state["messages"].append({"role": "user", "content": user_prompt})
    token_kength = get_input_token_length(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
    response = f"{token_kength}: " + get_prompt(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
    with st.chat_message("assistant"):
        st.text(response)
    st.session_state["messages"].append({"role": "assistant", "content": user_prompt})