shigeru saito commited on
Commit
4d4f912
·
1 Parent(s): 04b2d94

CLIとGradioの共存

Browse files
Files changed (1) hide show
  1. app.py +69 -60
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # 必要なモジュールをインポート
 
2
  import os
3
  import sys
4
  import json
@@ -27,8 +28,8 @@ def fetch_folklore(location):
27
  # CSVファイルからデータを読み取り、地点をキー、伝承を値とする辞書を作成
28
  with open('folklore.csv', 'r') as f:
29
  reader = csv.DictReader(f)
30
- folklore_lookup = {row['location']:row['folklore'] for row in reader}
31
-
32
  # 指定された地点の伝承を返す。存在しない場合は不明を返す。
33
  return folklore_lookup.get(location, "その地域の伝承は不明です。")
34
 
@@ -39,17 +40,18 @@ def get_response_from_lang_chain_agent(query_text):
39
  tools = [
40
  # 民間伝承を取得するToolを作成
41
  Tool(
42
- name = "Folklore",
43
  func=fetch_folklore,
44
- description="伝承を知りたい施設や地名を入力。例: 箱根寄木細工",
45
  )
46
  ]
47
  # エージェントを初期化してから応答を取得
48
- agent = initialize_agent(tools, language_model, agent="zero-shot-react-description", verbose=True, return_intermediate_steps=True)
 
49
  response = agent({"input": query_text})
50
  return response
51
 
52
- # 関数呼び出しからレスポンスを取得する関数
53
  def get_response_from_function_calling(query_text):
54
  function_definitions = [
55
  # 関数の定義を作成
@@ -61,33 +63,36 @@ def get_response_from_function_calling(query_text):
61
  "properties": {
62
  "location": {
63
  "type": "string",
64
- "description": "伝承を知りたい施設や地名。例: 箱根寄木細工",
65
  },
66
  },
67
  "required": ["location"],
68
  },
69
  }
70
  ]
71
- messages=[HumanMessage(content=query_text)]
72
  language_model = ChatOpenAI(model_name='gpt-3.5-turbo-0613')
73
  # 言語モデルを使ってメッセージを予測
74
- message = language_model.predict_messages(messages, functions=function_definitions)
75
-
 
76
  if message.additional_kwargs:
77
  # 関数の名前と引数を取得
78
  function_name = message.additional_kwargs["function_call"]["name"]
79
  arguments = message.additional_kwargs["function_call"]["arguments"]
80
-
81
  # JSON 文字列を辞書に変換
82
  arguments = json.loads(arguments)
83
-
84
  # 関数を実行してレスポンスを取得
85
  function_response = fetch_folklore(location=arguments.get("location"))
86
  # 関数メッセージを作成
87
- function_message = FunctionMessage(name=function_name, content=function_response)
 
88
  # 関数のレスポンスをメッセージに追加して予測
89
  messages.append(function_message)
90
- second_response = language_model.predict_messages(messages=messages, functions=function_definitions)
 
91
  content = "AIの回答: " + second_response.content
92
  else:
93
  content = "AIの回答: " + message.content
@@ -99,63 +104,67 @@ def get_response_from_function_calling_agent(query_text):
99
  tools = [
100
  # 民間伝承情報を提供するツールの追加
101
  Tool(
102
- name = "Folklore",
103
  func=fetch_folklore,
104
- description="伝承を知りたい施設や地名を入力。例: 箱根寄木細工"
105
  )
106
  ]
107
  # エージェントの初期化とレスポンスの取得
108
- agent = initialize_agent(tools, language_model, agent=AgentType.OPENAI_FUNCTIONS, verbose=True, return_intermediate_steps=True)
 
109
  response = agent({"input": query_text})
110
  return response
111
 
112
  # メインの実行部分
113
- def main(query_text):
114
- print(f"Input: {query_text}")
115
-
116
- # LangChainエージェントからのレスポンス
117
- response = get_response_from_lang_chain_agent(query_text)
118
- print(f"Output of LangChain Agent: {response}")
119
-
120
- # 関数呼び出しからのレスポンス
121
- response = get_response_from_function_calling(query_text)
122
- print(f"Output of Function Calling: {response}")
123
-
124
- # Function Callingエージェントからのレスポンス
125
- response = get_response_from_function_calling_agent(query_text)
126
- print(f"Output of Function Calling Agent: {response}")
127
 
128
- import gradio as gr
129
 
130
- def main(query_text):
131
- # LangChainエージェントからのレスポンス
132
- response1 = get_response_from_lang_chain_agent(query_text)
133
 
134
- # 関数呼び出しからのレスポンス
135
- response2 = get_response_from_function_calling(query_text)
136
-
137
- # Function Callingエージェントからのレスポンス
138
- response3 = get_response_from_function_calling_agent(query_text)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  return response1, response2, response3
141
 
142
- # gr.Interface()を使ってユーザーインターフェースを作成します
143
- # gr.inputs.Text()はテキスト入力ボックスを作成し、
144
- # gr.outputs.Textbox()は出力テキストを表示するためのテキストボックスを作成します。
145
- iface = gr.Interface(fn=main, inputs=gr.inputs.Textbox(lines=5, placeholder="質問を入力してください"),
146
- outputs=[gr.outputs.Textbox(label="LangChain Agentのレスポンス"),
147
- gr.outputs.Textbox(label="関数呼び出しのレスポンス"),
148
- gr.outputs.Textbox(label="Function Calling Agentのレスポンス")])
149
-
150
- # インターフェースを起動します
151
- iface.launch()
152
-
153
- # # スクリプトが直接実行された場合にmain()を実行
154
- # if __name__ == "__main__":
155
- # try:
156
- # query_text = sys.argv[1]
157
- # except IndexError:
158
- # print("使い方: python app.py '質問の文字列'")
159
- # print("  例) python app.py '醍醐寺の伝承を教えて'")
160
- # sys.exit()
161
- # main(query_text=query_text)
 
 
 
 
 
 
 
 
1
  # 必要なモジュールをインポート
