Spaces:
Sleeping
Sleeping
import os | |
import dotenv | |
import sys | |
import gradio as gr | |
import pandas as pd | |
from langchain import LLMChain | |
from langchain.agents import ZeroShotAgent, AgentExecutor, load_tools, Tool | |
from langchain.chat_models import ChatOpenAI | |
from langchain.utilities import GoogleSearchAPIWrapper | |
from langchain.schema.output_parser import OutputParserException | |
from googleapiclient.errors import HttpError | |
dotenv.load_dotenv() | |
OPENAI_API_KEY=os.environ["OPENAI_API_KEY"] | |
GOOGLE_CSE_ID=os.environ["GOOGLE_CSE_ID"] | |
GOOGLE_API_KEY=os.environ["GOOGLE_API_KEY"] | |
def search_and_generate(question, prefix = "次の質問にできる限り答えてください。"): | |
# ツールの準備 | |
search = GoogleSearchAPIWrapper(k=1) | |
tools = load_tools(["google-search"], llm=ChatOpenAI()) | |
# プロンプトテンプレートの準備 | |
prefix = f"{prefix} 次のツールにアクセスできます:" | |
suffix = """始めましょう。 | |
Question: {input} | |
{agent_scratchpad}""" | |
prompt = ZeroShotAgent.create_prompt( | |
tools, | |
prefix=prefix, | |
suffix=suffix, | |
input_variables=["input", "agent_scratchpad"] | |
) | |
# エージェントの準備 | |
llm = ChatOpenAI(model_name="gpt-4") | |
llm_chain = LLMChain(llm=llm, prompt=prompt) | |
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools) | |
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True) | |
try: | |
result = agent_executor.run(question) | |
except OutputParserException as e: | |
result = '例外が発生しました: OutputParserException ' + str(e.args[0]) | |
except HttpError as e: | |
print("An HTTP error %d occurred:\n%s" % (e.resp.status, e.content)) | |
# reasonがrateLimitExceededの場合 | |
if "rateLimitExceeded" in e.content.decode(): | |
result = "検索APIリクエストの上限に達しました。" | |
else: | |
result = "検索API実行時に通信エラーが発生しました。" | |
return result | |
def search_and_generate_csv(csv_file): | |
output_csv_file = csv_file.replace(".csv", "_output.csv") | |
questions_df = pd.read_csv(csv_file) | |
# 結果を格納するDataFrameを作成 | |
results_df = pd.DataFrame(columns=["question", "answer"]) | |
for i, row in questions_df.iterrows(): | |
question = "" | |
if i == 0 and "question" not in row: | |
# 1行目はヘッダーなのでquestionがある場合は1行目を無視する | |
continue | |
elif "question" not in row: | |
# questionがない場合は1列目をquestionとして扱う | |
question = row[0] | |
else: | |
# questionがある場合はquestionの列を参照する | |
question = row["question"] | |
# 質問に対する回答を取得 | |
answer = search_and_generate(question) | |
# 結果をDataFrameに追加 | |
results_df.loc[i] = [question, answer] | |
# 結果をCSVファイルに保存 | |
result_csv = results_df.to_csv(output_csv_file, index=False) | |
return result_csv | |
def process_input(str_question, file_csv=None): | |
# file_csvが空でない場合はCSVファイルとして扱う | |
if file_csv: | |
return search_and_generate_csv(file_csv.name) | |
elif str_question: | |
return search_and_generate(str_question) | |
def main(): | |
# パラメータに --csv がある場合、かつ、その次の引数がファイル名の場合 | |
if "--csv" in sys.argv and len(sys.argv) > sys.argv.index("--csv") + 1: | |
csv_file = sys.argv[sys.argv.index("--csv") + 1] | |
# 引数のファイル名からCSVファイルを読み込む | |
result_csv = search_and_generate_csv(csv_file) | |
# csvファイルの内容をprintで表示する | |
print(result_csv) | |
elif len(sys.argv) > 1: | |
question = sys.argv[1] | |
result = search_and_generate(question) | |
print(result) | |
else: | |
gr.Interface(fn=process_input, | |
inputs=[gr.Text()], | |
outputs="text", | |
examples=[ | |
"ChatGPTの最新のアップデートは?", | |
"日下部民藝館に伝わる伝承は?", | |
"箱根に伝わる伝説は?", | |
], | |
).launch() | |
if __name__ == "__main__": | |
main() | |