FakeQA / app.py
Charles Chan
coding
cb8213b
raw
history blame
3.36 kB
import streamlit as st
from langchain_community.llms import HuggingFaceHub
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
from datasets import load_dataset
import random
# 1. 准备知识库数据 (示例)
knowledge_base = [
"Gemma 是 Google 开发的大型语言模型。",
"Gemma 具有强大的自然语言处理能力。",
"Gemma 可以用于问答、对话、文本生成等任务。",
"Gemma 基于 Transformer 架构。",
"Gemma 支持多种语言。"
]
try:
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
answer_list = [example["Answer"] for example in dataset["train"]]
except Exception as e:
st.error(f"读取数据集失败:{e}")
st.stop()
# 2. 构建向量数据库 (如果需要,仅构建一次)
try:
embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
db = FAISS.from_texts(answer_list, embeddings)
except Exception as e:
st.error(f"向量数据库构建失败:{e}")
st.stop()
# 3. 问答函数
def answer_question(repo_id, temperature, max_length, question):
# 4. 初始化 Gemma 模型
try:
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
except Exception as e:
st.error(f"Gemma 模型加载失败:{e}")
st.stop()
# 5. 获取答案
try:
question_embedding = embeddings.embed_query(question)
question_embedding_str = " ".join(map(str, question_embedding))
# print('question_embedding: ' + question_embedding_str)
docs_and_scores = db.similarity_search_with_score(question_embedding_str)
context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
print('context: ' + context)
prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
print('prompt: ' + prompt)
answer = llm.invoke(prompt)
return answer
except Exception as e:
st.error(f"问答过程出错:{e}")
return "An error occurred during the answering process."
# 6. Streamlit 界面
st.title("Gemma 知识库问答系统")
gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
temperature = st.number_input("temperature", value=1.0)
max_length = st.number_input("max_length", value=1024)
question = st.text_area("请输入问题", "Gemma 有哪些特点?")
if st.button("随机"):
dataset_size = len(dataset["train"])
random_index = random.randint(0, dataset_size - 1)
# 读取随机问题
random_question = dataset["train"][random_index]["Question"]
origin_answer = dataset["train"][random_index]["Answer"]
st.write("随机问题:")
st.write(random_question)
st.write("原始答案:")
st.write(origin_answer)
answer = answer_question(gemma, float(temperature), int(max_length), random_question)
st.write("生成答案:")
st.write(answer)
if st.button("提交"):
if not question:
st.warning("请输入问题!")
else:
with st.spinner("正在查询..."):
answer = answer_question(gemma, float(temperature), int(max_length), question)
st.write("答案:")
st.write(answer)