obsidian-qa-bot / app.py
anpigon's picture
Refactored the app.py file to improve structure and readability.
7935c3a
raw
history blame contribute delete
No virus
4.52 kB
import os
import gradio as gr
import platform
from langchain_community.document_loaders import ObsidianLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter, Language
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_core.runnables import ConfigurableField, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from langchain_google_genai import GoogleGenerativeAI
from prompt_template import PROMPT_TEMPLATE
DIRECTORIES = ["./docs/obsidian-help", "./docs/obsidian-developer"]
FAISS_DB_INDEX = "db_index"
def load_and_process_documents(directories):
md_docs = []
for directory in directories:
try:
loader = ObsidianLoader(directory, encoding="utf-8")
md_docs.extend(loader.load())
except Exception:
pass
md_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.MARKDOWN,
chunk_size=2000,
chunk_overlap=200,
)
return md_splitter.split_documents(md_docs)
def setup_retrieval_system(splitted_docs):
if platform.system() == "Darwin":
model_kwargs = {"device": "mps"}
else:
model_kwargs = {"device": "cpu"}
embeddings = HuggingFaceBgeEmbeddings(
model_name="BAAI/bge-m3",
model_kwargs=model_kwargs,
encode_kwargs={"normalize_embeddings": True},
)
store = LocalFileStore("./.cache/")
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
embeddings,
store,
namespace=embeddings.model_name,
)
if os.path.exists(FAISS_DB_INDEX):
db = FAISS.load_local(
FAISS_DB_INDEX,
cached_embeddings,
allow_dangerous_deserialization=True,
)
else:
db = FAISS.from_documents(splitted_docs, cached_embeddings)
db.save_local(folder_path=FAISS_DB_INDEX)
faiss_retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
bm25_retriever = BM25Retriever.from_documents(splitted_docs)
bm25_retriever.k = 10
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever],
weights=[0.5, 0.5],
search_type="mmr",
)
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
return ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=ensemble_retriever,
)
def setup_language_model():
return ChatGroq(
model_name="llama3-70b-8192",
temperature=0,
).configurable_alternatives(
ConfigurableField(id="llm"),
default_key="llama3",
gemini=GoogleGenerativeAI(
model="gemini-pro",
temperature=0,
),
)
def format_docs(docs):
formatted_docs = []
for doc in docs:
formatted_doc = f"Page Content:\n{doc.page_content}\n"
if doc.metadata.get("source"):
formatted_doc += f"Source: {doc.metadata['source']}\n"
formatted_docs.append(formatted_doc)
return "\n---\n".join(formatted_docs)
def main():
splitted_docs = load_and_process_documents(DIRECTORIES)
compression_retriever = setup_retrieval_system(splitted_docs)
llm = setup_language_model()
rag_chain = (
{"context": compression_retriever | format_docs, "question": RunnablePassthrough()}
| PROMPT_TEMPLATE
| llm
| StrOutputParser()
)
def predict(message, history=None):
return rag_chain.invoke(message)
gr.ChatInterface(
predict,
title="옵시디언 노트앱 및 플러그인 개발에 대해서 물어보세요!",
description="안녕하세요!\n저는 옵시디언 노트앱과 플러그인 개발에 대한 인공지능 QA봇입니다. 옵시디언 노트앱의 사용법, 고급 기능, 플러그인 및 테마 개발에 대해 깊은 지식을 가지고 있어요. 문서 작업, 정보 정리 또는 개발에 관한 도움이 필요하시면 언제든지 질문해주세요!",
).launch()
if __name__ == "__main__":
main()