SPARK / app /spark.py
amagastya's picture
Add wandb tracing
8f82b4e
raw
history blame contribute delete
No virus
3.62 kB
import os
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.vectorstores import Pinecone
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chat_models import ChatOpenAI
import pinecone
import chainlit as cl
from langchain.memory import ConversationTokenBufferMemory
from langchain.prompts import (
ChatPromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts.prompt import PromptTemplate
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.callbacks import get_openai_callback
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
from chainlit import user_session
from prompts import load_query_gen_prompt, load_spark_prompt
from chainlit import on_message, on_chat_start
import openai
from langchain.callbacks import ContextCallbackHandler
from promptwatch import PromptWatch
index_name = "spark"
spark = load_spark_prompt()
query_gen_prompt = load_query_gen_prompt()
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(query_gen_prompt)
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment='us-west1-gcp',
)
@on_chat_start
def init():
token = os.environ["CONTEXT_TOKEN"]
context_callback = ContextCallbackHandler(token)
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
os.environ["WANDB_PROJECT"] = "spark"
llm = ChatOpenAI(temperature=0.7, verbose=True, openai_api_key = os.environ.get("OPENAI_API_KEY"), streaming=True,
callbacks=[context_callback])
memory = ConversationTokenBufferMemory(llm=llm,memory_key="chat_history", return_messages=True,input_key='question',max_token_limit=1000)
embeddings = CohereEmbeddings(model='embed-english-light-v2.0',cohere_api_key=os.environ.get("COHERE_API_KEY"))
docsearch = Pinecone.from_existing_index(
index_name=index_name, embedding=embeddings
)
retriever = docsearch.as_retriever(search_kwargs={"k": 4})
# compressor = CohereRerank()
# reranker = ContextualCompressionRetriever(
# base_compressor=compressor, base_retriever=retriever
# )
messages = [SystemMessagePromptTemplate.from_template(spark)]
# print('mem', user_session.get('memory'))
messages.append(HumanMessagePromptTemplate.from_template("{question}"))
prompt = ChatPromptTemplate.from_messages(messages)
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True)
doc_chain = load_qa_with_sources_chain(llm, chain_type="stuff", verbose=True,prompt=prompt)
chain = ConversationalRetrievalChain(
retriever=retriever,
question_generator=question_generator,
combine_docs_chain=doc_chain,
verbose=True,
memory=memory,
rephrase_question=False,
callbacks=[context_callback]
)
cl.user_session.set("conversation_chain", chain)
@on_message
async def main(message: str):
with PromptWatch(api_key=os.environ.get("PROMPTWATCH_KEY")) as pw:
token = os.environ["CONTEXT_TOKEN"]
context_callback = ContextCallbackHandler(token)
chain = cl.user_session.get("conversation_chain")
res = await chain.arun({"question": message},callbacks=[cl.AsyncLangchainCallbackHandler(),
context_callback])
# Send the answer and the text elements to the UI
await cl.Message(content=res).send()