nvdajp-book-qa / app.py
terapyon's picture
added streamlit cache
60434a8
raw
history blame
6.91 kB
from time import time
from typing import Iterable
# import gradio as gr
import streamlit as st
from langchain.chains import RetrievalQA
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
# from langchain.prompts import PromptTemplate
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.llms import HuggingFacePipeline
# from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import Qdrant
from openai.error import InvalidRequestError
from qdrant_client import QdrantClient
from config import DB_CONFIG, DB_E5_CONFIG
@st.cache_resource
def load_e5_embeddings():
model_name = "intfloat/multilingual-e5-large"
model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embeddings
@st.cache_resource
def load_rinna_model():
if torch.cuda.is_available():
model_name = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_name,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map="auto",
)
return tokenizer, model
else:
return None, None
E5_EMBEDDINGS = load_e5_embeddings()
RINNA_TOKENIZER, RINNA_MODEL = load_rinna_model()
def _get_config_and_embeddings(collection_name: str | None) -> tuple:
if collection_name is None or collection_name == "E5":
db_config = DB_E5_CONFIG
embeddings = E5_EMBEDDINGS
elif collection_name == "OpenAI":
db_config = DB_CONFIG
embeddings = OpenAIEmbeddings()
else:
raise ValueError("Unknow collection name")
return db_config, embeddings
@st.cache_resource
def _get_rinna_llm(temperature: float) -> HuggingFacePipeline | None:
if RINNA_MODEL is not None:
pipe = pipeline(
"text-generation",
model=RINNA_MODEL,
tokenizer=RINNA_TOKENIZER,
max_new_tokens=1024,
temperature=temperature,
)
llm = HuggingFacePipeline(pipeline=pipe)
else:
llm = None
return llm
def _get_llm_model(
model_name: str | None,
temperature: float,
):
if model_name is None:
model = "gpt-3.5-turbo"
elif model_name == "rinna":
model = "rinna"
elif model_name == "GPT-3.5":
model = "gpt-3.5-turbo"
elif model_name == "GPT-4":
model = "gpt-4"
else:
raise ValueError("Unknow model name")
if model.startswith("gpt"):
llm = ChatOpenAI(model=model, temperature=temperature)
elif model == "rinna":
llm = _get_rinna_llm(temperature)
return llm
def get_retrieval_qa(
collection_name: str | None,
model_name: str | None,
temperature: float,
option: str | None,
):
db_config, embeddings = _get_config_and_embeddings(collection_name)
db_url, db_api_key, db_collection_name = db_config
client = QdrantClient(url=db_url, api_key=db_api_key)
db = Qdrant(
client=client, collection_name=db_collection_name, embeddings=embeddings
)
if option is None or option == "All":
retriever = db.as_retriever()
else:
retriever = db.as_retriever(
search_kwargs={
"filter": {"category": option},
}
)
llm = _get_llm_model(model_name, temperature)
# chain_type_kwargs = {"prompt": PROMPT}
result = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
# chain_type_kwargs=chain_type_kwargs,
)
return result
def get_related_url(metadata) -> Iterable[str]:
urls = set()
for m in metadata:
# p = m['source']
url = m["url"]
if url in urls:
continue
urls.add(url)
category = m["category"]
# print(m)
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
def run_qa(query: str, qa: RetrievalQA) -> tuple[str, str]:
now = time()
try:
result = qa(query)
except InvalidRequestError as e:
return "回答が見つかりませんでした。別な質問をしてみてください", str(e)
else:
metadata = [s.metadata for s in result["source_documents"]]
sec_html = f"<p>実行時間: {(time() - now):.2f}秒</p>"
html = "<div>" + sec_html + "\n".join(get_related_url(metadata)) + "</div>"
return result["result"], html
def main(
query: str,
collection_name: str | None,
model_name: str | None,
option: str | None,
temperature: float,
e5_option: list[str],
) -> Iterable[tuple[str, tuple[str, str]]]:
qa = get_retrieval_qa(collection_name, model_name, temperature, option)
if collection_name == "E5":
for option in e5_option:
if option == "No":
yield "E5 No", run_qa(query, qa)
elif option == "Query":
yield "E5 Query", run_qa("query: " + query, qa)
elif option == "Passage":
yield "E5 Passage", run_qa("passage: " + query, qa)
else:
raise ValueError("Unknow option")
else:
yield "OpenAI", run_qa(query, qa)
AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"]
if RINNA_MODEL is not None:
AVAILABLE_LLMS.append("rinna")
with st.form("my_form"):
query = st.text_input(label="query")
collection_name = st.radio(options=["E5", "OpenAI"], label="Embedding")
# if collection_name == "E5": # TODO : 選択肢で選べるようにする
e5_option = st.multiselect("E5 option", ["No", "Query", "Passage"], default="No")
model_name = st.radio(
options=AVAILABLE_LLMS,
label="Model",
help="GPU環境だとrinnaが選択可能",
)
option = st.radio(
options=["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
label="絞り込み",
help="ドキュメント制限する?",
)
temperature = st.slider(label="temperature", min_value=0, max_value=2)
submitted = st.form_submit_button("Submit")
if submitted:
with st.spinner("Searching..."):
results = main(
query, collection_name, model_name, option, temperature, e5_option
)
for type_, (answer, html) in results:
with st.container():
st.header(type_)
st.write(answer)
st.markdown(html, unsafe_allow_html=True)
st.divider()