File size: 4,439 Bytes
cc4cb18
 
 
 
e87d7b4
cc4cb18
f56a350
 
 
 
e87d7b4
 
cc4cb18
 
 
 
 
 
 
f56a350
cc4cb18
56dd884
f4695fc
cc4cb18
 
f56a350
 
cc4cb18
 
 
 
 
 
 
 
 
 
 
 
56dd884
f56a350
cc4cb18
 
 
e87d7b4
 
 
 
56dd884
 
 
 
 
 
 
cc4cb18
 
e87d7b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56dd884
e87d7b4
56dd884
 
 
 
e87d7b4
cc4cb18
e87d7b4
 
 
 
 
 
 
 
 
 
 
cc4cb18
f56a350
cc4cb18
 
e87d7b4
56dd884
e87d7b4
 
56dd884
 
 
e87d7b4
 
cc4cb18
 
f56a350
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import dotenv
import sys
import gradio as gr
import pandas as pd

from langchain import LLMChain
from langchain.agents import ZeroShotAgent, AgentExecutor, load_tools, Tool
from langchain.chat_models import ChatOpenAI
from langchain.utilities import GoogleSearchAPIWrapper
from langchain.schema.output_parser import OutputParserException
from googleapiclient.errors import HttpError

dotenv.load_dotenv()

OPENAI_API_KEY=os.environ["OPENAI_API_KEY"]
GOOGLE_CSE_ID=os.environ["GOOGLE_CSE_ID"]
GOOGLE_API_KEY=os.environ["GOOGLE_API_KEY"]

def search_and_generate(question, prefix = "次の質問にできる限り答えてください。"):
    # ツールの準備
    search = GoogleSearchAPIWrapper(k=1)
    tools = load_tools(["google-search"], llm=ChatOpenAI())

    # プロンプトテンプレートの準備
    prefix = f"{prefix} 次のツールにアクセスできます:"
    suffix = """始めましょう。

    Question: {input}
    {agent_scratchpad}"""

    prompt = ZeroShotAgent.create_prompt(
        tools,
        prefix=prefix,
        suffix=suffix,
        input_variables=["input", "agent_scratchpad"]
    )

    # エージェントの準備
    llm = ChatOpenAI(model_name="gpt-4")
    llm_chain = LLMChain(llm=llm, prompt=prompt)
    agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools)
    agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)

    try:
        result = agent_executor.run(question)
    except OutputParserException as e:
        result = '例外が発生しました: OutputParserException ' + str(e.args[0])
    except HttpError as e:
        print("An HTTP error %d occurred:\n%s" % (e.resp.status, e.content))
        # reasonがrateLimitExceededの場合
        if "rateLimitExceeded" in e.content.decode():
            result = "検索APIリクエストの上限に達しました。"
        else:
            result = "検索API実行時に通信エラーが発生しました。"
    return result

def search_and_generate_csv(csv_file):
    
    output_csv_file = csv_file.replace(".csv", "_output.csv")
    questions_df = pd.read_csv(csv_file)
    
    # 結果を格納するDataFrameを作成
    results_df = pd.DataFrame(columns=["question", "answer"])
    
    for i, row in questions_df.iterrows():
        
        question = ""
        if i == 0 and "question" not in row:
            # 1行目はヘッダーなのでquestionがある場合は1行目を無視する
            continue
        elif "question" not in row:
            # questionがない場合は1列目をquestionとして扱う
            question = row[0]
        else:
            # questionがある場合はquestionの列を参照する
            question = row["question"]
        
        # 質問に対する回答を取得
        answer = search_and_generate(question)
        
        # 結果をDataFrameに追加
        results_df.loc[i] = [question, answer]
        
    # 結果をCSVファイルに保存
    result_csv = results_df.to_csv(output_csv_file, index=False)
    
    return result_csv

def process_input(str_question, file_csv=None):
    # file_csvが空でない場合はCSVファイルとして扱う
    if file_csv:
        return search_and_generate_csv(file_csv.name)
    elif str_question:
        return search_and_generate(str_question)

def main():
    # パラメータに --csv がある場合、かつ、その次の引数がファイル名の場合
    if "--csv" in sys.argv and len(sys.argv) > sys.argv.index("--csv") + 1:
        csv_file = sys.argv[sys.argv.index("--csv") + 1]
        
        # 引数のファイル名からCSVファイルを読み込む
        result_csv = search_and_generate_csv(csv_file)
        
        # csvファイルの内容をprintで表示する
        print(result_csv)
        
    elif len(sys.argv) > 1:
        question = sys.argv[1]
        result = search_and_generate(question)
        print(result)
    else:
        gr.Interface(fn=process_input, 
                 inputs=[gr.Text()], 
                 outputs="text",
                 examples=[
                            "ChatGPTの最新のアップデートは?",
                            "日下部民藝館に伝わる伝承は?",
                            "箱根に伝わる伝説は?",
                        ],
                 ).launch()

if __name__ == "__main__":
    main()