JBHF commited on
Commit
ac715c8
1 Parent(s): 2ca8a6c

Create rag_BACKUP.py

Browse files
Files changed (1) hide show
  1. rag_BACKUP.py +63 -0
rag_BACKUP.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_BACKUP.py
2
+ # rag.py
3
+ # https://github.com/vndee/local-rag-example/blob/main/rag.py
4
+
5
+ from langchain.vectorstores import Chroma
6
+ from langchain.chat_models import ChatOllama
7
+ from langchain.embeddings import FastEmbedEmbeddings
8
+ from langchain.schema.output_parser import StrOutputParser
9
+ from langchain.document_loaders import PyPDFLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain.schema.runnable import RunnablePassthrough
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.vectorstores.utils import filter_complex_metadata
14
+
15
+
16
+ class ChatPDF:
17
+ vector_store = None
18
+ retriever = None
19
+ chain = None
20
+
21
+ def __init__(self):
22
+ self.model = ChatOllama(model="mistral")
23
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
24
+ self.prompt = PromptTemplate.from_template(
25
+ """
26
+ <s> [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context
27
+ to answer the question. If you don't know the answer, just say that you don't know. Use three sentences
28
+ maximum and keep the answer concise. [/INST] </s>
29
+ [INST] Question: {question}
30
+ Context: {context}
31
+ Answer: [/INST]
32
+ """
33
+ )
34
+
35
+ def ingest(self, pdf_file_path: str):
36
+ docs = PyPDFLoader(file_path=pdf_file_path).load()
37
+ chunks = self.text_splitter.split_documents(docs)
38
+ chunks = filter_complex_metadata(chunks)
39
+
40
+ vector_store = Chroma.from_documents(documents=chunks, embedding=FastEmbedEmbeddings())
41
+ self.retriever = vector_store.as_retriever(
42
+ search_type="similarity_score_threshold",
43
+ search_kwargs={
44
+ "k": 3,
45
+ "score_threshold": 0.5,
46
+ },
47
+ )
48
+
49
+ self.chain = ({"context": self.retriever, "question": RunnablePassthrough()}
50
+ | self.prompt
51
+ | self.model
52
+ | StrOutputParser())
53
+
54
+ def ask(self, query: str):
55
+ if not self.chain:
56
+ return "Please, add a PDF document first."
57
+
58
+ return self.chain.invoke(query)
59
+
60
+ def clear(self):
61
+ self.vector_store = None
62
+ self.retriever = None
63
+ self.chain = None