File size: 5,705 Bytes
fb6f6d9 630d3f4 fdfcf53 4020981 a06a315 cb8213b a054c10 7c119cb 908d31d 1e21aa9 cb8213b 908d31d 1e21aa9 7c119cb 908d31d 1e21aa9 b2369fc 908d31d 1e21aa9 5b5abf5 908d31d fb6f6d9 908d31d 5b5abf5 908d31d 309abbd 908d31d 4bf350f 630d3f4 908d31d 630d3f4 908d31d fb6f6d9 908d31d 7c119cb 908d31d fb6f6d9 86b4310 908d31d 86b4310 908d31d a054c10 908d31d a054c10 908d31d fb6f6d9 908d31d cb8213b 908d31d 86b4310 908d31d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import streamlit as st
import random
from langchain_community.llms import HuggingFaceHub
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain_community.vectorstores import FAISS
from datasets import load_dataset
from opencc import OpenCC
# 使用 進擊的巨人 数据集
# 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用
if "dataset_loaded" not in st.session_state:
st.session_state.dataset_loaded = False
if not st.session_state.dataset_loaded:
try:
with st.spinner("正在读取数据库..."):
converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
answer_list = [converter.convert(example["Answer"]) for example in dataset["train"]]
st.success("数据库读取完成!")
except Exception as e:
st.error(f"读取数据集失败:{e}")
st.stop()
st.session_state.dataset_loaded = True
# 构建向量数据库 (如果需要,仅构建一次)
if "vector_created" not in st.session_state:
st.session_state.vector_created = False
if not st.session_state.vector_created:
try:
with st.spinner("正在构建向量数据库..."):
embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
db = FAISS.from_texts(answer_list, embeddings)
st.success("向量数据库构建完成!")
except Exception as e:
st.error(f"向量数据库构建失败:{e}")
st.stop()
st.session_state.vector_created = True
# 问答函数
if "repo_id" not in st.session_state:
st.session_state.repo_id = ''
if "temperature" not in st.session_state:
st.session_state.temperature = ''
if "max_length" not in st.session_state:
st.session_state.max_length = ''
def answer_question(repo_id, temperature, max_length, question):
# 初始化 Gemma 模型
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
try:
with st.spinner("正在初始化 Gemma 模型..."):
llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
st.success("Gemma 模型初始化完成!")
st.session_state.repo_id = repo_id
st.session_state.temperature = temperature
st.session_state.max_length = max_length
except Exception as e:
st.error(f"Gemma 模型加载失败:{e}")
st.stop()
# 获取答案
try:
with st.spinner("正在筛选本地数据集..."):
question_embedding = embeddings.embed_query(question)
question_embedding_str = " ".join(map(str, question_embedding))
# print('question_embedding: ' + question_embedding_str)
docs_and_scores = db.similarity_search_with_score(question_embedding_str)
context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
print('context: ' + context)
prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
print('prompt: ' + prompt)
st.success("本地数据集筛选完成!")
with st.spinner("正在生成答案..."):
answer = llm.invoke(prompt)
# 去掉 prompt 的内容
answer = answer.replace(prompt, "").strip()
st.success("答案已经生成!")
return {"prompt": prompt, "answer": answer}
except Exception as e:
st.error(f"问答过程出错:{e}")
return {"prompt": "", "answer": "An error occurred during the answering process."}
# Streamlit 界面
st.title("進擊的巨人 知识库问答系统")
col1, col2 = st.columns(2)
with col1:
gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
with col2:
temperature = st.number_input("temperature", value=1.0)
max_length = st.number_input("max_length", value=1024)
st.divider()
col3, col4 = st.columns(2)
with col3:
if st.button("使用原数据集中的随机问题"):
dataset_size = len(dataset["train"])
random_index = random.randint(0, dataset_size - 1)
# 读取随机问题
random_question = dataset["train"][random_index]["Question"]
random_question = converter.convert(random_question)
origin_answer = dataset["train"][random_index]["Answer"]
origin_answer = converter.convert(origin_answer)
print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
print('origin_answer: ' + origin_answer)
st.write("随机问题:")
st.write(random_question)
st.write("原始答案:")
st.write(origin_answer)
result = answer_question(gemma, float(temperature), int(max_length), random_question)
print('prompt: ' + result["prompt"])
print('answer: ' + result["answer"])
st.write("生成答案:")
st.write(result["answer"])
with col4:
question = st.text_area("请输入问题", "Gemma 有哪些特点?")
if st.button("提交输入的问题"):
if not question:
st.warning("请输入问题!")
else:
result = answer_question(gemma, float(temperature), int(max_length), question)
print('prompt: ' + result["prompt"])
print('answer: ' + result["answer"])
st.write("生成答案:")
st.write(result["answer"])
|