File size: 4,519 Bytes
c63ff03
 
7935c3a
84af0cb
c63ff03
 
 
7b40096
c63ff03
 
 
 
 
7935c3a
c63ff03
 
 
 
7935c3a
c63ff03
7935c3a
c63ff03
 
7935c3a
 
 
 
 
 
 
 
 
 
 
 
 
c63ff03
7935c3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c63ff03
7935c3a
 
 
 
 
 
 
 
c63ff03
f46dfb5
 
 
 
 
 
 
 
 
7935c3a
 
 
 
 
 
 
 
 
 
 
f46dfb5
7935c3a
 
c63ff03
7935c3a
 
 
 
 
c63ff03
7935c3a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import gradio as gr
import platform
from langchain_community.document_loaders import ObsidianLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter, Language
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_core.runnables import ConfigurableField, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from langchain_google_genai import GoogleGenerativeAI

from prompt_template import PROMPT_TEMPLATE

DIRECTORIES = ["./docs/obsidian-help", "./docs/obsidian-developer"]
FAISS_DB_INDEX = "db_index"

def load_and_process_documents(directories):
    md_docs = []
    for directory in directories:
        try:
            loader = ObsidianLoader(directory, encoding="utf-8")
            md_docs.extend(loader.load())
        except Exception:
            pass

    md_splitter = RecursiveCharacterTextSplitter.from_language(
        language=Language.MARKDOWN,
        chunk_size=2000,
        chunk_overlap=200,
    )
    return md_splitter.split_documents(md_docs)

def setup_retrieval_system(splitted_docs):
    if platform.system() == "Darwin":
        model_kwargs = {"device": "mps"}
    else:
        model_kwargs = {"device": "cpu"}
    
    embeddings = HuggingFaceBgeEmbeddings(
        model_name="BAAI/bge-m3",
        model_kwargs=model_kwargs,
        encode_kwargs={"normalize_embeddings": True},
    )

    store = LocalFileStore("./.cache/")
    cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
        embeddings,
        store,
        namespace=embeddings.model_name,
    )

    if os.path.exists(FAISS_DB_INDEX):
        db = FAISS.load_local(
            FAISS_DB_INDEX,
            cached_embeddings,
            allow_dangerous_deserialization=True,
        )
    else:
        db = FAISS.from_documents(splitted_docs, cached_embeddings)
        db.save_local(folder_path=FAISS_DB_INDEX)

    faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
    bm25_retriever = BM25Retriever.from_documents(splitted_docs)
    bm25_retriever.k = 10

    ensemble_retriever = EnsembleRetriever(
        retrievers=[bm25_retriever, faiss_retriever],
        weights=[0.5, 0.5],
        search_type="mmr",
    )

    compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
    return ContextualCompressionRetriever(
        base_compressor=compressor,
        base_retriever=ensemble_retriever,
    )

def setup_language_model():
    return ChatGroq(
        model_name="llama3-70b-8192",
        temperature=0,
    ).configurable_alternatives(
        ConfigurableField(id="llm"),
        default_key="llama3",
        gemini=GoogleGenerativeAI(
            model="gemini-pro",
            temperature=0,
        ),
    )

def format_docs(docs):
    formatted_docs = []
    for doc in docs:
        formatted_doc = f"Page Content:\n{doc.page_content}\n"
        if doc.metadata.get("source"):
            formatted_doc += f"Source: {doc.metadata['source']}\n"
        formatted_docs.append(formatted_doc)
    return "\n---\n".join(formatted_docs)

def main():
    splitted_docs = load_and_process_documents(DIRECTORIES)
    compression_retriever = setup_retrieval_system(splitted_docs)
    llm = setup_language_model()

    rag_chain = (
        {"context": compression_retriever | format_docs, "question": RunnablePassthrough()}
        | PROMPT_TEMPLATE
        | llm
        | StrOutputParser()
    )

    def predict(message, history=None):
        return rag_chain.invoke(message)

    gr.ChatInterface(
        predict,
        title="옵시디언 노트앱 및 플러그인 개발에 대해서 물어보세요!",
        description="안녕하세요!\n저는 옵시디언 노트앱과 플러그인 개발에 대한 인공지능 QA봇입니다. 옵시디언 노트앱의 사용법, 고급 기능, 플러그인 및 테마 개발에 대해 깊은 지식을 가지고 있어요. 문서 작업, 정보 정리 또는 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!",
    ).launch()

if __name__ == "__main__":
    main()