File size: 4,802 Bytes
2b5265f
ae244b4
43a5958
e74be72
ecdef15
48f76d5
ecdef15
 
 
2b5265f
43a5958
ecdef15
cabf339
ecdef15
f7f0473
 
cabf339
f7f0473
f9a0a33
 
 
ecdef15
cabf339
 
392a7b7
cabf339
ecdef15
48f76d5
43a5958
ecdef15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392a7b7
 
 
 
 
 
 
 
ecdef15
 
392a7b7
 
ecdef15
 
 
 
 
 
392a7b7
 
ecdef15
 
82ac073
ecdef15
43a5958
 
ecdef15
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pymongo
import os, textwrap
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
# from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
# from langchain_community.document_loaders import PyPDFLoader
# from langchain_community.document_loaders import DirectoryLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from flask import Flask, request, render_template
from flask_cors import CORS
from gevent import pywsgi

app = Flask(__name__)
CORS(app)

# load env
mongodb_connection_string = os.getenv("MONGODB_CONNECTION_STRING")
HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
   
# Initialize the models 
instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
llm=HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.2", model_kwargs={"temperature":0.1 ,"max_length":512})
 
@app.route('/',methods=['GET','POST'])

def main():
    if request.method == 'POST':
        # connect to mongodb
        client = pymongo.MongoClient(mongodb_connection_string)
        db = client.database
        collection = db.textbooks

        query = request.args.get('q')
        # query = unquote(query)
        print("==================== query is -",query,"====================")
        # query = 'What is the price of iphone 13?'

        # load pdfs from the Documents directory
        # loader = DirectoryLoader(f'./Documents/', glob="./*.pdf", loader_cls=PyPDFLoader)
        # documents = loader.load()

        # split the documents into chunks
        # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        # texts = text_splitter.split_documents(documents)


        # create the retriever
        # db_instructEmbedd = FAISS.from_documents(texts, instructor_embeddings)
        # retriever = db_instructEmbedd.as_retriever(search_kwargs={"k": 3})
        # retriever search type is similarity search

        # # create the retriever and do embedding
        # vector_search = MongoDBAtlasVectorSearch.from_documents(
        #     documents=texts,
        #     embedding=instructor_embeddings,
        #     collection=collection,
        #     index_name="default",
        # )

        vector_search = MongoDBAtlasVectorSearch.from_connection_string(
            mongodb_connection_string,
            "database" + "." + "textbooks",
            instructor_embeddings,
            index_name="search",
        )
        retriever = vector_search.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 3},
        )

        # prompt template
        prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

        {context}

        Question: {question}
        """
        PROMPT = PromptTemplate(
            template=prompt_template, input_variables=["context", "question"]
        )

        # create the chain to answer questions 
        qa_chain_instrucEmbed = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": PROMPT})

        # def wrap_text_preserve_newlines(text, width=110):
        #     # Split the input text into lines based on newline characters
        #     lines = text.split('\n')
        #     # Wrap each line individually
        #     wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
        #     # Join the wrapped lines back together using newline characters
        #     wrapped_text = '\n'.join(wrapped_lines)
        #     return wrapped_text

        llm_response = qa_chain_instrucEmbed(query)
        # res = wrap_text_preserve_newlines(llm_response['result'])
        res = llm_response['result']
        source = [[item.metadata.get('source')[10:-4], item.metadata.get('page')+1] for item in llm_response['source_documents']]
        print(res)

        index_helpful_answer = res.find("Answer:")
        if index_helpful_answer != -1:  
            helpful_answer_text = res[index_helpful_answer + len("Answer:"):]
            # helpful_answer_text.strip().replace("\n"," ")
            return({"result": helpful_answer_text, "source": source if "I don't know" not in helpful_answer_text else []})
        else:
            return("Error")
    else:
        return render_template('index.html')

if __name__ == '__main__':
    server = pywsgi.WSGIServer(('0.0.0.0', 7860), app)
    server.serve_forever()
    # app.run(host="0.0.0.0", port=7860)