Spaces:
Running
Running
File size: 5,801 Bytes
91aeb7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import os
from dotenv import load_dotenv
load_dotenv(".env")
os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
os.environ["TOKENIZERS_PARALLELISM"]='true'
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from pinecone import Pinecone
from pinecone_text.sparse import BM25Encoder
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_groq import ChatGroq
from flask import Flask, request
from flask_cors import CORS
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_socketio import SocketIO, emit
app = Flask(__name__)
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*")
app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
app.config['SESSION_COOKIE_HTTPONLY'] = True
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
app.config['SECRET_KEY'] = os.getenv('SECRET_KEY')
try:
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
index_name = "traveler-demo-website-vectorstore"
# connect to index
pinecone_index = pc.Index(index_name)
except:
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
index_name = "traveler-demo-website-vectorstore"
# connect to index
pinecone_index = pc.Index(index_name)
bm25 = BM25Encoder().load("bm25_traveler_website.json")
embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})
retriever = PineconeHybridSearchRetriever(
embeddings=embed_model,
sparse_encoder=bm25,
index=pinecone_index,
top_k=20,
alpha=0.5,
)
llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2)
### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is.
"""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}")
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \
Provide links to sources provided in the answer. \
If you don't know the answer, just say that you don't know. \
Do not give extra long answers. \
When responding to queries, your responses should be comprehensive and well-organized. For each response: \
1. Provide Clear Answers \
2. Include Detailed References: \
- Include links to sources and any links or sites where there is a mentioned in the answer.
- Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \
- Downloadable Materials: Include links to any relevant downloadable resources if applicable. \
- Reference Sites: Mention specific websites or platforms that offer additional information. \
3. Formatting for Readability: \
- Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \
- Emphasize Important Information: Use bold or italics to highlight key details. \
4. Organize Content Logically \
Do not include anything about context in the answer. \
{context}
"""
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}")
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
### Statefully manage chat history ###
store = {}
def clean_temporary_data():
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
# Stream response to client
@socketio.on('message')
def handle_message(data):
question = data.get('question')
session_id = data.get('session_id', 'abc123')
chain = conversational_rag_chain.pick("answer")
try:
for chunk in conversational_rag_chain.stream(
{"input": question},
config={
"configurable": {"session_id": "abc123"}
},
):
emit('response', chunk, room=request.sid)
except:
for chunk in conversational_rag_chain.stream(
{"input": question},
config={
"configurable": {"session_id": "abc123"}
},
):
emit('response', chunk, room=request.sid)
if __name__ == '__main__':
socketio.run(app, debug=True)
|