Spaces:
Runtime error
Runtime error
shigeru saito
Merge branch 'main' of https://huggingface.co/spaces/shigel/langchain-function-calling
629f973
# 必要なモジュールをインポート | |
import gradio as gr | |
import os | |
import sys | |
import json | |
import csv | |
import dotenv | |
import openai | |
from langchain.chat_models import ChatOpenAI | |
from langchain.agents import initialize_agent, Tool | |
from langchain.schema import ( | |
AIMessage, | |
AgentAction, | |
HumanMessage, | |
FunctionMessage | |
) | |
from langchain.chat_models import ChatOpenAI | |
from langchain.agents import AgentType | |
# .envファイルから環境変数をロード | |
dotenv.load_dotenv(".env") | |
# OpenAIキーをosモジュールで取得 | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
# 民間伝承を取得する関数 | |
def fetch_folklore(location): | |
folklore_lookup = {} | |
# CSVファイルからデータを読み取り、地点をキー、伝承を値とする辞書を作成 | |
with open('folklore.csv', 'r') as f: | |
reader = csv.DictReader(f) | |
folklore_lookup = {row['location']: row['folklore'] for row in reader} | |
type_lookup = {row['type']: row['folklore'] for row in reader} | |
# 指定された地点の伝承などを返す。存在しない場合は不明を返す。 | |
folklore = folklore_lookup.get((location), f"その地域の伝承は不明です。") | |
type = type_lookup.get((location), f"その地域の伝承は不明です。") | |
print("type:", type) | |
return folklore | |
def serialize_agent_action(obj): | |
if isinstance(obj, AgentAction): | |
return { "tool": obj.tool, "tool_input": obj.tool_input, "log": obj.log} | |
if isinstance(obj, _FunctionsAgentAction): | |
return { "tool": obj.tool, "tool_input": obj.tool_input, "log": obj.log, "message_log": obj.message_log} | |
if isinstance(obj, AIMessage): | |
return { "content": obj.content, "additional_kwargs": obj.additional_kwargs, "example": obj.example} | |
raise TypeError(f"Type {type(obj)} not serializable") | |
# LangChainエージェントからレスポンスを取得する関数 | |
def get_response_from_lang_chain_agent(query_text): | |
# ChatOpenAIを使用して言語モデルを初期化 | |
language_model = ChatOpenAI(model_name='gpt-3.5-turbo-0613') | |
tools = [ | |
# 民間伝承を取得するToolを作成 | |
Tool( | |
name="Folklore", | |
func=fetch_folklore, | |
description="伝承を知りたい施設や地名を入力。例: 箱根", | |
) | |
] | |
# エージェントを初期化してから応答を取得 | |
agent = initialize_agent(tools, language_model, agent="zero-shot-react-description", | |
verbose=True, return_intermediate_steps=True) | |
response = agent({"input": query_text}) | |
print(type(response)) | |
response = json.dumps(response, default=serialize_agent_action, indent=2, ensure_ascii=False) | |
return response | |
# Function Callingからレスポンスを取得する関数 | |
def get_response_from_function_calling(query_text): | |
function_definitions = [ | |
# 関数の定義を作成 | |
{ | |
"name": "fetch_folklore", | |
"description": "伝承を調べる", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"location": { | |
"description": "伝承を知りたい施設や地名。例: 箱根", | |
}, | |
}, | |
"required": ["location"], | |
}, | |
} | |
] | |
messages = [HumanMessage(content=query_text)] | |
language_model = ChatOpenAI(model_name='gpt-4') | |
# 言語モデルを使ってメッセージを予測 | |
message = language_model.predict_messages( | |
messages, functions=function_definitions) | |
if message.additional_kwargs: | |
# 関数の名前と引数を取得 | |
function_name = message.additional_kwargs["function_call"]["name"] | |
arguments = message.additional_kwargs["function_call"]["arguments"] | |
# JSON 文字列を辞書に変換 | |
arguments = json.loads(arguments) | |
location=arguments.get("location") | |
# type=arguments.get("type") | |
# 関数を実行してレスポンスを取得 | |
function_response = fetch_folklore(location=location) | |
# 関数メッセージを作成 | |
function_message = FunctionMessage( | |
name=function_name, content=function_response) | |
# 関数のレスポンスをメッセージに追加して予測 | |
messages.append(function_message) | |
second_response = language_model.predict_messages( | |
messages=messages, functions=function_definitions) | |
content = second_response.content | |
else: | |
content = message.content | |
return content | |
# Function Call Agentからレスポンスを取得する関数 | |
def get_response_from_function_calling_agent(query_text): | |
language_model = ChatOpenAI(model_name='gpt-3.5-turbo-0613') | |
tools = [ | |
# 民間伝承情報を提供するツールの追加 | |
Tool( | |
name="Folklore", | |
func=fetch_folklore, | |
description="伝承を知りたい施設や地名を入力。例: 箱根" | |
) | |
] | |
# エージェントの初期化とレスポンスの取得 | |
agent = initialize_agent(tools, language_model, agent=AgentType.OPENAI_FUNCTIONS, | |
verbose=True, return_intermediate_steps=True) | |
response = agent({"input": query_text}) | |
response = json.dumps(response, default=serialize_agent_action, indent=2, ensure_ascii=False) | |
return response | |
# メインの実行部分 | |
def main(query_text, function_name="all"): | |
response1 = "" | |
response2 = "" | |
response3 = "" | |
if function_name == "all" or function_name == "langchain": | |
# LangChainエージェントからのレスポンス | |
response1 = get_response_from_lang_chain_agent(query_text) | |
print(response1) | |
if function_name == "all" or function_name == "functioncalling": | |
# Function Callingからのレスポンス | |
response2 = get_response_from_function_calling(query_text) | |
print(response2) | |
if function_name == "all" or function_name == "functioncallingagent": | |
# Function Callingエージェントからのレスポンス | |
response3 = get_response_from_function_calling_agent(query_text) | |
print(response3) | |
return response1, response2, response3 | |
# スクリプトが直接実行された場合にmain()を実行 | |
if __name__ == "__main__": | |
if len(sys.argv) == 2: | |
query_text = sys.argv[1] | |
main(query_text=query_text) | |
elif len(sys.argv) > 2: | |
query_text = sys.argv[1] | |
function_name = sys.argv[2] | |
main(query_text=query_text, function_name=function_name) | |
else: | |
import time | |
# インプット例をクリックした時のコールバック関数 | |
def click_example(example): | |
# クリックされたインプット例をテキストボックスに自動入力 | |
inputs.value = example | |
time.sleep(0.1) # テキストボックスに文字が表示されるまで待機 | |
# 自動入力後に実行ボタンをクリックして結果を表示 | |
execute_button.click() | |
# gr.Interface()を使ってユーザーインターフェースを作成します | |
# gr.Text()はテキスト入力ボックスを作成し、 | |
# gr.Textbox()は出力テキストを表示するためのテキストボックスを作成します。 | |
iface = gr.Interface( | |
fn=main, | |
examples=[ | |
["葛飾区の伝承を教えてください。"], | |
["千代田区にはどんな伝承がありますか?"], | |
["江戸川区で有名な伝承?"], | |
], | |
inputs=gr.Textbox( | |
lines=5, placeholder="質問を入力してください"), | |
outputs=[ | |
gr.Textbox(label="LangChain Agentのレスポンス"), | |
gr.Textbox(label="Function Callingのレスポンス"), | |
gr.Textbox(label="Function Calling Agentのレスポンス") | |
], | |
title="日本各地の伝承AI (東京23区版)", | |
description="最新のGPTモデルを使用し、LangChain, Function Calling, Function Calling + LangChain Agentの対話モデルのAIから回答を取得するシステムです。以下のインプット例をクリックすると入力欄に自動入力されます。", | |
example_columns=3, | |
example_callback=click_example | |
) | |
# インターフェースを起動します | |
iface.launch() | |