Update app.py
Browse files
app.py
CHANGED
|
@@ -11,8 +11,6 @@ import chardet
|
|
| 11 |
import pandas as pd
|
| 12 |
import plotly.graph_objs as go
|
| 13 |
|
| 14 |
-
os.environ["GROQ_API_KEY"] = "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"
|
| 15 |
-
|
| 16 |
try:
|
| 17 |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
except NameError:
|
|
@@ -68,10 +66,10 @@ def setup(file_path):
|
|
| 68 |
|
| 69 |
return retriever, chain_type_kwargs
|
| 70 |
|
| 71 |
-
def setup_qa(file_path, model_name):
|
| 72 |
retriever, chain_type_kwargs = setup(file_path)
|
| 73 |
llm = ChatGroq(
|
| 74 |
-
groq_api_key=
|
| 75 |
model_name=model_name,
|
| 76 |
)
|
| 77 |
qa = RetrievalQA.from_chain_type(
|
|
@@ -83,20 +81,23 @@ def setup_qa(file_path, model_name):
|
|
| 83 |
)
|
| 84 |
return qa
|
| 85 |
|
| 86 |
-
def chat_with_models(file, model_a, model_b, history_a, history_b, question):
|
| 87 |
if file is None:
|
| 88 |
-
return history_a + [("請上傳文件。", None)], history_b + [("請上傳文件。", None)]
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
file_path = file.name
|
| 91 |
_, extension = os.path.splitext(file_path)
|
| 92 |
|
| 93 |
if extension.lower() not in ['.pdf', '.txt']:
|
| 94 |
error_message = "只能上傳PDF或TXT檔案,不接受其他的格式。"
|
| 95 |
-
return history_a + [(error_message, None)], history_b + [(error_message, None)]
|
| 96 |
|
| 97 |
try:
|
| 98 |
-
qa_a = setup_qa(file_path, model_a)
|
| 99 |
-
qa_b = setup_qa(file_path, model_b)
|
| 100 |
|
| 101 |
response_a = qa_a.invoke(question)
|
| 102 |
response_b = qa_b.invoke(question)
|
|
@@ -104,10 +105,10 @@ def chat_with_models(file, model_a, model_b, history_a, history_b, question):
|
|
| 104 |
history_a.append((question, response_a["result"]))
|
| 105 |
history_b.append((question, response_b["result"]))
|
| 106 |
|
| 107 |
-
return history_a, history_b
|
| 108 |
except Exception as e:
|
| 109 |
error_message = f"遇到錯誤:{str(e)}"
|
| 110 |
-
return history_a + [(error_message, None)], history_b + [(error_message, None)]
|
| 111 |
|
| 112 |
def load_or_create_df():
|
| 113 |
if os.path.exists(CSV_PATH):
|
|
@@ -154,6 +155,7 @@ def create_demo():
|
|
| 154 |
model_a = gr.Dropdown(choices=models, label="Model A", value=models[0])
|
| 155 |
model_b = gr.Dropdown(choices=models, label="Model B", value=models[-1])
|
| 156 |
|
|
|
|
| 157 |
file_input = gr.File(label="Upload PDF or TXT file")
|
| 158 |
|
| 159 |
with gr.Row():
|
|
@@ -176,8 +178,8 @@ def create_demo():
|
|
| 176 |
|
| 177 |
send_btn.click(
|
| 178 |
fn=chat_with_models,
|
| 179 |
-
inputs=[file_input, model_a, model_b, chat_a, chat_b, question],
|
| 180 |
-
outputs=[chat_a, chat_b],
|
| 181 |
)
|
| 182 |
|
| 183 |
def create_evaluate_fn(eval_type):
|
|
|
|
| 11 |
import pandas as pd
|
| 12 |
import plotly.graph_objs as go
|
| 13 |
|
|
|
|
|
|
|
| 14 |
try:
|
| 15 |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
except NameError:
|
|
|
|
| 66 |
|
| 67 |
return retriever, chain_type_kwargs
|
| 68 |
|
| 69 |
+
def setup_qa(file_path, model_name, api_key):
|
| 70 |
retriever, chain_type_kwargs = setup(file_path)
|
| 71 |
llm = ChatGroq(
|
| 72 |
+
groq_api_key=api_key,
|
| 73 |
model_name=model_name,
|
| 74 |
)
|
| 75 |
qa = RetrievalQA.from_chain_type(
|
|
|
|
| 81 |
)
|
| 82 |
return qa
|
| 83 |
|
| 84 |
+
def chat_with_models(file, model_a, model_b, api_key, history_a, history_b, question):
|
| 85 |
if file is None:
|
| 86 |
+
return history_a + [("請上傳文件。", None)], history_b + [("請上傳文件。", None)], ""
|
| 87 |
+
|
| 88 |
+
if not api_key:
|
| 89 |
+
return history_a + [("請輸入 API 金鑰。", None)], history_b + [("請輸入 API 金鑰。", None)], ""
|
| 90 |
|
| 91 |
file_path = file.name
|
| 92 |
_, extension = os.path.splitext(file_path)
|
| 93 |
|
| 94 |
if extension.lower() not in ['.pdf', '.txt']:
|
| 95 |
error_message = "只能上傳PDF或TXT檔案,不接受其他的格式。"
|
| 96 |
+
return history_a + [(error_message, None)], history_b + [(error_message, None)], ""
|
| 97 |
|
| 98 |
try:
|
| 99 |
+
qa_a = setup_qa(file_path, model_a, api_key)
|
| 100 |
+
qa_b = setup_qa(file_path, model_b, api_key)
|
| 101 |
|
| 102 |
response_a = qa_a.invoke(question)
|
| 103 |
response_b = qa_b.invoke(question)
|
|
|
|
| 105 |
history_a.append((question, response_a["result"]))
|
| 106 |
history_b.append((question, response_b["result"]))
|
| 107 |
|
| 108 |
+
return history_a, history_b, ""
|
| 109 |
except Exception as e:
|
| 110 |
error_message = f"遇到錯誤:{str(e)}"
|
| 111 |
+
return history_a + [(error_message, None)], history_b + [(error_message, None)], ""
|
| 112 |
|
| 113 |
def load_or_create_df():
|
| 114 |
if os.path.exists(CSV_PATH):
|
|
|
|
| 155 |
model_a = gr.Dropdown(choices=models, label="Model A", value=models[0])
|
| 156 |
model_b = gr.Dropdown(choices=models, label="Model B", value=models[-1])
|
| 157 |
|
| 158 |
+
api_key = gr.Textbox(label="Enter your Groq API Key", type="password")
|
| 159 |
file_input = gr.File(label="Upload PDF or TXT file")
|
| 160 |
|
| 161 |
with gr.Row():
|
|
|
|
| 178 |
|
| 179 |
send_btn.click(
|
| 180 |
fn=chat_with_models,
|
| 181 |
+
inputs=[file_input, model_a, model_b, api_key, chat_a, chat_b, question],
|
| 182 |
+
outputs=[chat_a, chat_b, question],
|
| 183 |
)
|
| 184 |
|
| 185 |
def create_evaluate_fn(eval_type):
|