File size: 5,801 Bytes
91aeb7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import os
from dotenv import load_dotenv
load_dotenv(".env")

os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
os.environ["TOKENIZERS_PARALLELISM"]='true'

from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory

from pinecone import Pinecone
from pinecone_text.sparse import BM25Encoder

from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.retrievers import PineconeHybridSearchRetriever

from langchain_groq import ChatGroq

from flask import Flask, request
from flask_cors import CORS
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_socketio import SocketIO, emit

app = Flask(__name__)
CORS(app)
socketio = SocketIO(app, cors_allowed_origins="*")
app.config['SESSION_COOKIE_SECURE'] = True  # Use HTTPS
app.config['SESSION_COOKIE_HTTPONLY'] = True
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
app.config['SECRET_KEY'] = os.getenv('SECRET_KEY')

try:
    pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
    index_name = "traveler-demo-website-vectorstore"
    # connect to index
    pinecone_index = pc.Index(index_name)
except:
    pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
    index_name = "traveler-demo-website-vectorstore"
    # connect to index
    pinecone_index = pc.Index(index_name)

bm25 = BM25Encoder().load("bm25_traveler_website.json")

embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True})

retriever = PineconeHybridSearchRetriever(
    embeddings=embed_model, 
    sparse_encoder=bm25, 
    index=pinecone_index, 
    top_k=20, 
    alpha=0.5, 
)

llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2)

### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is.
"""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}")
    ]
)

history_aware_retriever = create_history_aware_retriever(
    llm, retriever, contextualize_q_prompt
)


qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \
Provide links to sources provided in the answer. \
If you don't know the answer, just say that you don't know. \
Do not give extra long answers. \
When responding to queries, your responses should be comprehensive and well-organized. For each response: \

    1. Provide Clear Answers \

    2. Include Detailed References: \
        - Include links to sources and any links or sites where there is a mentioned in the answer.
        - Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \
        - Downloadable Materials: Include links to any relevant downloadable resources if applicable. \
        - Reference Sites: Mention specific websites or platforms that offer additional information. \

    3. Formatting for Readability: \
        - Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \
        - Emphasize Important Information: Use bold or italics to highlight key details. \

    4. Organize Content Logically \

Do not include anything about context in the answer. \

{context}
"""
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}")
    ]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

### Statefully manage chat history ###
store = {}

def clean_temporary_data():
    store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

# Stream response to client
@socketio.on('message')
def handle_message(data):
    question = data.get('question')
    session_id = data.get('session_id', 'abc123')
    chain = conversational_rag_chain.pick("answer")
    
    try:
        for chunk in conversational_rag_chain.stream(
                {"input": question},
                config={
                    "configurable": {"session_id": "abc123"}
                },
            ):
                emit('response', chunk, room=request.sid)
    except:
        for chunk in conversational_rag_chain.stream(
                {"input": question},
                config={
                    "configurable": {"session_id": "abc123"}
                },
            ):
                emit('response', chunk, room=request.sid)

if __name__ == '__main__': 
    socketio.run(app, debug=True)