Spaces:
Sleeping
Sleeping
File size: 4,439 Bytes
cc4cb18 e87d7b4 cc4cb18 f56a350 e87d7b4 cc4cb18 f56a350 cc4cb18 56dd884 f4695fc cc4cb18 f56a350 cc4cb18 56dd884 f56a350 cc4cb18 e87d7b4 56dd884 cc4cb18 e87d7b4 56dd884 e87d7b4 56dd884 e87d7b4 cc4cb18 e87d7b4 cc4cb18 f56a350 cc4cb18 e87d7b4 56dd884 e87d7b4 56dd884 e87d7b4 cc4cb18 f56a350 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import dotenv
import sys
import gradio as gr
import pandas as pd
from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, load_tools, Tool
from langchain.chat_models import ChatOpenAI
from langchain.utilities import GoogleSearchAPIWrapper
from langchain.schema.output_parser import OutputParserException
from googleapiclient.errors import HttpError
dotenv.load_dotenv()
OPENAI_API_KEY=os.environ["OPENAI_API_KEY"]
GOOGLE_CSE_ID=os.environ["GOOGLE_CSE_ID"]
GOOGLE_API_KEY=os.environ["GOOGLE_API_KEY"]
def search_and_generate(question, prefix = "次の質問にできる限り答えてください。"):
# ツールの準備
search = GoogleSearchAPIWrapper(k=1)
tools = load_tools(["google-search"], llm=ChatOpenAI())
# プロンプトテンプレートの準備
prefix = f"{prefix} 次のツールにアクセスできます:"
suffix = """始めましょう。
Question: {input}
{agent_scratchpad}"""
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
input_variables=["input", "agent_scratchpad"]
)
# エージェントの準備
llm = ChatOpenAI(model_name="gpt-4")
llm_chain = LLMChain(llm=llm, prompt=prompt)
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools)
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
try:
result = agent_executor.run(question)
except OutputParserException as e:
result = '例外が発生しました: OutputParserException ' + str(e.args[0])
except HttpError as e:
print("An HTTP error %d occurred:\n%s" % (e.resp.status, e.content))
# reasonがrateLimitExceededの場合
if "rateLimitExceeded" in e.content.decode():
result = "検索APIリクエストの上限に達しました。"
else:
result = "検索API実行時に通信エラーが発生しました。"
return result
def search_and_generate_csv(csv_file):
output_csv_file = csv_file.replace(".csv", "_output.csv")
questions_df = pd.read_csv(csv_file)
# 結果を格納するDataFrameを作成
results_df = pd.DataFrame(columns=["question", "answer"])
for i, row in questions_df.iterrows():
question = ""
if i == 0 and "question" not in row:
# 1行目はヘッダーなのでquestionがある場合は1行目を無視する
continue
elif "question" not in row:
# questionがない場合は1列目をquestionとして扱う
question = row[0]
else:
# questionがある場合はquestionの列を参照する
question = row["question"]
# 質問に対する回答を取得
answer = search_and_generate(question)
# 結果をDataFrameに追加
results_df.loc[i] = [question, answer]
# 結果をCSVファイルに保存
result_csv = results_df.to_csv(output_csv_file, index=False)
return result_csv
def process_input(str_question, file_csv=None):
# file_csvが空でない場合はCSVファイルとして扱う
if file_csv:
return search_and_generate_csv(file_csv.name)
elif str_question:
return search_and_generate(str_question)
def main():
# パラメータに --csv がある場合、かつ、その次の引数がファイル名の場合
if "--csv" in sys.argv and len(sys.argv) > sys.argv.index("--csv") + 1:
csv_file = sys.argv[sys.argv.index("--csv") + 1]
# 引数のファイル名からCSVファイルを読み込む
result_csv = search_and_generate_csv(csv_file)
# csvファイルの内容をprintで表示する
print(result_csv)
elif len(sys.argv) > 1:
question = sys.argv[1]
result = search_and_generate(question)
print(result)
else:
gr.Interface(fn=process_input,
inputs=[gr.Text()],
outputs="text",
examples=[
"ChatGPTの最新のアップデートは?",
"日下部民藝館に伝わる伝承は?",
"箱根に伝わる伝説は?",
],
).launch()
if __name__ == "__main__":
main()
|