shigeru saito
requestsの削除
f4695fc
import os
import dotenv
import sys
import gradio as gr
import pandas as pd
from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, load_tools, Tool
from langchain.chat_models import ChatOpenAI
from langchain.utilities import GoogleSearchAPIWrapper
from langchain.schema.output_parser import OutputParserException
from googleapiclient.errors import HttpError
dotenv.load_dotenv()
OPENAI_API_KEY=os.environ["OPENAI_API_KEY"]
GOOGLE_CSE_ID=os.environ["GOOGLE_CSE_ID"]
GOOGLE_API_KEY=os.environ["GOOGLE_API_KEY"]
def search_and_generate(question, prefix = "次の質問にできる限り答えてください。"):
# ツールの準備
search = GoogleSearchAPIWrapper(k=1)
tools = load_tools(["google-search"], llm=ChatOpenAI())
# プロンプトテンプレートの準備
prefix = f"{prefix} 次のツールにアクセスできます:"
suffix = """始めましょう。
Question: {input}
{agent_scratchpad}"""
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=["input", "agent_scratchpad"]
)
# エージェントの準備
llm = ChatOpenAI(model_name="gpt-4")
llm_chain = LLMChain(llm=llm, prompt=prompt)
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
try:
result = agent_executor.run(question)
except OutputParserException as e:
result = '例外が発生しました: OutputParserException ' + str(e.args[0])
except HttpError as e:
print("An HTTP error %d occurred:\n%s" % (e.resp.status, e.content))
# reasonがrateLimitExceededの場合
if "rateLimitExceeded" in e.content.decode():
result = "検索APIリクエストの上限に達しました。"
else:
result = "検索API実行時に通信エラーが発生しました。"
return result
def search_and_generate_csv(csv_file):
output_csv_file = csv_file.replace(".csv", "_output.csv")
questions_df = pd.read_csv(csv_file)
# 結果を格納するDataFrameを作成
results_df = pd.DataFrame(columns=["question", "answer"])
for i, row in questions_df.iterrows():
question = ""
if i == 0 and "question" not in row:
# 1行目はヘッダーなのでquestionがある場合は1行目を無視する
continue
elif "question" not in row:
# questionがない場合は1列目をquestionとして扱う
question = row[0]
else:
# questionがある場合はquestionの列を参照する
question = row["question"]
# 質問に対する回答を取得
answer = search_and_generate(question)
# 結果をDataFrameに追加
results_df.loc[i] = [question, answer]
# 結果をCSVファイルに保存
result_csv = results_df.to_csv(output_csv_file, index=False)
return result_csv
def process_input(str_question, file_csv=None):
# file_csvが空でない場合はCSVファイルとして扱う
if file_csv:
return search_and_generate_csv(file_csv.name)
elif str_question:
return search_and_generate(str_question)
def main():
# パラメータに --csv がある場合、かつ、その次の引数がファイル名の場合
if "--csv" in sys.argv and len(sys.argv) > sys.argv.index("--csv") + 1:
csv_file = sys.argv[sys.argv.index("--csv") + 1]
# 引数のファイル名からCSVファイルを読み込む
result_csv = search_and_generate_csv(csv_file)
# csvファイルの内容をprintで表示する
print(result_csv)
elif len(sys.argv) > 1:
question = sys.argv[1]
result = search_and_generate(question)
print(result)
else:
gr.Interface(fn=process_input,
inputs=[gr.Text()],
outputs="text",
examples=[
"ChatGPTの最新のアップデートは?",
"日下部民藝館に伝わる伝承は?",
"箱根に伝わる伝説は?",
],
).launch()
if __name__ == "__main__":
main()