File size: 2,924 Bytes
2631838 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.vectorstores import Pinecone
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI
import pinecone
import chainlit as cl
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"),
)
index_name = "spark"
# Optional
namespace = None
embeddings = CohereEmbeddings(model='embed-english-light-v2.0',cohere_api_key=os.environ.get("COHERE_API_KEY"))
welcome_message = "Welcome to the Chainlit Pinecone demo! Ask anything about documents you vectorized and stored in your Pinecone DB."
@cl.langchain_factory(use_async=True)
async def langchain_factory():
await cl.Message(content=welcome_message).send()
docsearch = Pinecone.from_existing_index(
index_name=index_name, embedding=embeddings, namespace=namespace
)
chain = RetrievalQAWithSourcesChain.from_chain_type(
ChatOpenAI(temperature=0, streaming=True, verbose=True),
chain_type="stuff",
retriever=docsearch.as_retriever(max_tokens_limit=4097),
return_source_documents=True,
verbose=True
)
return chain
@cl.langchain_postprocess
async def process_response(res):
answer = res["answer"]
sources = res.get("sources", "").strip() # Use the get method with a default value
source_elements = []
docs = res.get("source_documents", None)
print('sources', sources)
if docs:
metadatas = [doc.metadata for doc in docs]
# Get the source names from the metadata
all_sources = [m["source"] for m in metadatas]
if sources:
found_sources = []
# For each source mentioned by the LLM
for source_index, source in enumerate(sources.split(",")):
# Remove the period and any whitespace
orig_source_name = source.strip().replace(".", "")
# The name that will be displayed in the UI
clean_source_name = f"source {source_index}"
try:
# Find the mentioned source in the list of all sources
found_index = all_sources.index(orig_source_name)
except ValueError:
continue
# Get the text from the source document
text = docs[found_index].page_content
found_sources.append(clean_source_name)
source_elements.append(cl.Text(content=text, name=clean_source_name))
if found_sources:
# Add the sources to the answer, referencing the text elements
answer += f"\nSources: {', '.join(found_sources)}"
else:
answer += "\nNo sources found"
# Send the answer and the text elements to the UI
await cl.Message(content=answer, elements=source_elements).send() |