Spaces:
Runtime error
Runtime error
| from modal import Stub, Image, Secret, asgi_app, method | |
| from urllib.request import urlretrieve | |
| from fastapi import FastAPI | |
| from typing import List, Dict | |
| image = Image.debian_slim("3.11").pip_install( | |
| "cohere", | |
| "gradio==3.50.2", | |
| "pinecone-client", | |
| ) | |
| stub = Stub("secsplorer", image=image) | |
| web_app = FastAPI() | |
| def fastapi_app(): | |
| import cohere | |
| import pinecone | |
| import os | |
| import uuid | |
| import gradio as gr | |
| from gradio.routes import mount_gradio_app | |
| print("Connecting to cohere client") | |
| co = cohere.Client(os.environ["COHERE_API_KEY"]) | |
| print("Done") | |
| pinecone.init(api_key=os.environ["PINECONE_API_KEY"], environment="us-west1-gcp") | |
| index = pinecone.Index(index_name="td-sec-embeddings") | |
| def retrieve( | |
| index: pinecone.Index, query: str, co: cohere.Client | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| Retrieves documents based on the given query. | |
| Parameters: | |
| query (str): The query to retrieve documents for. | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries representing the retrieved documents, with 'title', 'snippet', and 'url' keys. | |
| """ | |
| docs_retrieved = [] | |
| print(f"Calling retrieve for '{query}'") | |
| print("Embedding the query") | |
| query_emb = co.embed( | |
| texts=[query], model="embed-english-v3.0", input_type="search_query" | |
| ).embeddings | |
| print("Querying pinecone") | |
| res = index.query(query_emb, top_k=10, include_metadata=True) | |
| print("Preparing to rerank") | |
| docs_to_rerank = [match["metadata"] for match in res["matches"]] | |
| rerank_results = co.rerank( | |
| query=query, | |
| documents=docs_to_rerank, | |
| top_n=3, | |
| model="rerank-english-v2.0", | |
| ) | |
| docs_retrieved = [] | |
| for hit in rerank_results: | |
| docs_retrieved.append(docs_to_rerank[hit.index]) | |
| print("Returning retrieved docs") | |
| return docs_retrieved | |
| class Chatbot: | |
| def __init__(self, co: cohere.Client, index: pinecone.Index): | |
| self.index = index | |
| self.conversation_id = str(uuid.uuid4()) | |
| self.co = co | |
| def generate_response(self, message: str): | |
| """ | |
| Generates a response to the user's message. | |
| Parameters: | |
| message (str): The user's message. | |
| Yields: | |
| Event: A response event generated by the chatbot. | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries representing the retrieved documents. | |
| """ | |
| # Generate search queries (if any) | |
| response = self.co.chat(message=message, search_queries_only=True) | |
| # If there are search queries, retrieve documents and respond | |
| if response.search_queries: | |
| print("Retrieving information") | |
| documents = self.retrieve_docs(response) | |
| response = self.co.chat( | |
| message=message, | |
| documents=documents, | |
| conversation_id=self.conversation_id, | |
| stream=True, | |
| ) | |
| for event in response: | |
| yield event | |
| # If there is no search query, directly respond | |
| else: | |
| response = self.co.chat( | |
| message=message, conversation_id=self.conversation_id, stream=True | |
| ) | |
| for event in response: | |
| yield event | |
| def retrieve_docs(self, response) -> List[Dict[str, str]]: | |
| """ | |
| Retrieves documents based on the search queries in the response. | |
| Parameters: | |
| response: The response object containing search queries. | |
| Returns: | |
| List[Dict[str, str]]: A list of dictionaries representing the retrieved documents. | |
| """ | |
| # Get the query(s) | |
| queries = [] | |
| for search_query in response.search_queries: | |
| queries.append(search_query["text"]) | |
| # Retrieve documents for each query | |
| retrieved_docs = [] | |
| for query in queries: | |
| retrieved_docs.extend(retrieve(self.index, query, self.co)) | |
| return retrieved_docs | |
| chatbot = Chatbot(co, index) | |
| def chat_function(message, history): | |
| flag = False | |
| reply = "" | |
| for event in chatbot.generate_response(message): | |
| if event.event_type == "text-generation": | |
| reply += str(event.text) | |
| yield reply | |
| # Citations | |
| if event.event_type == "citation-generation": | |
| if not flag: | |
| reply += "\n\nCITATIONS:\n\n" | |
| yield reply | |
| flag = True | |
| reply += str(event.citations) + "\n" | |
| yield reply | |
| interface = gr.ChatInterface(chat_function).queue() | |
| print("All ready!") | |
| return mount_gradio_app(app=web_app, blocks=interface, path="/") | |