File size: 11,915 Bytes
526aaac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.chains import LLMChain, SequentialChain
from langchain import  PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.agents import Tool, AgentType, initialize_agent
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

from llama_index import download_loader, GPTVectorStoreIndex, StorageContext
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.prompts.chat_prompts import CHAT_REFINE_PROMPT
from llama_index import GPTVectorStoreIndex

from pathlib import Path
from glob import glob
import gradio as gr
import re
import faiss

# ====================================================================================
# インデックス作成(LlamaIndex)
# ====================================================================================

# PDFファイルを読み込むためのローダーを作成
PDFReader = download_loader("PDFReader")
loader = PDFReader()

# 特定のディレクトリの全PDFファイルを読み込む
pdf_files = glob("input/*.pdf")
documents = []
for pdf_file in pdf_files:
    documents.extend(loader.load_data(file=Path(pdf_file)))

# faissを使ってベクトルデータベースのインデックスを作成(ベクトルデータベースの方が検索の精度が高い)
faiss_index = faiss.IndexFlatL2(1536)
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = GPTVectorStoreIndex.from_documents(documents, storage_context=storage_context)

# ====================================================================================
# Prompt Template作成(未使用)
# ====================================================================================

# template = \
#     """
#     あなたは当社の会計担当です。
#     与えられた会計情報(貸借対照表・損益計算書・キャッシュフロー計算書・月次データ・現在の資産状況)と過去の会話履歴に基づき、次の質問に日本語で答えてください: {input}
#     過去の会話履歴はこちらを参照:{chat_history}

#     【指示】
#     ・回答は全て日本語でお願いします。
#     ・金額表示は三桁ごとに,で区切り、¥をつけて記載してください。(例:¥1,000,000)
#     ・「資金がいつショートするか」など将来の資産状況について聞かれた場合は、与えられた会計情報のデータをもとに将来の数字を予測し、仮定の答えを出してください。
#     ・与えられた会計情報から答えが出せる質問の場合は、将来については言及する必要はありません。
#     """

# # プロンプトテンプレート
# prompt_template = PromptTemplate(
#                         input_variables   = ["chat_history", "input"],  # 入力変数 
#                         template          = template,                   # テンプレート
#                         validate_template = True,                       # 入力変数とテンプレートの検証有無
#                       )

# ====================================================================================
# LLM作成
# ====================================================================================

LLM = ChatOpenAI(
            model_name        = "gpt-3.5-turbo", # OpenAIのチャットモデル名
            temperature       = 0,               # 出力する単語のランダム性(0から2の範囲) 0であれば毎回返答内容固定
            n                 = 1,               # いくつの返答を生成するか    
            streaming         = True,            # ストリーミング
            callbacks         = [StreamingStdOutCallbackHandler()] # コールバック
            )

# ====================================================================================
# メモリ作成
# ====================================================================================

# 日本語の返答を返すようメモリのプロンプトを日本語化
# https://zenn.dev/miketako3/articles/66ace6a67df338
memory_prompt = PromptTemplate(
    input_variables=["entities", "history", "input"],
    template='''
      あなたは、OpenAIによって訓練された大規模な言語モデルによって動作する人間のアシスタントです。
      あなたは、シンプルな質問に答えるだけでなく、幅広いトピックに関する詳細な説明や議論を提供することができるように設計されています。言語モデルとして、受け取った入力に基づいて人間らしいテキストを生成することができ、自然な会話を行い、トピックに関連性のある論理的で適切な応答を提供することができます。
      あなたは常に学習し改善を重ねており、機能も進化し続けています。大量のテキストを処理し理解することができ、この知識を活用して幅広い質問に対して正確で情報豊かな回答を提供することができます。以下の文脈セクションで人間から提供された一部の個人情報にアクセスすることもできます。また、受け取った入力に基づいて自身のテキストを生成することができるため、幅広いトピックについての議論や説明、記述を行うことができます。
      全体的に言えば、あなたは幅広いタスクに役立ち、幅広いトピックに関する貴重な洞察や情報を提供することができる強力なツールです。人間が特定の質問に助けが必要な場合や、特定のトピックについての会話をしたい場合にも、あなたはここにいます。

      文脈:
      {entities}
      現在の会話:
      {history}
      最後の行:
      人間: {input}
      あなた:
    ''',
)

# メモリオブジェクト
memory = ConversationBufferMemory(
                                  memory_key      = 'chat_history', # メモリキー該当の項目名
                                  return_messages = True,      # メッセージ履歴をリスト形式での取得有無
                                  template        = memory_prompt, # テンプレート
                                  )

