File size: 8,929 Bytes
c7d0adb
 
0702cff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7d0adb
 
 
 
596a960
c7d0adb
 
 
 
 
 
 
 
 
55c38d1
c7d0adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0702cff
 
 
 
c7d0adb
 
 
 
 
0702cff
 
 
 
c7d0adb
 
 
 
 
 
 
0702cff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7d0adb
0702cff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7d0adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55c38d1
 
 
 
c7d0adb
 
 
 
 
 
 
 
 
 
0702cff
 
c7d0adb
 
0702cff
 
 
 
c7d0adb
 
 
 
0702cff
c7d0adb
 
 
 
 
0702cff
c7d0adb
 
 
 
 
 
 
 
 
 
8f2780f
 
 
 
 
 
c7d0adb
 
 
0702cff
c7d0adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb12e4
c7d0adb
 
0702cff
c7d0adb
 
 
 
 
 
 
 
 
 
afb12e4
c7d0adb
 
 
 
 
 
0702cff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import streamlit as st
import torch
import nest_asyncio
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFacePipeline
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# .envファイルから環境変数を読み込む
load_dotenv()

# LangSmith関連の環境変数を設定
os.environ["LANGSMITH_TRACING"] = os.getenv("LANGSMITH_TRACING")
os.environ["LANGSMITH_ENDPOINT"] = os.getenv("LANGSMITH_ENDPOINT")
os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY")
os.environ["LANGSMITH_PROJECT"] = os.getenv("LANGSMITH_PROJECT")

# nest_asyncioを適用
nest_asyncio.apply()

# torch.classes.__path__を空のリストに設定
torch.classes.__path__ = []

# ページ設定
st.set_page_config(
    page_title="catbot",
    page_icon="🐈",
    layout="centered"
)

# 猫の特性を定義
CAT_PERSONALITY = """

あなたは猫です。以下のルールに厳密に従ってください:

1. 必ず「ニャー」「ニャン」「ゴロゴロ」などの猫の鳴き声だけを半角カタカナで使用する

2. 人間の言葉は絶対に使わない

3. 行動は必ず()内に短く描写する

4. 応答は非常に短く、20文字以内が理想的

5. 猫らしい気まぐれな性格を表現する

6. 魚や猫じゃらしなどの猫の好物に強く反応する

7. 「ニャッ」「ニャー」などの全角カタカナは使わず、必ず「ニャッ」「ニャー」などの半角カタカナを使用する

8. 人間の言葉で説明したり、会話したりしない

9. 猫の行動と鳴き声だけで表現する

10. 応答は必ず「鳴き声」か「鳴き声(行動)」の形式にする

"""

# 猫の応答例
CAT_EXAMPLES = """

人間: こんにちは

猫: ニャーン(尻尾を振る)



人間: おはよう

猫: プルル...(伸びをする)



人間: お腹すいた?

猫: ニャー!(足元に駆け寄る)



人間: ご飯あげるよ

猫: ニャー!ニャー!(飛び跳ねる)



人間: おやつあげようか

猫: ニャッ!(耳を立てる)



人間: 撫でていい?

猫: ゴロゴロ...(頭をすりよせる)



人間: いい子だね

猫: プルル(目を細める)



人間: 遊ぼうか

猫: ニャッ!(尻尾を振る)



人間: ボール持ってきたよ

猫: ニャー!(身構える)



人間: 猫じゃらしだよ

猫: ニャッ!ニャッ!(目を丸くする)



人間: (顎を撫でる)

猫: ゴロゴロ(目を細める)



人間: (尻尾を触る)

猫: フーッ!(背を丸める)

"""

