Spaces:
Runtime error
Runtime error
shigeru saito
Merge branch 'main' of https://huggingface.co/spaces/shigel/langchain-function-calling
629f973
| # 必要なモジュールをインポート | |
| import gradio as gr | |
| import os | |
| import sys | |
| import json | |
| import csv | |
| import dotenv | |
| import openai | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.agents import initialize_agent, Tool | |
| from langchain.schema import ( | |
| AIMessage, | |
| AgentAction, | |
| HumanMessage, | |
| FunctionMessage | |
| ) | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.agents import AgentType | |
| # .envファイルから環境変数をロード | |
| dotenv.load_dotenv(".env") | |
| # OpenAIキーをosモジュールで取得 | |
| openai.api_key = os.environ.get("OPENAI_API_KEY") | |
| # 民間伝承を取得する関数 | |
| def fetch_folklore(location): | |
| folklore_lookup = {} | |
| # CSVファイルからデータを読み取り、地点をキー、伝承を値とする辞書を作成 | |
| with open('folklore.csv', 'r') as f: | |
| reader = csv.DictReader(f) | |
| folklore_lookup = {row['location']: row['folklore'] for row in reader} | |
| type_lookup = {row['type']: row['folklore'] for row in reader} | |
| # 指定された地点の伝承などを返す。存在しない場合は不明を返す。 | |
| folklore = folklore_lookup.get((location), f"その地域の伝承は不明です。") | |
| type = type_lookup.get((location), f"その地域の伝承は不明です。") | |
| print("type:", type) | |
| return folklore | |
| def serialize_agent_action(obj): | |
| if isinstance(obj, AgentAction): | |
| return { "tool": obj.tool, "tool_input": obj.tool_input, "log": obj.log} | |
| if isinstance(obj, _FunctionsAgentAction): | |
| return { "tool": obj.tool, "tool_input": obj.tool_input, "log": obj.log, "message_log": obj.message_log} | |
| if isinstance(obj, AIMessage): | |
| return { "content": obj.content, "additional_kwargs": obj.additional_kwargs, "example": obj.example} | |
| raise TypeError(f"Type {type(obj)} not serializable") | |
| # LangChainエージェントからレスポンスを取得する関数 | |
| def get_response_from_lang_chain_agent(query_text): | |
| # ChatOpenAIを使用して言語モデルを初期化 | |
| language_model = ChatOpenAI(model_name='gpt-3.5-turbo-0613') | |
| tools = [ | |
| # 民間伝承を取得するToolを作成 | |
| Tool( | |
| name="Folklore", | |
| func=fetch_folklore, | |
| description="伝承を知りたい施設や地名を入力。例: 箱根", | |
| ) | |
| ] | |
| # エージェントを初期化してから応答を取得 | |
| agent = initialize_agent(tools, language_model, agent="zero-shot-react-description", | |
| verbose=True, return_intermediate_steps=True) | |
| response = agent({"input": query_text}) | |
| print(type(response)) | |
| response = json.dumps(response, default=serialize_agent_action, indent=2, ensure_ascii=False) | |
| return response | |
| # Function Callingからレスポンスを取得する関数 | |
| def get_response_from_function_calling(query_text): | |
| function_definitions = [ | |
| # 関数の定義を作成 | |
| { | |
| "name": "fetch_folklore", | |
| "description": "伝承を調べる", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": { | |
| "description": "伝承を知りたい施設や地名。例: 箱根", | |
| }, | |
| }, | |
| "required": ["location"], | |
| }, | |
| } | |
| ] | |
| messages = [HumanMessage(content=query_text)] | |
| language_model = ChatOpenAI(model_name='gpt-4') | |
| # 言語モデルを使ってメッセージを予測 | |
| message = language_model.predict_messages( | |
| messages, functions=function_definitions) | |
| if message.additional_kwargs: | |
| # 関数の名前と引数を取得 | |
| function_name = message.additional_kwargs["function_call"]["name"] | |
| arguments = message.additional_kwargs["function_call"]["arguments"] | |
| # JSON 文字列を辞書に変換 | |
| arguments = json.loads(arguments) | |
| location=arguments.get("location") | |
| # type=arguments.get("type") | |
| # 関数を実行してレスポンスを取得 | |
| function_response = fetch_folklore(location=location) | |
| # 関数メッセージを作成 | |
| function_message = FunctionMessage( | |
| name=function_name, content=function_response) | |
| # 関数のレスポンスをメッセージに追加して予測 | |
| messages.append(function_message) | |
| second_response = language_model.predict_messages( | |
| messages=messages, functions=function_definitions) | |
| content = second_response.content | |
| else: | |
| content = message.content | |
| return content | |
| # Function Call Agentからレスポンスを取得する関数 | |
| def get_response_from_function_calling_agent(query_text): | |
| language_model = ChatOpenAI(model_name='gpt-3.5-turbo-0613') | |
| tools = [ | |
| # 民間伝承情報を提供するツールの追加 | |
| Tool( | |
| name="Folklore", | |
| func=fetch_folklore, | |
| description="伝承を知りたい施設や地名を入力。例: 箱根" | |
| ) | |
| ] | |
| # エージェントの初期化とレスポンスの取得 | |
| agent = initialize_agent(tools, language_model, agent=AgentType.OPENAI_FUNCTIONS, | |
| verbose=True, return_intermediate_steps=True) | |
| response = agent({"input": query_text}) | |
| response = json.dumps(response, default=serialize_agent_action, indent=2, ensure_ascii=False) | |
| return response | |
| # メインの実行部分 | |
| def main(query_text, function_name="all"): | |
| response1 = "" | |
| response2 = "" | |
| response3 = "" | |
| if function_name == "all" or function_name == "langchain": | |
| # LangChainエージェントからのレスポンス | |
| response1 = get_response_from_lang_chain_agent(query_text) | |
| print(response1) | |
| if function_name == "all" or function_name == "functioncalling": | |
| # Function Callingからのレスポンス | |
| response2 = get_response_from_function_calling(query_text) | |
| print(response2) | |
| if function_name == "all" or function_name == "functioncallingagent": | |
| # Function Callingエージェントからのレスポンス | |
| response3 = get_response_from_function_calling_agent(query_text) | |
| print(response3) | |
| return response1, response2, response3 | |
| # スクリプトが直接実行された場合にmain()を実行 | |
| if __name__ == "__main__": | |
| if len(sys.argv) == 2: | |
| query_text = sys.argv[1] | |
| main(query_text=query_text) | |
| elif len(sys.argv) > 2: | |
| query_text = sys.argv[1] | |
| function_name = sys.argv[2] | |
| main(query_text=query_text, function_name=function_name) | |
| else: | |
| import time | |
| # インプット例をクリックした時のコールバック関数 | |
| def click_example(example): | |
| # クリックされたインプット例をテキストボックスに自動入力 | |
| inputs.value = example | |
| time.sleep(0.1) # テキストボックスに文字が表示されるまで待機 | |
| # 自動入力後に実行ボタンをクリックして結果を表示 | |
| execute_button.click() | |
| # gr.Interface()を使ってユーザーインターフェースを作成します | |
| # gr.Text()はテキスト入力ボックスを作成し、 | |
| # gr.Textbox()は出力テキストを表示するためのテキストボックスを作成します。 | |
| iface = gr.Interface( | |
| fn=main, | |
| examples=[ | |
| ["葛飾区の伝承を教えてください。"], | |
| ["千代田区にはどんな伝承がありますか?"], | |
| ["江戸川区で有名な伝承?"], | |
| ], | |
| inputs=gr.Textbox( | |
| lines=5, placeholder="質問を入力してください"), | |
| outputs=[ | |
| gr.Textbox(label="LangChain Agentのレスポンス"), | |
| gr.Textbox(label="Function Callingのレスポンス"), | |
| gr.Textbox(label="Function Calling Agentのレスポンス") | |
| ], | |
| title="日本各地の伝承AI (東京23区版)", | |
| description="最新のGPTモデルを使用し、LangChain, Function Calling, Function Calling + LangChain Agentの対話モデルのAIから回答を取得するシステムです。以下のインプット例をクリックすると入力欄に自動入力されます。", | |
| example_columns=3, | |
| example_callback=click_example | |
| ) | |
| # インターフェースを起動します | |
| iface.launch() | |