|
import os |
|
|
|
import streamlit as st |
|
import torch |
|
from huggingface_hub import login |
|
from peft import PeftModel |
|
from transformers import AutoModelForCausalLM, LlamaTokenizer |
|
|
|
login(token=os.getenv("HUGGINGFACE_API_KEY")) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
@st.cache_resource |
|
def load(): |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
"stabilityai/japanese-stablelm-instruct-alpha-7b", |
|
device_map="auto", |
|
low_cpu_mem_usage=True, |
|
variant="int8", |
|
load_in_8bit=True, |
|
trust_remote_code=True, |
|
) |
|
model = PeftModel.from_pretrained( |
|
base_model, |
|
"lora_adapter", |
|
device_map="auto", |
|
) |
|
tokenizer = LlamaTokenizer.from_pretrained( |
|
"novelai/nerdstash-tokenizer-v1", |
|
additional_special_tokens=['▁▁'] |
|
) |
|
return model, tokenizer |
|
|
|
def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "): |
|
prompt = system_prompt + "\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。" |
|
roles = ["指示", "応答"] |
|
msgs = [": \n" + user_query, ": "] |
|
if messages: |
|
roles.insert(1, "入力") |
|
msgs.insert(1, ": \n" + "\n\n".join(message["content"] for message in messages)) |
|
|
|
for role, msg in zip(roles, msgs): |
|
prompt += sep + role + msg |
|
return prompt |
|
|
|
def get_input_token_length(user_query, system_prompt, messages=""): |
|
prompt = get_prompt(user_query, system_prompt, messages) |
|
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids'] |
|
return input_ids.shape[-1] |
|
|
|
def generate_response(user_query: str, system_prompt: str, messages: str="", temperature: float=0, top_k: int=50, top_p: float=0.95, repetition_penalty: float=1.1): |
|
prompt = get_prompt(user_query, system_prompt, messages) |
|
inputs = tokenizer( |
|
prompt, |
|
add_special_tokens=False, |
|
return_tensors="pt" |
|
).to(device) |
|
max_new_tokens = 2048 - get_input_token_length(user_query, system_prompt, messages) |
|
model.eval() |
|
with torch.no_grad(): |
|
tokens = model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
) |
|
response = tokenizer.decode(tokens[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() |
|
return response |
|
|
|
|
|
st.header(":dna: 遺伝カウンセリング対話AI") |
|
|
|
|
|
|
|
model, tokenizer = load() |
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [] |
|
if "options" not in st.session_state: |
|
st.session_state["options"] = { |
|
"temperature": 0.0, |
|
"top_k": 50, |
|
"top_p": 0.95, |
|
"repetition_penalty": 1.1, |
|
"system_prompt": """あなたは誠実かつ優秀な遺伝子カウンセリングのカウンセラーです。 |
|
常に安全を考慮し、できる限り有益な回答を心がけてください。 |
|
あなたの回答には、有害、非倫理的、人種差別的、性差別的、有害、危険、違法な内容が含まれてはいけません。 |
|
社会的に偏りのない、前向きな回答を心がけてください。 |
|
質問が意味をなさない場合、または事実に一貫性がない場合は、正しくないことを答えるのではなく、その理由を説明してください。 |
|
質問の答えを知らない場合は、誤った情報を共有しないでください。"""} |
|
|
|
|
|
clear_chat = st.sidebar.button(":sparkles: 新しくチャットを始める", key="clear_chat") |
|
|
|
st.sidebar.header("Options") |
|
st.session_state["options"]["temperature"] = st.sidebar.slider("temperature", min_value=0.0, max_value=2.0, step=0.1, value=st.session_state["options"]["temperature"]) |
|
st.session_state["options"]["top_k"] = st.sidebar.slider("top_k", min_value=0, max_value=100, step=1, value=st.session_state["options"]["top_k"]) |
|
st.session_state["options"]["top_p"] = st.sidebar.slider("top_p", min_value=0.0, max_value=1.0, step=0.1, value=st.session_state["options"]["top_p"]) |
|
st.session_state["options"]["repetition_penalty"] = st.sidebar.slider("repetition_penalty", min_value=1.0, max_value=2.0, step=0.01, value=st.session_state["options"]["repetition_penalty"]) |
|
st.session_state["options"]["system_prompt"] = st.sidebar.text_area("System Prompt", value=st.session_state["options"]["system_prompt"]) |
|
|
|
|
|
if clear_chat: |
|
st.session_state["messages"] = [] |
|
|
|
|
|
for message in st.session_state["messages"]: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if user_prompt := st.chat_input("質問を送信してください"): |
|
with st.chat_message("user"): |
|
st.markdown(user_prompt) |
|
st.session_state["messages"].append({"role": "user", "content": user_prompt}) |
|
response = generate_response( |
|
user_query=user_prompt, |
|
system_prompt=st.session_state["options"]["system_prompt"], |
|
messages=st.session_state["messages"], |
|
temperature=st.session_state["options"]["temperature"], |
|
top_k=st.session_state["options"]["top_k"], |
|
top_p=st.session_state["options"]["top_p"], |
|
repetition_penalty=st.session_state["options"]["repetition_penalty"], |
|
) |
|
with st.chat_message("assistant"): |
|
st.markdown(response) |
|
st.session_state["messages"].append({"role": "assistant", "content": response}) |
|
|
|
|