import pymongo import os, textwrap from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate # from langchain_community.vectorstores import FAISS from langchain_community.llms import HuggingFaceHub # from langchain_community.document_loaders import PyPDFLoader # from langchain_community.document_loaders import DirectoryLoader # from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import MongoDBAtlasVectorSearch from langchain_community.embeddings import HuggingFaceInstructEmbeddings from flask import Flask, request, render_template from flask_cors import CORS from gevent import pywsgi app = Flask(__name__) CORS(app) # load env mongodb_connection_string = os.getenv("MONGODB_CONNECTION_STRING") HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") # Initialize the models instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl") llm=HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.2", model_kwargs={"temperature":0.1 ,"max_length":512}) @app.route('/',methods=['GET','POST']) def main(): if request.method == 'POST': # connect to mongodb client = pymongo.MongoClient(mongodb_connection_string) db = client.database collection = db.textbooks query = request.args.get('q') # query = unquote(query) print("==================== query is -",query,"====================") # query = 'What is the price of iphone 13?' # load pdfs from the Documents directory # loader = DirectoryLoader(f'./Documents/', glob="./*.pdf", loader_cls=PyPDFLoader) # documents = loader.load() # split the documents into chunks # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # texts = text_splitter.split_documents(documents) # create the retriever # db_instructEmbedd = FAISS.from_documents(texts, instructor_embeddings) # retriever = db_instructEmbedd.as_retriever(search_kwargs={"k": 3}) # retriever search type is similarity search # # create the retriever and do embedding # vector_search = MongoDBAtlasVectorSearch.from_documents( # documents=texts, # embedding=instructor_embeddings, # collection=collection, # index_name="default", # ) vector_search = MongoDBAtlasVectorSearch.from_connection_string( mongodb_connection_string, "database" + "." + "textbooks", instructor_embeddings, index_name="search", ) retriever = vector_search.as_retriever( search_type="similarity", search_kwargs={"k": 3}, ) # prompt template prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} Question: {question} """ PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) # create the chain to answer questions qa_chain_instrucEmbed = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": PROMPT}) # def wrap_text_preserve_newlines(text, width=110): # # Split the input text into lines based on newline characters # lines = text.split('\n') # # Wrap each line individually # wrapped_lines = [textwrap.fill(line, width=width) for line in lines] # # Join the wrapped lines back together using newline characters # wrapped_text = '\n'.join(wrapped_lines) # return wrapped_text llm_response = qa_chain_instrucEmbed(query) # res = wrap_text_preserve_newlines(llm_response['result']) res = llm_response['result'] source = [[item.metadata.get('source')[10:-4], item.metadata.get('page')+1] for item in llm_response['source_documents']] print(res) index_helpful_answer = res.find("Answer:") if index_helpful_answer != -1: helpful_answer_text = res[index_helpful_answer + len("Answer:"):] # helpful_answer_text.strip().replace("\n"," ") return({"result": helpful_answer_text, "source": source if "I don't know" not in helpful_answer_text else []}) else: return("Error") else: return render_template('index.html') if __name__ == '__main__': server = pywsgi.WSGIServer(('0.0.0.0', 7860), app) server.serve_forever() # app.run(host="0.0.0.0", port=7860)