demo / app.py
WJL's picture
feat: vectorsearch-based QA
9e88bc1
raw
history blame
3.22 kB
from embeddings import KorRobertaEmbeddings
import streamlit as st
from streamlit import session_state as sst
from langchain_core.runnables import (
RunnablePassthrough,
RunnableParallel,
)
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]
def create_or_get_pinecone_index(index_name: str, dimension: int = 768):
from pinecone import Pinecone, ServerlessSpec
client = Pinecone(api_key=PINECONE_API_KEY)
if index_name in [index["name"] for index in client.list_indexes()]:
pc_index = client.Index(index_name)
print("☑️ Got the existing Pinecone index")
else:
client.create_index(
name=index_name,
dimension=dimension,
metric="cosine",
spec=ServerlessSpec("aws", "us-west-2"),
)
pc_index = client.Index(index_name)
print("☑️ Created a new Pinecone index")
print(pc_index.describe_index_stats())
return pc_index
def get_pinecone_vectorstore(
index_name: str,
embedding_fn=KorRobertaEmbeddings(),
dimension: int = 768,
namespace: str = None,
):
from langchain_pinecone import Pinecone
index = create_or_get_pinecone_index(
index_name,
dimension,
)
vs = Pinecone(
index,
embedding_fn,
pinecone_api_key=PINECONE_API_KEY,
index_name=index_name,
namespace=namespace,
)
print(vs)
return vs
def build_pinecone_retrieval_chain(vectorstore):
retriever = vectorstore.as_retriever()
rag_chain_with_source = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
)
return rag_chain_with_source
@st.cache_resource
def get_pinecone_retrieval_chain(collection_name):
print("☑️ Building a new pinecone retrieval chain...")
embed_fn = KorRobertaEmbeddings()
pinecone_vectorstore = get_pinecone_vectorstore(
index_name=collection_name,
embedding_fn=embed_fn,
dimension=768,
namespace="0221",
)
chain = build_pinecone_retrieval_chain(pinecone_vectorstore)
return chain
def rerun():
st.rerun()
st.title("이노션 데모")
with st.spinner("환경 설정 중"):
sst.retrieval_chain = get_pinecone_retrieval_chain(
collection_name="innocean",
)
if prompt := st.chat_input("정보 검색"):
# Display user message in chat message container
with st.chat_message("human"):
st.markdown(prompt)
# Get assistant response
outputs = sst.retrieval_chain.invoke(prompt)
print(outputs)
retrieval_docs = outputs["context"]
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(retrieval_docs[0].metadata["answer"])
with st.expander("출처 보기", expanded=True):
st.info(f"출처 페이지: {retrieval_docs[0].metadata['page']}")
st.markdown(retrieval_docs[0].metadata["source_passage"])
# tabs = st.tabs([f"doc{i}" for i in range(len(retrieval_docs))])
# for i in range(len(retrieval_docs)):
# tabs[i].write(retrieval_docs[i].page_content)
# tabs[i].write(retrieval_docs[i].metadata)