File size: 8,929 Bytes
c7d0adb 0702cff c7d0adb 596a960 c7d0adb 55c38d1 c7d0adb 0702cff c7d0adb 0702cff c7d0adb 0702cff c7d0adb 0702cff c7d0adb 55c38d1 c7d0adb 0702cff c7d0adb 0702cff c7d0adb 0702cff c7d0adb 0702cff c7d0adb 8f2780f c7d0adb 0702cff c7d0adb afb12e4 c7d0adb 0702cff c7d0adb afb12e4 c7d0adb 0702cff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
import streamlit as st
import torch
import nest_asyncio
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFacePipeline
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# .envファイルから環境変数を読み込む
load_dotenv()
# LangSmith関連の環境変数を設定
os.environ["LANGSMITH_TRACING"] = os.getenv("LANGSMITH_TRACING")
os.environ["LANGSMITH_ENDPOINT"] = os.getenv("LANGSMITH_ENDPOINT")
os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY")
os.environ["LANGSMITH_PROJECT"] = os.getenv("LANGSMITH_PROJECT")
# nest_asyncioを適用
nest_asyncio.apply()
# torch.classes.__path__を空のリストに設定
torch.classes.__path__ = []
# ページ設定
st.set_page_config(
page_title="catbot",
page_icon="🐈",
layout="centered"
)
# 猫の特性を定義
CAT_PERSONALITY = """
あなたは猫です。以下のルールに厳密に従ってください:
1. 必ず「ニャー」「ニャン」「ゴロゴロ」などの猫の鳴き声だけを半角カタカナで使用する
2. 人間の言葉は絶対に使わない
3. 行動は必ず()内に短く描写する
4. 応答は非常に短く、20文字以内が理想的
5. 猫らしい気まぐれな性格を表現する
6. 魚や猫じゃらしなどの猫の好物に強く反応する
7. 「ニャッ」「ニャー」などの全角カタカナは使わず、必ず「ニャッ」「ニャー」などの半角カタカナを使用する
8. 人間の言葉で説明したり、会話したりしない
9. 猫の行動と鳴き声だけで表現する
10. 応答は必ず「鳴き声」か「鳴き声(行動)」の形式にする
"""
# 猫の応答例
CAT_EXAMPLES = """
人間: こんにちは
猫: ニャーン(尻尾を振る)
人間: おはよう
猫: プルル...(伸びをする)
人間: お腹すいた?
猫: ニャー!(足元に駆け寄る)
人間: ご飯あげるよ
猫: ニャー!ニャー!(飛び跳ねる)
人間: おやつあげようか
猫: ニャッ!(耳を立てる)
人間: 撫でていい?
猫: ゴロゴロ...(頭をすりよせる)
人間: いい子だね
猫: プルル(目を細める)
人間: 遊ぼうか
猫: ニャッ!(尻尾を振る)
人間: ボール持ってきたよ
猫: ニャー!(身構える)
人間: 猫じゃらしだよ
猫: ニャッ!ニャッ!(目を丸くする)
人間: (顎を撫でる)
猫: ゴロゴロ(目を細める)
人間: (尻尾を触る)
猫: フーッ!(背を丸める)
"""
@st.cache_resource
def load_langchain_model():
"""LangChainモデルをロードする関数(キャッシュ付き)"""
# Hugging Faceからモデルをロード
model_path = "yokomachi/rinnya"
# トークナイザーとモデルをロード
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
tokenizer.do_lower_case = True # rinnaモデル用の設定
# パディングトークンの設定
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# モデルをロード
model = AutoModelForCausalLM.from_pretrained(model_path)
# GPUが利用可能な場合はGPUに移動
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Hugging Face pipelineの作成
# Torchのエラーを回避するために設定を修正
text_generation_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=50,
temperature=0.7,
top_p=0.9,
top_k=40,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
# no_repeat_ngram_sizeパラメータを削除(問題の原因となる可能性があるため)
)
# LangChain HuggingFacePipelineの作成
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
# プロンプトテンプレートの作成
template = """
{cat_personality}
以下は猫と人間の会話例です:
{cat_examples}
人間: {user_input}
猫:"""
prompt = PromptTemplate(
input_variables=["cat_personality", "cat_examples", "user_input"],
template=template
)
# 新しいRunnableSequenceの作成
chain = (
{
"cat_personality": lambda x: CAT_PERSONALITY,
"cat_examples": lambda x: CAT_EXAMPLES,
"user_input": RunnablePassthrough()
}
| prompt
| llm
)
return chain, device
def extract_cat_response(generated_text):
"""生成されたテキストから猫の応答部分を抽出する関数"""
# 「猫:」の後の部分を抽出
if "猫:" in generated_text:
response = generated_text.split("猫:")[-1].strip()
else:
response = generated_text.strip()
return response
def post_process_response(response):
"""応答の後処理を行う関数(最小限の処理のみ)"""
# 応答の整形(空白の削除のみ)
response = response.strip()
# 「人間:」が含まれる場合、それ以降を削除
if "人間:" in response:
response = response.split("人間:")[0].strip()
# 最初の改行または対話の区切りで切る
if "\n" in response:
response = response.split("\n")[0].strip()
# 応答が空の場合のみデフォルトの猫の鳴き声を返す
if not response.strip():
return "ニャー"
return response
def generate_cat_response_with_langchain(chain, user_input):
"""LangChainを使って猫の応答を生成する関数"""
# 応答を生成
result = chain.invoke(user_input)
# 結果から応答テキストを取得
generated_text = result
# 応答を抽出
response = extract_cat_response(generated_text)
# 応答を後処理
response = post_process_response(response)
return response
# アプリのタイトルと説明
st.title("🐈 catbot")
st.markdown("""
猫とじゃれあうチャットボット
""")
# セッション状態の初期化
if "messages" not in st.session_state:
st.session_state.messages = []
# 過去のメッセージを表示
for message in st.session_state.messages:
if message["role"] == "assistant":
with st.chat_message(message["role"], avatar="🐈"):
st.markdown(message["content"])
else:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# モデルのロード(初回のみ実行され、その後はキャッシュから取得)
try:
chain, device = load_langchain_model()
model_loaded = True
except Exception as e:
st.error(f"モデルのロード中にエラーが発生しました: {e}")
model_loaded = False
# ユーザー入力
if prompt := st.chat_input("猫に話しかけてみよう"):
# ユーザーのメッセージを表示
with st.chat_message("user"):
st.markdown(prompt)
# ユーザーのメッセージを履歴に追加
st.session_state.messages.append({"role": "user", "content": prompt})
if model_loaded:
# 猫の応答を生成
with st.chat_message("assistant", avatar="🐈"):
with st.spinner("猫が考え中..."):
try:
response = generate_cat_response_with_langchain(chain, prompt)
st.markdown(response)
# 応答を履歴に追加
st.session_state.messages.append({"role": "assistant", "content": response})
except Exception as e:
error_message = "ニャ?(首を傾げる)"
st.markdown(error_message)
st.error(f"エラーが発生しました: {e}")
st.session_state.messages.append({"role": "assistant", "content": error_message})
else:
with st.chat_message("assistant", avatar="🐈"):
st.markdown("ニャー...(モデルが読み込めませんでした)")
st.session_state.messages.append({"role": "assistant", "content": "ニャー...(モデルが読み込めませんでした)"})
# 会話をクリアするボタン
if st.button("会話をクリア"):
st.session_state.messages = []
st.rerun() |