Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# 財政部財政資訊中心 江信宗 | |
# pip install langchain transformers langchain-groq chromadb langchain-community langchain-huggingface gradio | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
os.environ["LANGCHAIN_COMMUNITY__USER_AGENT"] = "Taiwan_Tax_KB (Colab)" | |
from langchain_community.utils import user_agent | |
from langchain_groq import ChatGroq | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain_community.document_loaders import WebBaseLoader, TextLoader | |
from langchain.prompts import PromptTemplate | |
from langchain.schema import Document | |
import gradio as gr | |
def initialize_llm(api_key): | |
os.environ["GROQ_API_KEY"] = api_key | |
return ChatGroq( | |
groq_api_key=api_key, | |
model_name='llama-3.1-70b-versatile' | |
) | |
print(f"成功初始化 ChatGroq 模型") | |
def load_documents(sources): | |
documents = [] | |
for source in sources: | |
try: | |
if isinstance(source, str): | |
if source.startswith('http'): | |
loader = WebBaseLoader(source) | |
else: | |
loader = TextLoader(source) | |
documents.extend(loader.load()) | |
elif isinstance(source, dict): | |
documents.append(Document(page_content=source['content'], metadata=source.get('metadata', {}))) | |
except Exception as e: | |
print(f"Error loading source {source}: {str(e)}") | |
return documents | |
sources = [ | |
"/content/TaxQADataSet_kctax.txt", | |
"/content/TaxQADataSet_chutax.txt", | |
"/content/HouseTaxAct1130103.txt", | |
"/content/VehicleLicenseTaxAct1101230.txt", | |
"/content/TaxCollectionAct1101217.txt", | |
"/content/LandTaxAct1100623.txt", | |
"/content/AmusementTaxAct960523.txt", | |
"/content/StampTaxAct910515.txt", | |
"/content/DeedTaxAct990505.txt", | |
"/content/ProgressiveHouseTaxRates1130701.txt", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-1-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-2-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-3-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-4-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-5-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-6-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-7-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-8-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-9-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-10-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-11-20.html", | |
"https://www.tax.ntpc.gov.tw/lp-2158-1-12-20.html" | |
] | |
documents = load_documents(sources) | |
print(f"成功載入 {len(documents)} 個網址或檔案") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=512, | |
chunk_overlap=50, | |
length_function=len, | |
separators=["\n\n\n","\n\n", "\n", "。"] | |
) | |
split_docs = text_splitter.split_documents(documents) | |
print(f"分割後的文檔數量:{len(split_docs)}") | |
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-zh-v1.5") | |
print(f"\n成功初始化嵌入模型") | |
vectorstore = Chroma.from_documents(split_docs, embeddings, persist_directory="./Knowledge-base") | |
print(f"成功建立 Chroma 向量資料庫") | |
retriever = vectorstore.as_retriever() | |
template = """Let's work this out in a step by step way to be sure we have the right answer. Must reply to me in Taiwanese Traditional Chinese. | |
在回答之前,請仔細分析檢索到的上下文,確保你的回答準確完整反映了上下文中的訊息,而不是依賴先前的知識,但在回應答案中不要提到是根據提供的上下文回答。 | |
如果檢索到的多個上下文之間存在聯繫,請整合這些訊息以提供全面的回答,但要避免過度推斷。 | |
如果檢索到的上下文不包含足夠回答問題的訊息,請誠實的說明,不要試圖編造答案。 | |
上下文: {context} | |
問題: {question} | |
答案:""" | |
PROMPT = PromptTemplate( | |
template=template, input_variables=["context", "question"] | |
) | |
print(f"成功定義 Prompt Template") | |
def create_chain(llm): | |
return RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": PROMPT} | |
) | |
print(f"成功建立 RAG Chain") | |
def generate_insight_questions(answer, api_key): | |
llm = initialize_llm(api_key) | |
prompt = f""" | |
根據以下回答,生成3個相關的洞見問題: | |
回答: {answer} | |
請提供3個簡短但有深度的問題,這些問題應該: | |
1. 與原始回答緊密相關 | |
2. 能夠引導更深入的討論 | |
3. 涵蓋不同的方面或角度 | |
請直接列出這3個問題,每個問題一行,不要添加編號或其他文字。 | |
""" | |
try: | |
response = llm.invoke(prompt) | |
if hasattr(response, 'content'): | |
questions = response.content.split('\n') | |
else: | |
questions = str(response).split('\n') | |
# 確保我們有至少3個問題 | |
while len(questions) < 3: | |
questions.append("需要更多資訊嗎?") | |
return questions[:3] # 只返回前3個問題 | |
except Exception as e: | |
print(f"Error generating insight questions: {str(e)}") | |
return ["需要更多資訊嗎?", "有其他問題嗎?", "還有什麼想了解的嗎?"] | |
def answer_question(query, api_key): | |
try: | |
llm = initialize_llm(api_key) | |
chain = create_chain(llm) | |
result = chain({"query": query}) | |
answer = result["result"] | |
insight_questions = generate_insight_questions(answer, api_key) | |
# 確保有三個問題,如果不足則添加默認問題 | |
while len(insight_questions) < 3: | |
insight_questions.append("需要更多資訊嗎?") | |
# 分開返回答案和洞見問題 | |
return answer, insight_questions[:3] | |
except Exception as e: | |
return f"抱歉,處理您的問題時發生錯誤:{str(e)}", [] | |
def handle_interaction(query, api_key, state): | |
if state is None: | |
state = {"history": []} | |
answer, insight_questions = answer_question(query, api_key) | |
state["history"].append((query, answer)) | |
insight_questions = [q if q.strip() else "需要更多資訊" for q in insight_questions] | |
return answer, insight_questions[0], insight_questions[1], insight_questions[2], state, query | |
custom_css = """ | |
body { | |
background-color: #e8f5e9; | |
} | |
#answer-box textarea, #query-input textarea { | |
font-size: 18px !important; | |
background-color: #ffffff; | |
border: 1px solid #81c784; | |
border-radius: 8px; | |
} | |
.center-text { | |
text-align: center !important; | |
color: #2e7d32 !important; | |
} | |
.gradio-container { | |
background-color: #c8e6c9 !important; | |
border-radius: 15px !important; | |
padding: 20px !important; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important; | |
} | |
.gr-button { | |
color: white !important; | |
border: none !important; | |
border-radius: 20px !important; | |
transition: all 0.3s ease !important; | |
font-weight: bold !important; | |
} | |
.gr-button:hover { | |
transform: translateY(-2px) !important; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2) !important; | |
} | |
#submit-btn { | |
background-color: #ff4081 !important; | |
} | |
#submit-btn:hover { | |
background-color: #f50057 !important; | |
} | |
.insight-btn { | |
background-color: #00bcd4 !important; | |
} | |
.insight-btn:hover { | |
background-color: #00acc1 !important; | |
} | |
.gr-form { | |
background-color: #e8f5e9 !important; | |
padding: 15px !important; | |
border-radius: 10px !important; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as iface: | |
gr.Markdown("# 地方稅知識庫系統 - 財政部財政資訊中心", elem_classes=["center-text"]) | |
gr.Markdown("※ RAG-based Q&A Web系統,建置:江信宗,LLM:Llama-3.1-70B,目前僅示範地方稅各稅目問答。", elem_classes=["center-text"]) | |
with gr.Row(): | |
query_input = gr.Textbox(lines=2, placeholder="請輸入您的問題...", label="輸入您的問題,系統將基於學習到的知識資料提供相關答案。", elem_id="query-input") | |
api_key_input = gr.Textbox(type="password", placeholder="請輸入您的 API Key", label="API authentication key for large language models") | |
answer_output = gr.Textbox(lines=6, label="答案:", elem_id="answer-box") | |
with gr.Row(): | |
insight_q1 = gr.Button("洞見問題 1", visible=False, elem_classes=["insight-btn"]) | |
insight_q2 = gr.Button("洞見問題 2", visible=False, elem_classes=["insight-btn"]) | |
insight_q3 = gr.Button("洞見問題 3", visible=False, elem_classes=["insight-btn"]) | |
state = gr.State() | |
current_question = gr.Textbox(lines=2, label="當前問題", visible=False) | |
submit_btn = gr.Button("提交", elem_id="submit-btn") | |
def update_ui(answer, q1, q2, q3, state, current_q): | |
return [ | |
answer, | |
gr.update(value=q1, visible=bool(q1)), | |
gr.update(value=q2, visible=bool(q2)), | |
gr.update(value=q3, visible=bool(q3)), | |
state, | |
current_q | |
] | |
submit_btn.click( | |
fn=handle_interaction, | |
inputs=[query_input, api_key_input, state], | |
outputs=[answer_output, insight_q1, insight_q2, insight_q3, state, current_question] | |
).then( | |
fn=update_ui, | |
inputs=[answer_output, insight_q1, insight_q2, insight_q3, state, current_question], | |
outputs=[answer_output, insight_q1, insight_q2, insight_q3, state, current_question] | |
) | |
for btn in [insight_q1, insight_q2, insight_q3]: | |
btn.click( | |
lambda x: x, | |
inputs=[btn], | |
outputs=[query_input] | |
) | |
if __name__ == "__main__": | |
iface.launch(share=True, debug=True) | |