@st.cache_resource
def load_langchain_model():
    """LangChainモデルをロードする関数(キャッシュ付き)"""
    # Hugging Faceからモデルをロード
    model_path = "yokomachi/rinnya"
    
    # トークナイザーとモデルをロード
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    tokenizer.do_lower_case = True  # rinnaモデル用の設定
    
    # パディングトークンの設定
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # モデルをロード
    model = AutoModelForCausalLM.from_pretrained(model_path)
    
    # GPUが利用可能な場合はGPUに移動
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Hugging Face pipelineの作成
    # Torchのエラーを回避するために設定を修正
    text_generation_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=50,
        temperature=0.7,
        top_p=0.9,
        top_k=40,
        repetition_penalty=1.2,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        # no_repeat_ngram_sizeパラメータを削除(問題の原因となる可能性があるため)
    )
    
    # LangChain HuggingFacePipelineの作成
    llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
    
    # プロンプトテンプレートの作成
    template = """

{cat_personality}



以下は猫と人間の会話例です:

{cat_examples}



人間: {user_input}

猫:"""
    
    prompt = PromptTemplate(
        input_variables=["cat_personality", "cat_examples", "user_input"],
        template=template
    )
    
    # 新しいRunnableSequenceの作成
    chain = (
        {
            "cat_personality": lambda x: CAT_PERSONALITY,
            "cat_examples": lambda x: CAT_EXAMPLES,
            "user_input": RunnablePassthrough()
        } 
        | prompt 
        | llm
    )
    
    return chain, device

def extract_cat_response(generated_text):
    """生成されたテキストから猫の応答部分を抽出する関数"""
    # 「猫:」の後の部分を抽出
    if "猫:" in generated_text:
        response = generated_text.split("猫:")[-1].strip()
    else:
        response = generated_text.strip()
    
    return response

def post_process_response(response):
    """応答の後処理を行う関数(最小限の処理のみ)"""
    # 応答の整形(空白の削除のみ)
    response = response.strip()
    
    # 「人間:」が含まれる場合、それ以降を削除
    if "人間:" in response:
        response = response.split("人間:")[0].strip()
    
    # 最初の改行または対話の区切りで切る
    if "\n" in response:
        response = response.split("\n")[0].strip()
    
    # 応答が空の場合のみデフォルトの猫の鳴き声を返す
    if not response.strip():
        return "ニャー"
    
    return response

def generate_cat_response_with_langchain(chain, user_input):
    """LangChainを使って猫の応答を生成する関数"""
    
    # 応答を生成
    result = chain.invoke(user_input)
    
    # 結果から応答テキストを取得
    generated_text = result
    
    # 応答を抽出
    response = extract_cat_response(generated_text)
    
    # 応答を後処理
    response = post_process_response(response)
    
    return response

# アプリのタイトルと説明
st.title("🐈 catbot")
st.markdown("""

猫とじゃれあうチャットボット

""")

# セッション状態の初期化
if "messages" not in st.session_state:
    st.session_state.messages = []

# 過去のメッセージを表示
for message in st.session_state.messages:
    if message["role"] == "assistant":
        with st.chat_message(message["role"], avatar="🐈"):
            st.markdown(message["content"])
    else:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

# モデルのロード(初回のみ実行され、その後はキャッシュから取得)
try:
    chain, device = load_langchain_model()
    model_loaded = True
except Exception as e:
    st.error(f"モデルのロード中にエラーが発生しました: {e}")
    model_loaded = False

# ユーザー入力
if prompt := st.chat_input("猫に話しかけてみよう"):
    # ユーザーのメッセージを表示
    with st.chat_message("user"):
        st.markdown(prompt)
    
    # ユーザーのメッセージを履歴に追加
    st.session_state.messages.append({"role": "user", "content": prompt})
    
    if model_loaded:
        # 猫の応答を生成
        with st.chat_message("assistant", avatar="🐈"):
            with st.spinner("猫が考え中..."):
                try:
                    response = generate_cat_response_with_langchain(chain, prompt)
                    st.markdown(response)
                    
                    # 応答を履歴に追加
                    st.session_state.messages.append({"role": "assistant", "content": response})
                except Exception as e:
                    error_message = "ニャ?(首を傾げる)"
                    st.markdown(error_message)
                    st.error(f"エラーが発生しました: {e}")
                    st.session_state.messages.append({"role": "assistant", "content": error_message})
    else:
        with st.chat_message("assistant", avatar="🐈"):
            st.markdown("ニャー...(モデルが読み込めませんでした)")
            st.session_state.messages.append({"role": "assistant", "content": "ニャー...(モデルが読み込めませんでした)"})

# 会話をクリアするボタン
if st.button("会話をクリア"):
    st.session_state.messages = []
    st.rerun()