Spaces:
Sleeping
Sleeping
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
import pandas as pd | |
import os | |
import csv | |
import json | |
from pathlib import Path | |
from llama_index.llms.anyscale import Anyscale | |
# Load the sentence transformer model for embedding text | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Initialize the ChromaDB client for managing the vector database | |
chroma_client = chromadb.Client() | |
# Function to build the vector database from a CSV file | |
def build_database(): | |
# Read the CSV file containing document data | |
df = pd.read_csv('vector_store.csv') | |
print(df.head()) | |
# Name of the collection to store the data | |
collection_name = 'Dataset-10k-companies' | |
# Uncomment the line below to delete the existing collection if needed | |
# chroma_client.delete_collection(name=collection_name) | |
# Create a new collection in ChromaDB | |
collection = chroma_client.create_collection(name=collection_name) | |
# Add data from the DataFrame to the collection | |
collection.add( | |
documents=df['documents'].tolist(), | |
ids=df['ids'].tolist(), | |
metadatas=df['metadatas'].apply(eval).tolist(), | |
embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist() | |
) | |
return collection | |
# Build the database when the app starts | |
collection = build_database() | |
# Access the Anyscale API key from environment variables | |
anyscale_api_key = os.environ.get('anyscale_api_key') | |
# Instantiate the Anyscale client for using the Llama language model | |
client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf") | |
# Function to get relevant chunks from the database based on the query | |
def get_relevant_chunks(query, collection, top_n=3): | |
# Encode the query into an embedding | |
query_embedding = model.encode(query).tolist() | |
# Query the collection to get the top_n most relevant results | |
results = collection.query(query_embeddings=[query_embedding], n_results=top_n) | |
relevant_chunks = [] | |
# Extract relevant chunks and their metadata | |
for i in range(len(results['documents'][0])): | |
chunk = results['documents'][0][i] | |
source = results['metadatas'][0][i]['source'] | |
page = results['metadatas'][0][i]['page'] | |
relevant_chunks.append((chunk, source, page)) | |
return relevant_chunks | |
# System message template for the LLM to provide structured responses | |
qna_system_message = """ | |
You are an assistant to Finsights analysts. Your task is to provide relevant information about the financial performance of the companies followed by Finsights. | |
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context. | |
The context contains references to specific portions of documents relevant to the user's query, along with source links. | |
The source for a context will begin with the token: ###Source. | |
When crafting your response: | |
1. Select only the context relevant to answer the question. | |
2. Include the source links in your response. | |
3. User questions will begin with the token: ###Question. | |
4. If the question is irrelevant to Finsights, respond with: "I am an assistant for Finsight Docs. I can only help you with questions related to Finsights." | |
Adhere to the following guidelines: | |
- Your response should only address the question asked and nothing else. | |
- Answer only using the context provided. | |
- Do not mention anything about the context in your final answer. | |
- If the answer is not found in the context, respond with: "I don't know." | |
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source: | |
- Do not make up sources. Use only the links provided in the sources section of the context. You are prohibited from providing other links/sources. | |
Here is an example of how to structure your response: | |
Answer: | |
[Answer] | |
Source: | |
[Source] | |
""" | |
# User message template for passing context and question to the LLM | |
qna_user_message_template = """ | |
###Context | |
Here are some documents and their source links that are relevant to the question mentioned below. | |
{context} | |
###Question | |
{question} | |
""" | |
# Function to get a response from the LLM with retries | |
def get_llm_response(prompt, max_attempts=3): | |
full_response = "" | |
for attempt in range(max_attempts): | |
try: | |
# Generate a response from the LLM | |
response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible | |
chunk = response.text.strip() | |
full_response += chunk | |
if chunk.endswith((".", "!", "?")): # Check if the response seems complete | |
break | |
else: | |
# Continue the prompt from where it left off | |
prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context | |
except Exception as e: | |
print(f"Attempt {attempt + 1} failed with error: {e}") | |
return full_response | |
# Prediction function to handle user queries | |
def predict(company, user_query): | |
try: | |
# Modify the query to include the company name | |
modified_query = f"{user_query} for {company}" | |
# Get relevant chunks from the database | |
relevant_chunks = get_relevant_chunks(modified_query, collection) | |
# Prepare the context string from the relevant chunks | |
context = "" | |
for chunk, source, page in relevant_chunks: | |
context += chunk + "\n" | |
context += f"###Source {source}, Page {page}\n" | |
# Prepare the user message with context and question | |
user_message = qna_user_message_template.format(context=context, question=user_query) | |
# Craft the prompt for the Llama model | |
prompt = f"{qna_system_message}\n\n{qna_user_message_template.format(context=context, question=user_query)}" | |
# Generate the response using the Llama model through Anyscale | |
answer = get_llm_response(prompt) | |
# Log the interaction for future reference | |
log_interaction(company, user_query, context, answer) | |
return answer | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Function to log interactions in a JSON lines file | |
def log_interaction(company, user_query, context, answer): | |
log_file = Path("interaction_log.jsonl") | |
with log_file.open("a") as f: | |
json.dump({ | |
'company': company, | |
'user_query': user_query, | |
'context': context, | |
'answer': answer | |
}, f) | |
f.write("\n") | |
# Create Gradio interface for user interaction | |
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"] | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Radio(company_list, label="Select Company"), | |
gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query") | |
], | |
outputs=gr.Textbox(label="Generated Answer"), | |
title="Company Reports Q&A", | |
description="Query the vector database and get an LLM response based on the documents in the collection." | |
) | |
# Launch the Gradio interface | |
iface.launch(share=True) |