2
+ import gradio as gr
3
  import os
4
  import sys
5
  import json
 
28
  # CSVファイルからデータを読み取り、地点をキー、伝承を値とする辞書を作成
29
  with open('folklore.csv', 'r') as f:
30
  reader = csv.DictReader(f)
31
+ folklore_lookup = {row['location']: row['folklore'] for row in reader}
32
+
33
  # 指定された地点の伝承を返す。存在しない場合は不明を返す。
34
  return folklore_lookup.get(location, "その地域の伝承は不明です。")
35
 
 
40
  tools = [
41
  # 民間伝承を取得するToolを作成
42
  Tool(
43
+ name="Folklore",
44
  func=fetch_folklore,
45
+ description="伝承を知りたい施設や地名を入力。例: 箱根",
46
  )
47
  ]
48
  # エージェントを初期化してから応答を取得
49
+ agent = initialize_agent(tools, language_model, agent="zero-shot-react-description",
50
+ verbose=True, return_intermediate_steps=True)
51
  response = agent({"input": query_text})
52
  return response
53
 
54
+ # Function Callingからレスポンスを取得する関数
55
  def get_response_from_function_calling(query_text):
56
  function_definitions = [
57
  # 関数の定義を作成
 
63
  "properties": {
64
  "location": {
65
  "type": "string",
66
+ "description": "伝承を知りたい施設や地名。例: 箱根",
67
  },
68
  },
69
  "required": ["location"],
70
  },
71
  }
72
  ]
73
+ messages = [HumanMessage(content=query_text)]
74
  language_model = ChatOpenAI(model_name='gpt-3.5-turbo-0613')
75
  # 言語モデルを使ってメッセージを予測
76
+ message = language_model.predict_messages(
77
+ messages, functions=function_definitions)
78
+
79
  if message.additional_kwargs:
80
  # 関数の名前と引数を取得
81
  function_name = message.additional_kwargs["function_call"]["name"]
82
  arguments = message.additional_kwargs["function_call"]["arguments"]
83
+
84
  # JSON 文字列を辞書に変換
85
  arguments = json.loads(arguments)
86
+
87
  # 関数を実行してレスポンスを取得
88
  function_response = fetch_folklore(location=arguments.get("location"))
89
  # 関数メッセージを作成
90
+ function_message = FunctionMessage(
91
+ name=function_name, content=function_response)
92
  # 関数のレスポンスをメッセージに追加して予測
93
  messages.append(function_message)
94
+ second_response = language_model.predict_messages(
95
+ messages=messages, functions=function_definitions)
96
  content = "AIの回答: " + second_response.content
97
  else:
98
  content = "AIの回答: " + message.content
 
104
  tools = [
105
  # 民間伝承情報を提供するツールの追加
106
  Tool(
107
+ name="Folklore",
108
  func=fetch_folklore,
109
+ description="伝承を知りたい施設や地名を入力。例: 箱根"
110
  )
111
  ]
112
  # エージェントの初期化とレスポンスの取得
113
+ agent = initialize_agent(tools, language_model, agent=AgentType.OPENAI_FUNCTIONS,
114
+ verbose=True, return_intermediate_steps=True)
115
  response = agent({"input": query_text})
116
  return response
117
 
118
  # メインの実行部分
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
 
120
 
121
+ def main(query_text, function_name="all"):
 
 
122
 
123
+ response1 = ""
124
+ response2 = ""
125
+ response3 = ""
 
 
126
 
127
+ if function_name == "all" or function_name == "langchain":
128
+ # LangChainエージェントからのレスポンス
129
+ response1 = get_response_from_lang_chain_agent(query_text)
130
+ print(response1)
131
+
132
+ if function_name == "all" or function_name == "functioncalling":
133
+ # Function Callingからのレスポンス
134
+ response2 = get_response_from_function_calling(query_text)
135
+ print(response2)
136
+
137
+ if function_name == "all" or function_name == "functioncallingagent":
138
+ # Function Callingエージェントからのレスポンス
139
+ response3 = get_response_from_function_calling_agent(query_text)
140
+ print(response3)
141
+
142
  return response1, response2, response3
143
 
144
+
145
+ # スクリプトが直接実行された場合にmain()を実行
146
+ if __name__ == "__main__":
147
+ if len(sys.argv) == 2:
148
+ query_text = sys.argv[1]
149
+ main(query_text=query_text)
150
+ elif len(sys.argv) > 2:
151
+ query_text = sys.argv[1]
152
+ function_name = sys.argv[2]
153
+ main(query_text=query_text, function_name=function_name)
154
+ else:
155
+ # gr.Interface()を使ってユーザーインターフェースを作成します
156
+ # gr.Text()はテキスト入力ボックスを作成し、
157
+ # gr.Textbox()は出力テキストを表示するためのテキストボックスを作成します。
158
+ iface = gr.Interface(
159
+ fn=main,
160
+ inputs=gr.Textbox(
161
+ lines=5, placeholder="質問を入力してください"),
162
+ outputs=[
163
+ gr.Textbox(label="LangChain Agentのレスポンス"),
164
+ gr.Textbox(label="Function Callingのレスポンス"),
165
+ gr.Textbox(label="Function Calling Agentのレスポンス")
166
+ ]
167
+ )
168
+
169
+ # インターフェースを起動します
170
+ iface.launch()