File size: 3,221 Bytes
9e88bc1
 
ab1dc24
9e88bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab1dc24
9e88bc1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from embeddings import KorRobertaEmbeddings

import streamlit as st
from streamlit import session_state as sst

from langchain_core.runnables import (
    RunnablePassthrough,
    RunnableParallel,
)

PINECONE_API_KEY = st.secrets["PINECONE_API_KEY"]


def create_or_get_pinecone_index(index_name: str, dimension: int = 768):
    from pinecone import Pinecone, ServerlessSpec

    client = Pinecone(api_key=PINECONE_API_KEY)
    if index_name in [index["name"] for index in client.list_indexes()]:
        pc_index = client.Index(index_name)
        print("☑️ Got the existing Pinecone index")
    else:
        client.create_index(
            name=index_name,
            dimension=dimension,
            metric="cosine",
            spec=ServerlessSpec("aws", "us-west-2"),
        )
        pc_index = client.Index(index_name)
        print("☑️ Created a new Pinecone index")

    print(pc_index.describe_index_stats())
    return pc_index


def get_pinecone_vectorstore(
    index_name: str,
    embedding_fn=KorRobertaEmbeddings(),
    dimension: int = 768,
    namespace: str = None,
):
    from langchain_pinecone import Pinecone

    index = create_or_get_pinecone_index(
        index_name,
        dimension,
    )
    vs = Pinecone(
        index,
        embedding_fn,
        pinecone_api_key=PINECONE_API_KEY,
        index_name=index_name,
        namespace=namespace,
    )
    print(vs)
    return vs


def build_pinecone_retrieval_chain(vectorstore):
    retriever = vectorstore.as_retriever()
    rag_chain_with_source = RunnableParallel(
        {"context": retriever, "question": RunnablePassthrough()}
    )

    return rag_chain_with_source


@st.cache_resource
def get_pinecone_retrieval_chain(collection_name):
    print("☑️ Building a new pinecone retrieval chain...")
    embed_fn = KorRobertaEmbeddings()
    pinecone_vectorstore = get_pinecone_vectorstore(
        index_name=collection_name,
        embedding_fn=embed_fn,
        dimension=768,
        namespace="0221",
    )

    chain = build_pinecone_retrieval_chain(pinecone_vectorstore)
    return chain


def rerun():
    st.rerun()


st.title("이노션 데모")

with st.spinner("환경 설정 중"):
    sst.retrieval_chain = get_pinecone_retrieval_chain(
        collection_name="innocean",
    )

if prompt := st.chat_input("정보 검색"):

    # Display user message in chat message container
    with st.chat_message("human"):
        st.markdown(prompt)

    # Get assistant response
    outputs = sst.retrieval_chain.invoke(prompt)
    print(outputs)
    retrieval_docs = outputs["context"]

    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(retrieval_docs[0].metadata["answer"])

        with st.expander("출처 보기", expanded=True):
            st.info(f"출처 페이지: {retrieval_docs[0].metadata['page']}")
            st.markdown(retrieval_docs[0].metadata["source_passage"])
            # tabs = st.tabs([f"doc{i}" for i in range(len(retrieval_docs))])
            # for i in range(len(retrieval_docs)):
            #     tabs[i].write(retrieval_docs[i].page_content)
            #     tabs[i].write(retrieval_docs[i].metadata)