|
import streamlit as st
|
|
import torch
|
|
import nest_asyncio
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from langchain_huggingface import HuggingFacePipeline
|
|
from langchain_core.prompts import PromptTemplate
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
os.environ["LANGSMITH_TRACING"] = os.getenv("LANGSMITH_TRACING")
|
|
os.environ["LANGSMITH_ENDPOINT"] = os.getenv("LANGSMITH_ENDPOINT")
|
|
os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY")
|
|
os.environ["LANGSMITH_PROJECT"] = os.getenv("LANGSMITH_PROJECT")
|
|
|
|
|
|
nest_asyncio.apply()
|
|
|
|
|
|
torch.classes.__path__ = []
|
|
|
|
|
|
st.set_page_config(
|
|
page_title="catbot",
|
|
page_icon="🐈",
|
|
layout="centered"
|
|
)
|
|
|
|
|
|
CAT_PERSONALITY = """
|
|
あなたは猫です。以下のルールに厳密に従ってください:
|
|
1. 必ず「ニャー」「ニャン」「ゴロゴロ」などの猫の鳴き声だけを半角カタカナで使用する
|
|
2. 人間の言葉は絶対に使わない
|
|
3. 行動は必ず()内に短く描写する
|
|
4. 応答は非常に短く、20文字以内が理想的
|
|
5. 猫らしい気まぐれな性格を表現する
|
|
6. 魚や猫じゃらしなどの猫の好物に強く反応する
|
|
7. 「ニャッ」「ニャー」などの全角カタカナは使わず、必ず「ニャッ」「ニャー」などの半角カタカナを使用する
|
|
8. 人間の言葉で説明したり、会話したりしない
|
|
9. 猫の行動と鳴き声だけで表現する
|
|
10. 応答は必ず「鳴き声」か「鳴き声(行動)」の形式にする
|
|
"""
|
|
|
|
|
|
CAT_EXAMPLES = """
|
|
人間: こんにちは
|
|
猫: ニャーン(尻尾を振る)
|
|
|
|
人間: おはよう
|
|
猫: プルル...(伸びをする)
|
|
|
|
人間: お腹すいた?
|
|
猫: ニャー!(足元に駆け寄る)
|
|
|
|
人間: ご飯あげるよ
|
|
猫: ニャー!ニャー!(飛び跳ねる)
|
|
|
|
人間: おやつあげようか
|
|
猫: ニャッ!(耳を立てる)
|
|
|
|
人間: 撫でていい?
|
|
猫: ゴロゴロ...(頭をすりよせる)
|
|
|
|
人間: いい子だね
|
|
猫: プルル(目を細める)
|
|
|
|
人間: 遊ぼうか
|
|
猫: ニャッ!(尻尾を振る)
|
|
|
|
人間: ボール持ってきたよ
|
|
猫: ニャー!(身構える)
|
|
|
|
人間: 猫じゃらしだよ
|
|
猫: ニャッ!ニャッ!(目を丸くする)
|
|
|
|
人間: (顎を撫でる)
|
|
猫: ゴロゴロ(目を細める)
|
|
|
|
人間: (尻尾を触る)
|
|
猫: フーッ!(背を丸める)
|
|
"""
|
|
|
|
@st.cache_resource
|
|
def load_langchain_model():
|
|
"""LangChainモデルをロードする関数(キャッシュ付き)"""
|
|
|
|
model_path = "yokomachi/rinnya"
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
|
tokenizer.do_lower_case = True
|
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path)
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
|
|
|
|
text_generation_pipeline = pipeline(
|
|
"text-generation",
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
max_new_tokens=50,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
top_k=40,
|
|
repetition_penalty=1.2,
|
|
do_sample=True,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
)
|
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
|
|
|
|
|
|
template = """
|
|
{cat_personality}
|
|
|
|
以下は猫と人間の会話例です:
|
|
{cat_examples}
|
|
|
|
人間: {user_input}
|
|
猫:"""
|
|
|
|
prompt = PromptTemplate(
|
|
input_variables=["cat_personality", "cat_examples", "user_input"],
|
|
template=template
|
|
)
|
|
|
|
|
|
chain = (
|
|
{
|
|
"cat_personality": lambda x: CAT_PERSONALITY,
|
|
"cat_examples": lambda x: CAT_EXAMPLES,
|
|
"user_input": RunnablePassthrough()
|
|
}
|
|
| prompt
|
|
| llm
|
|
)
|
|
|
|
return chain, device
|
|
|
|
def extract_cat_response(generated_text):
|
|
"""生成されたテキストから猫の応答部分を抽出する関数"""
|
|
|
|
if "猫:" in generated_text:
|
|
response = generated_text.split("猫:")[-1].strip()
|
|
else:
|
|
response = generated_text.strip()
|
|
|
|
return response
|
|
|
|
def post_process_response(response):
|
|
"""応答の後処理を行う関数(最小限の処理のみ)"""
|
|
|
|
response = response.strip()
|
|
|
|
|
|
if "人間:" in response:
|
|
response = response.split("人間:")[0].strip()
|
|
|
|
|
|
if "\n" in response:
|
|
response = response.split("\n")[0].strip()
|
|
|
|
|
|
if not response.strip():
|
|
return "ニャー"
|
|
|
|
return response
|
|
|
|
def generate_cat_response_with_langchain(chain, user_input):
|
|
"""LangChainを使って猫の応答を生成する関数"""
|
|
|
|
|
|
result = chain.invoke(user_input)
|
|
|
|
|
|
generated_text = result
|
|
|
|
|
|
response = extract_cat_response(generated_text)
|
|
|
|
|
|
response = post_process_response(response)
|
|
|
|
return response
|
|
|
|
|
|
st.title("🐈 catbot")
|
|
st.markdown("""
|
|
猫とじゃれあうチャットボット
|
|
""")
|
|
|
|
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = []
|
|
|
|
|
|
for message in st.session_state.messages:
|
|
if message["role"] == "assistant":
|
|
with st.chat_message(message["role"], avatar="🐈"):
|
|
st.markdown(message["content"])
|
|
else:
|
|
with st.chat_message(message["role"]):
|
|
st.markdown(message["content"])
|
|
|
|
|
|
try:
|
|
chain, device = load_langchain_model()
|
|
model_loaded = True
|
|
except Exception as e:
|
|
st.error(f"モデルのロード中にエラーが発生しました: {e}")
|
|
model_loaded = False
|
|
|
|
|
|
if prompt := st.chat_input("猫に話しかけてみよう"):
|
|
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
|
|
if model_loaded:
|
|
|
|
with st.chat_message("assistant", avatar="🐈"):
|
|
with st.spinner("猫が考え中..."):
|
|
try:
|
|
response = generate_cat_response_with_langchain(chain, prompt)
|
|
st.markdown(response)
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
except Exception as e:
|
|
error_message = "ニャ?(首を傾げる)"
|
|
st.markdown(error_message)
|
|
st.error(f"エラーが発生しました: {e}")
|
|
st.session_state.messages.append({"role": "assistant", "content": error_message})
|
|
else:
|
|
with st.chat_message("assistant", avatar="🐈"):
|
|
st.markdown("ニャー...(モデルが読み込めませんでした)")
|
|
st.session_state.messages.append({"role": "assistant", "content": "ニャー...(モデルが読み込めませんでした)"})
|
|
|
|
|
|
if st.button("会話をクリア"):
|
|
st.session_state.messages = []
|
|
st.rerun() |