simpleRAG / rag.py
Yang Ouyang
init
2fc6f7f
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import FastEmbedEmbeddings
from langchain.schema.output_parser import StrOutputParser
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import PromptTemplate
from langchain.vectorstores.utils import filter_complex_metadata
from langchain_community.llms import HuggingFaceHub
import os
prompt_text = """
<s> [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context
to answer the question. If you don't know the answer, just say that you don't know. Use three sentences
maximum and keep the answer concise. [/INST] </s>
[INST] Question: {question}
Context: {context}
Answer: [/INST]
"""
HUGGING_FACE_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN")
class ChatPDF:
vector_store = None
retriever = None
chain = None
def __init__(self):
repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
self.model = HuggingFaceHub(huggingfacehub_api_token=HUGGING_FACE_API_TOKEN,
repo_id=repo_id, model_kwargs={"temperature":0.8, "max_new_tokens":100})
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
self.prompt = PromptTemplate.from_template(
prompt_text,
)
def ingest(self, pdf_file_path: str):
docs = PyPDFLoader(file_path=pdf_file_path).load()
chunks = self.text_splitter.split_documents(docs)
chunks = filter_complex_metadata(chunks)
vector_store = Chroma.from_documents(documents=chunks, embedding=FastEmbedEmbeddings())
self.retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 3,
"score_threshold": 0.5,
},
)
self.chain = ({"context": self.retriever, "question": RunnablePassthrough()}
| self.prompt
| self.model
| StrOutputParser())
def ask(self, query: str):
if not self.chain:
return "Please, add a PDF document first."
return self.chain.invoke(query)
def clear(self):
self.vector_store = None
self.retriever = None
self.chain = None