import os import gradio as gr import platform from langchain_community.document_loaders import ObsidianLoader from langchain_text_splitters import RecursiveCharacterTextSplitter, Language from langchain.embeddings import CacheBackedEmbeddings from langchain.storage import LocalFileStore from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever from langchain_cohere import CohereRerank from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_core.runnables import ConfigurableField, RunnablePassthrough from langchain_core.output_parsers import StrOutputParser from langchain_groq import ChatGroq from langchain_google_genai import GoogleGenerativeAI from prompt_template import PROMPT_TEMPLATE DIRECTORIES = ["./docs/obsidian-help", "./docs/obsidian-developer"] FAISS_DB_INDEX = "db_index" def load_and_process_documents(directories): md_docs = [] for directory in directories: try: loader = ObsidianLoader(directory, encoding="utf-8") md_docs.extend(loader.load()) except Exception: pass md_splitter = RecursiveCharacterTextSplitter.from_language( language=Language.MARKDOWN, chunk_size=2000, chunk_overlap=200, ) return md_splitter.split_documents(md_docs) def setup_retrieval_system(splitted_docs): if platform.system() == "Darwin": model_kwargs = {"device": "mps"} else: model_kwargs = {"device": "cpu"} embeddings = HuggingFaceBgeEmbeddings( model_name="BAAI/bge-m3", model_kwargs=model_kwargs, encode_kwargs={"normalize_embeddings": True}, ) store = LocalFileStore("./.cache/") cached_embeddings = CacheBackedEmbeddings.from_bytes_store( embeddings, store, namespace=embeddings.model_name, ) if os.path.exists(FAISS_DB_INDEX): db = FAISS.load_local( FAISS_DB_INDEX, cached_embeddings, allow_dangerous_deserialization=True, ) else: db = FAISS.from_documents(splitted_docs, cached_embeddings) db.save_local(folder_path=FAISS_DB_INDEX) faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10}) bm25_retriever = BM25Retriever.from_documents(splitted_docs) bm25_retriever.k = 10 ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5], search_type="mmr", ) compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5) return ContextualCompressionRetriever( base_compressor=compressor, base_retriever=ensemble_retriever, ) def setup_language_model(): return ChatGroq( model_name="llama3-70b-8192", temperature=0, ).configurable_alternatives( ConfigurableField(id="llm"), default_key="llama3", gemini=GoogleGenerativeAI( model="gemini-pro", temperature=0, ), ) def format_docs(docs): formatted_docs = [] for doc in docs: formatted_doc = f"Page Content:\n{doc.page_content}\n" if doc.metadata.get("source"): formatted_doc += f"Source: {doc.metadata['source']}\n" formatted_docs.append(formatted_doc) return "\n---\n".join(formatted_docs) def main(): splitted_docs = load_and_process_documents(DIRECTORIES) compression_retriever = setup_retrieval_system(splitted_docs) llm = setup_language_model() rag_chain = ( {"context": compression_retriever | format_docs, "question": RunnablePassthrough()} | PROMPT_TEMPLATE | llm | StrOutputParser() ) def predict(message, history=None): return rag_chain.invoke(message) gr.ChatInterface( predict, title="옵시디언 노트앱 및 플러그인 개발에 대해서 물어보세요!", description="안녕하세요!\n저는 옵시디언 노트앱과 플러그인 개발에 대한 인공지능 QA봇입니다. 옵시디언 노트앱의 사용법, 고급 기능, 플러그인 및 테마 개발에 대해 깊은 지식을 가지고 있어요. 문서 작업, 정보 정리 또는 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!", ).launch() if __name__ == "__main__": main()