|
import os |
|
from langchain.embeddings.cohere import CohereEmbeddings |
|
from langchain.vectorstores import Pinecone |
|
from langchain.chains import ConversationalRetrievalChain, LLMChain |
|
from langchain.chat_models import ChatOpenAI |
|
import pinecone |
|
import chainlit as cl |
|
from langchain.memory import ConversationTokenBufferMemory |
|
from langchain.prompts import ( |
|
ChatPromptTemplate, |
|
PromptTemplate, |
|
SystemMessagePromptTemplate, |
|
HumanMessagePromptTemplate, |
|
) |
|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
|
from langchain.callbacks import get_openai_callback |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.retrievers.document_compressors import CohereRerank |
|
from chainlit import user_session |
|
from prompts import load_query_gen_prompt, load_spark_prompt |
|
from chainlit import on_message, on_chat_start |
|
import openai |
|
from langchain.callbacks import ContextCallbackHandler |
|
from promptwatch import PromptWatch |
|
|
|
|
|
index_name = "spark" |
|
|
|
spark = load_spark_prompt() |
|
query_gen_prompt = load_query_gen_prompt() |
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(query_gen_prompt) |
|
pinecone.init( |
|
api_key=os.environ.get("PINECONE_API_KEY"), |
|
environment='us-west1-gcp', |
|
) |
|
@on_chat_start |
|
def init(): |
|
token = os.environ["CONTEXT_TOKEN"] |
|
context_callback = ContextCallbackHandler(token) |
|
llm = ChatOpenAI(temperature=0.7, verbose=True, openai_api_key = os.environ.get("OPENAI_API_KEY"), streaming=True, |
|
callbacks=[context_callback]) |
|
memory = ConversationTokenBufferMemory(llm=llm,memory_key="chat_history", return_messages=True,input_key='question',max_token_limit=1000) |
|
embeddings = CohereEmbeddings(model='embed-english-light-v2.0',cohere_api_key=os.environ.get("COHERE_API_KEY")) |
|
|
|
docsearch = Pinecone.from_existing_index( |
|
index_name=index_name, embedding=embeddings |
|
) |
|
retriever = docsearch.as_retriever(search_kwargs={"k": 4}) |
|
|
|
|
|
|
|
|
|
messages = [SystemMessagePromptTemplate.from_template(spark)] |
|
|
|
messages.append(HumanMessagePromptTemplate.from_template("{question}")) |
|
prompt = ChatPromptTemplate.from_messages(messages) |
|
|
|
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True) |
|
doc_chain = load_qa_with_sources_chain(llm, chain_type="stuff", verbose=True,prompt=prompt) |
|
|
|
chain = ConversationalRetrievalChain( |
|
retriever=retriever, |
|
question_generator=question_generator, |
|
combine_docs_chain=doc_chain, |
|
verbose=True, |
|
memory=memory, |
|
rephrase_question=False, |
|
callbacks=[context_callback] |
|
) |
|
cl.user_session.set("conversation_chain", chain) |
|
|
|
|
|
@on_message |
|
async def main(message: str): |
|
with PromptWatch(api_key=os.environ.get("PROMPTWATCH_KEY")) as pw: |
|
token = os.environ["CONTEXT_TOKEN"] |
|
context_callback = ContextCallbackHandler(token) |
|
chain = cl.user_session.get("conversation_chain") |
|
res = await chain.arun({"question": message},callbacks=[cl.AsyncLangchainCallbackHandler(), |
|
context_callback]) |
|
|
|
await cl.Message(content=res).send() |