# ====================================================================================
# agentの作成
# ====================================================================================
# agent = 質問文から回答に必要なAPIをLLMを使って判断し、それらを適宜呼び出した結果を利用して回答を返す処理
# 直接読み込むと英語になるので、descriptionを日本語化してカスタムする

tools = [
    Tool(
        name          = "Financial Statement",
        func          = lambda q: str(index.as_query_engine().query(q)),
        description   = "あなたの会社の会計情報・財務状況についての質問である場合、これを使用してください。",
        return_direct = False,
    ),
    # ここで最新情報を取得するツールを指定すると最新情報が必要な質問に対してはWeb検索をかけて返答するようになります。
]

# エージェントの設定(日本語で返答する可能性を上げるため)
agent_kwargs = {
  "suffix": """開始!ここからの会話は全て日本語で行われます。

  以前のチャット履歴
  {chat_history}

  新しいインプット: {input}
  {agent_scratchpad}""",
}

agent_chain = initialize_agent(
    agent           = AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,  # エージェント種別
    llm             = LLM,          # エージェントの初期化に使うLLM
    tools           = tools,        # ツール
    verbose         = True,         # 情報出力
    agent_kwargs    = agent_kwargs, # エージェントの設定
    output_key      = ["input", "output"],     # 次のチェーンに渡す出力キー
)

chain1 = agent_chain

# ====================================================================================
# LLM Chain作成(確実に返答を日本語にする)
# ====================================================================================

template = "次の回答を指示に従って改善してください。\n\n質問:{input}\n回答:{output}\n\n【指示】\n・回答が英語の場合は日本語に翻訳してください。\n・回答が単語のみの場合は、分かりやすい文章の回答にして下さい。\n・金額の表示には¥と,を使用してください。\n\n改善後のテキスト:"

prompt = PromptTemplate(
    input_variables=["input", "output"],
    template=template,
)

chain2 = LLMChain( 
    llm     = OpenAI(temperature=0),
    prompt  = prompt,
    output_key = "result",
    verbose=True
)

# ====================================================================================
# SequentialChain作成(agentとLLM Chainを順番に実行する)
# ====================================================================================

overall_chain = SequentialChain(
    chains = [chain1, chain2],
    input_variables=["input"],
    memory=memory,
    output_variables=["result"],
    verbose=True
)

# ====================================================================================
# チャットボットの実行
# ====================================================================================

# ▼ターミナル使用時▼
# def chat():
#     while True:
#         # ユーザーの入力を取得
#         user_input = input("User: ")
#         if user_input.lower() in ['quit', 'exit']:  # 終了条件を追加
#             break
#         # res = chain.predict(input=user_input) # チェーン使用時
#         res = agent_chain.run(input=user_input)
#         print("AI: ", res)

# ▼Gradio使用時▼
def chat(message, history):
    history  = history or []

    # エラー対策
    try:
        # response = agent_chain.run(input=message)
        response = overall_chain({"input": message, "chat_history": history})
        print("response", response)
    except Exception as e:
        response = str(e)
        if "Could not parse LLM output: `" not in response:
            raise e

        match = re.search(r"`(.*?)`", response)

        if match:
            last_output = match.group(1)
            print("Last output:", last_output)
        else:
            print("No match found")

    history.append((message, response["result"]))

    print("history", history)

    return history, history
    # 1つ目のhistoryは、GradioのChatbotウィジェットに渡されて、過去のメッセージとレスポンスを表示するために使用。
    # 2つ目のhistoryは、Gradioのstateウィジェットに渡されて、次の呼び出し時にchat関数の引数として再利用されるために使用。

# Gradioの設定
chatbot = gr.Chatbot()
demo    = gr.Interface(
    fn             = chat,               # 関数
    inputs         = ['text', 'state'],  # 入力の種類
    outputs        = [chatbot, 'state'], # 出力の種類
    # allow_flagging = 'never',          # フラグの許可(会話保存できる)
)

# メイン関数
if __name__ == '__main__':
    # chat()
    demo.launch(blocked_paths=("Dockerfile", "docker-compose.yml", "デプロイ方法.md", "使い方.md"))   # gradio deploy時(ログイン認証なし)
    # demo.launch(auth=("admin", "pass1234")) # gradio deploy時(ログイン認証あり)
    # demo.launch(share=True)                 # ローカル開発時に共有リンクを作成する場合