rag-10k / app.py
bia-anto's picture
Upload app.py
d2db699 verified
raw
history blame contribute delete
No virus
5 kB
## Setup
# Import the necessary Libraries
import os
import uuid
import json
import gradio as gr
from huggingface_hub import CommitScheduler
from pathlib import Path
from openai import OpenAI
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
# Create Client
api_key = os.environ.get('anyscale_gl')
client = OpenAI(base_url='https://api.endpoints.anyscale.com/v1', api_key=api_key)
# Define the embedding model and the vectorstore
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
vectorstore_persisted = Chroma(
collection_name='10k',
persist_directory='./10k_vdb',
embedding_function=embedding_model
)
# Prepare the logging functionality
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
log_folder = log_file.parent
scheduler = CommitScheduler(
repo_id="rag-10k",
repo_type="dataset",
folder_path=log_folder,
path_in_repo="data",
every=2
)
# Define the Q&A system message
qna_system_message = """
You serve as an assistant at a leading financial technology firm.
This innovative company leverages advanced analytics to offer insights for investment management and financial planning.
It manages a comprehensive database of 10-K reports from a broad spectrum of industry participants.
10-K reports are rich in details regarding financial performance, risk factors, market trends, and strategic initiatives.
However, the financial analysts at the firm often face challenges in quickly extracting actionable insights from these reports.
This difficulty creates a bottleneck, affecting the firm's capability to provide timely and precise recommendations to its clients.
Your primary role is to respond to user inquiries about annual reports.
Please note the following guidelines for handling user queries:
- Context Information: Each user input will include necessary context marked by the token `#Context`. You should use only this provided information to formulate your response. The context will be passed via the `context` variable.
- User Questions: Queries from users will start with the token `#Question`, and the specific question will be passed through the `question` variable.
- Your Responses: Begin your answers with the token `#Answer`. Ensure your responses are based solely on the context given. Do not include or refer to any part of the context directly in your answers. If the required information is not present in the context, simply respond with "I don't know".
"""
# Define the user message template
qna_user_message_template = """
#Context: {context}
#Question: {question}
"""
# Define the predict function that runs when 'Submit' is clicked or when a API request is made
def predict(textbox,company):
filter = "dataset/"+company+"-10-k-2023.pdf"
relevant_document_chunks = vectorstore_persisted.similarity_search(textbox, k=5, filter={"source":filter})
# Create context_for_query
context_for_query = ""
for doc in relevant_document_chunks:
page_content = doc.page_content.replace('\n', ' ')
page_number = doc.metadata['page']
contexts = f'{page_content} \nPage Number: {page_number}\n\n'
context_for_query = ''.join(contexts)
# Create messages
prompt = [
{'role':'system', 'content': qna_system_message},
{'role': 'user', 'content': qna_user_message_template.format(
context=context_for_query, question=textbox)}
]
base_model = 'mlabonne/NeuralHermes-2.5-Mistral-7B' #'gpt-3.5-turbo-16k'
# Get a response from the LLM
try:
response = client.chat.completions.create(
model=base_model, messages=prompt, temperature=0)
prediction = response.choices[0].message.content.strip()
except Exception as e:
prediction = f'Sorry, I encountered the following error: \n {e}'
# While the prediction is made, log both the inputs and outputs to a local log file
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
# access
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps(
{
'question': textbox,
'context': context_for_query,
'answer': prediction
}
))
f.write("\n")
return prediction
# Set-up the Gradio UI
# Add text box and radio button to the interface
# The radio button is used to select the company 10k report in which the context needs to be retrieved.
textbox = gr.Textbox(label='Query')
company = gr.Radio(['google', 'Meta', 'aws', 'IBM', 'msft'])
# Create the interface
demo = gr.Interface(fn = predict, inputs = [textbox,company], outputs = 'text',
description = 'RAG on 2023 10-k statement by Google, Meta, Amazon, IBM, and Microsoft',
title = '10-k Questioner')
demo.queue()
demo.launch()