|
from embeddings import KorRobertaEmbeddings |
|
|
|
import streamlit as st |
|
from streamlit import session_state as sst |
|
|
|
from langchain_core.runnables import ( |
|
RunnablePassthrough, |
|
RunnableParallel, |
|
) |
|
|
|
PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"] |
|
|
|
|
|
def create_or_get_pinecone_index(index_name: str, dimension: int = 768): |
|
from pinecone import Pinecone, ServerlessSpec |
|
|
|
client = Pinecone(api_key=PINECONE_API_KEY) |
|
if index_name in [index["name"] for index in client.list_indexes()]: |
|
pc_index = client.Index(index_name) |
|
print("☑️ Got the existing Pinecone index") |
|
else: |
|
client.create_index( |
|
name=index_name, |
|
dimension=dimension, |
|
metric="cosine", |
|
spec=ServerlessSpec("aws", "us-west-2"), |
|
) |
|
pc_index = client.Index(index_name) |
|
print("☑️ Created a new Pinecone index") |
|
|
|
print(pc_index.describe_index_stats()) |
|
return pc_index |
|
|
|
|
|
def get_pinecone_vectorstore( |
|
index_name: str, |
|
embedding_fn=KorRobertaEmbeddings(), |
|
dimension: int = 768, |
|
namespace: str = None, |
|
): |
|
from langchain_pinecone import Pinecone |
|
|
|
index = create_or_get_pinecone_index( |
|
index_name, |
|
dimension, |
|
) |
|
vs = Pinecone( |
|
index, |
|
embedding_fn, |
|
pinecone_api_key=PINECONE_API_KEY, |
|
index_name=index_name, |
|
namespace=namespace, |
|
) |
|
print(vs) |
|
return vs |
|
|
|
|
|
def build_pinecone_retrieval_chain(vectorstore): |
|
retriever = vectorstore.as_retriever() |
|
rag_chain_with_source = RunnableParallel( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
) |
|
|
|
return rag_chain_with_source |
|
|
|
|
|
@st.cache_resource |
|
def get_pinecone_retrieval_chain(collection_name): |
|
print("☑️ Building a new pinecone retrieval chain...") |
|
embed_fn = KorRobertaEmbeddings() |
|
pinecone_vectorstore = get_pinecone_vectorstore( |
|
index_name=collection_name, |
|
embedding_fn=embed_fn, |
|
dimension=768, |
|
namespace="0221", |
|
) |
|
|
|
chain = build_pinecone_retrieval_chain(pinecone_vectorstore) |
|
return chain |
|
|
|
|
|
def rerun(): |
|
st.rerun() |
|
|
|
|
|
st.title("이노션 데모") |
|
|
|
with st.spinner("환경 설정 중"): |
|
sst.retrieval_chain = get_pinecone_retrieval_chain( |
|
collection_name="innocean", |
|
) |
|
|
|
if prompt := st.chat_input("정보 검색"): |
|
|
|
|
|
with st.chat_message("human"): |
|
st.markdown(prompt) |
|
|
|
|
|
outputs = sst.retrieval_chain.invoke(prompt) |
|
print(outputs) |
|
retrieval_docs = outputs["context"] |
|
|
|
|
|
with st.chat_message("assistant"): |
|
st.markdown(retrieval_docs[0].metadata["answer"]) |
|
|
|
with st.expander("출처 보기", expanded=True): |
|
st.info(f"출처 페이지: {retrieval_docs[0].metadata['page']}") |
|
st.markdown(retrieval_docs[0].metadata["source_passage"]) |
|
|
|
|
|
|
|
|
|
|