Prashant Kumar commited on
Commit
e4b7696
1 Parent(s): 232b6a0

added models and dockerfile

Browse files
Files changed (4) hide show
  1. Dockerfile +11 -0
  2. ingest.py +28 -0
  3. model.py +92 -0
  4. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /app
4
+
5
+ COPY ./requirement.txt /app/
6
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
7
+
8
+ COPY . /app/
9
+
10
+ CMD ["python", "ingest.py"]
11
+ CMD ["chainlit", "run", "model.py" "-w"]
ingest.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
+ from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.vectorstores import FAISS
5
+
6
+ DATA_PATH = "data/"
7
+ DB_FAISS_PATH = "vectorstores/db_faiss"
8
+
9
+ #model path:
10
+ #https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/blob/main/llama-2-7b-chat.ggmlv3.q8_0.bin
11
+
12
+ #create vector database
13
+ def create_vector_db():
14
+ loader = DirectoryLoader(
15
+ DATA_PATH,
16
+ glob='*.pdf',
17
+ loader_cls=PyPDFLoader
18
+ )
19
+ documents = loader.load()
20
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 500, chunk_overlap = 50)
21
+ texts = text_splitter.split_documents(documents)
22
+
23
+ embeddings = HuggingFaceEmbeddings(model_name = 'sentence-transformers/allMiniLM-L6-v2', model_kwargs = {'device': 'cpu'})
24
+ db = FAISS.from_documents(texts, embeddings)
25
+ db.save_local(DB_FAISS_PATH)
26
+
27
+ if __name__ == '__main__':
28
+ create_vector_db()
model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain import PromptTemplate
2
+ from langchain.embeddings import HuggingFaceEmbeddings
3
+ from langchain.vectorstores import FAISS
4
+ from langchain.llms import CTransformers
5
+ from langchain.chains import RetrievalQA
6
+ import chainlit as cl
7
+
8
+ DB_FAISS_PATH = "vectorstores/db_faiss"
9
+
10
+ custom_prompt_template = """Use the following pieces of information to answer the user's questions.
11
+ If you don't know the answer, don't try to make up an answer, just say that you do not know it.
12
+
13
+ Context: {}
14
+ Question: {question}
15
+
16
+ Only returns the helpful answer below and nothing else.
17
+ Helpful answer:
18
+ """
19
+
20
+ def set_custom_prompt():
21
+ """
22
+ Prompt template for QA retrieval for each vector stores
23
+ """
24
+
25
+ prompt = PromptTemplate(template = custom_prompt_template, input_variable = ['context', 'question'])
26
+ return prompt
27
+
28
+ def load_llm():
29
+ llm = CTransformers(
30
+ model = "llama-2-7b-chat.ggmlv3.q8_0.bin",
31
+ model_type = "llama",
32
+ max_new_tokens = 512,
33
+ temperature = 0.5
34
+ )
35
+ return llm
36
+
37
+ def retrieval_qa_chain(llm, prompt, db):
38
+ qa_chain = RetrievalQA.from_chain_type(
39
+ llm = llm,
40
+ chain_type = "stuff",
41
+ retriever = db.as_retriever(
42
+ search_kwargs = {'k': 2 },
43
+ return_source_documents = True,
44
+ chain_type_kwargs = { 'prompt': prompt }
45
+ )
46
+ )
47
+
48
+ return qa_chain
49
+
50
+ def qa_bot():
51
+ embeddings = HuggingFaceEmbeddings(model_name = 'sentence-transformers/all-MiniLM-L6-v2', model_kwargs = {'device': 'cpu'})
52
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings)
53
+ llm = load_llm()
54
+ qa_prompt = set_custom_prompt()
55
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
56
+ return qa
57
+
58
+ def final_result(query):
59
+ qa_result = qa_bot()
60
+ response = qa_result({'query': query})
61
+ return response
62
+
63
+
64
+ ## Chainlit ##
65
+
66
+ @cl.on_chat_start
67
+ async def start():
68
+ chain = qa_bot()
69
+ msg = cl.Message(content="Starting the bot...")
70
+ await msg.send()
71
+ msg.content = "Hi! I am Jarvis, what's your query?"
72
+ await msg.update()
73
+ cl.user_session.set("chain", chain)
74
+
75
+ @cl.on_message
76
+ async def main(message):
77
+ chain = cl.user_session.get("chain")
78
+ cb = cl.AsyncLangchainCallbackHandler(
79
+ stream_final_answer = True,
80
+ stream_prefix_tokens = ["FINAL", "ANSWER"]
81
+ )
82
+ cb.answer_reached = True
83
+ res = await chain.acall(message, callbacks=[cb])
84
+ answer = res['result']
85
+ sources = res['source_documents']
86
+
87
+ if(sources):
88
+ answer += f"\nSources: "+str(sources)
89
+ else:
90
+ answer += f"\nNo sources found"
91
+
92
+ await cl.Message(content = answer).send()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pypdf
2
+ langchain
3
+ chainlit
4
+ torch
5
+ accelerate
6
+ transformers
7
+ sentence_transformers
8
+ faiss_cpu