secsplorer / simple_script.py
lagerbaer's picture
Upload folder using huggingface_hub
bcf7c58 verified
import cohere
import os
import pinecone
import uuid
from typing import List, Dict
from dotenv import load_dotenv
load_dotenv()
co = cohere.Client(os.environ["COHERE_API_KEY"])
pc = pinecone.Pinecone(api_key=os.environ["PINECONE_API_KEY"])
index = pc.Index("td-sec-embeddings")
def retrieve(index: pinecone.Index, query: str) -> List[Dict[str, str]]:
"""
Retrieves documents based on the given query.
Parameters:
query (str): The query to retrieve documents for.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'snippet', and 'url' keys.
"""
docs_retrieved = []
query_emb = co.embed(
texts=[query], model="embed-english-v3.0", input_type="search_query"
).embeddings
res = index.query(vector=query_emb, top_k=100, include_metadata=True)
docs_to_rerank = [match["metadata"] for match in res["matches"]]
rerank_results = co.rerank(
query=query,
documents=docs_to_rerank,
top_n=5,
model="rerank-english-v2.0",
)
docs_retrieved = []
for hit in rerank_results.results:
docs_retrieved.append(docs_to_rerank[hit.index])
return docs_retrieved
class Chatbot:
def __init__(self, co: cohere.Client, index: pinecone.Index):
self.index = index
self.conversation_id = str(uuid.uuid4())
self.co = co
self.docs = None
def send_initial_instructions(self):
response = self.co.chat_stream(
message="""You are an expert in TD Bank's annual reports and have access to the 2023 and 2022 annual report. Respond with a polite welcome message.""",
conversation_id=self.conversation_id,
)
return response
def generate_response(self, message: str):
"""
Generates a response to the user's message.
Parameters:
message (str): The user's message.
Yields:
Event: A response event generated by the chatbot.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.
"""
# Generate search queries (if any)
response = self.co.chat(
message=message,
search_queries_only=True,
conversation_id=self.conversation_id,
)
# If there are search queries, retrieve documents and respond
if response.search_queries:
print("Retrieving information...")
documents = self.retrieve_docs(response)
self.docs = {f"doc_{i}": document for i, document in enumerate(documents)}
response = self.co.chat_stream(
message=message,
documents=documents,
conversation_id=self.conversation_id,
)
for event in response:
yield event
# If there is no search query, directly respond
else:
response = self.co.chat_stream(
message=message,
conversation_id=self.conversation_id,
)
for event in response:
yield event
def retrieve_docs(self, response) -> List[Dict[str, str]]:
"""
Retrieves documents based on the search queries in the response.
Parameters:
response: The response object containing search queries.
Returns:
List[Dict[str, str]]: A list of dictionaries representing the retrieved documents.
"""
# Get the query(s)
queries = []
for search_query in response.search_queries:
queries.append(search_query.text)
# Retrieve documents for each query
retrieved_docs = []
for query in queries:
retrieved_docs.extend(retrieve(self.index, query))
return retrieved_docs
import gradio as gr
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
cohere_chatbot_var = gr.State()
def user(user_message, history):
return "", history + [[user_message, None]]
def chat_function(history, cohere_chatbot):
if cohere_chatbot is None:
cohere_chatbot = Chatbot(co, index)
response = cohere_chatbot.send_initial_instructions()
history = [[None, ""]]
for event in response:
if event.event_type == "text-generation":
history[0][1] += str(event.text)
yield history, cohere_chatbot
return
message = history[-1][0]
history[-1][1] = ""
documents_used = set()
flag = True
for event in cohere_chatbot.generate_response(message):
if event.event_type == "text-generation":
history[-1][1] += str(event.text)
yield history, cohere_chatbot
# Citations
if event.event_type == "citation-generation":
if flag:
history[-1][1] += "\n\n**DOCUMENTS CONSULTED:**\n\n"
yield history, cohere_chatbot
flag = False
for citation in event.citations:
documents_used.update(citation.document_ids)
urls_used = set(cohere_chatbot.docs[doc_id]["url"] for doc_id in documents_used)
for url in sorted(urls_used):
history[-1][1] += f"* {url}\n"
yield history, cohere_chatbot
# Make sure we run the thing once to initialize!
demo.load(
chat_function, [chatbot, cohere_chatbot_var], [chatbot, cohere_chatbot_var]
)
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
chat_function, [chatbot, cohere_chatbot_var], [chatbot, cohere_chatbot_var]
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch()