Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
## Setup
|
4 |
+
|
5 |
+
# Import the necessary Libraries
|
6 |
+
import os
|
7 |
+
import uuid
|
8 |
+
import joblib
|
9 |
+
import json
|
10 |
+
import gradio as gr
|
11 |
+
import pandas as pd
|
12 |
+
from huggingface_hub import CommitScheduler
|
13 |
+
from pathlib import Path
|
14 |
+
from google.colab import userdata, drive
|
15 |
+
from openai import OpenAI
|
16 |
+
from langchain_community.vectorstores import Chroma
|
17 |
+
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
18 |
+
|
19 |
+
# Create Client
|
20 |
+
client = OpenAI(base_url='https://api.endpoints.anyscale.com/v1', api_key=userdata.get('anyscale_gl'))
|
21 |
+
|
22 |
+
# Define the embedding model and the vectorstore
|
23 |
+
embedding_model = 'thenlper/gte-large'
|
24 |
+
base_model = 'mlabonne/NeuralHermes-2.5-Mistral-7B'
|
25 |
+
|
26 |
+
# Load the persisted vectorDB
|
27 |
+
vectorstore_persisted = Chroma(
|
28 |
+
collection_name = '10k', persist_directory = '/content/drive/MyDrive/10k_vdb',
|
29 |
+
embedding_function = SentenceTransformerEmbeddings(model_name=embedding_model)
|
30 |
+
)
|
31 |
+
|
32 |
+
# Prepare the logging functionality
|
33 |
+
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
|
34 |
+
log_folder = log_file.parent
|
35 |
+
scheduler = CommitScheduler(
|
36 |
+
repo_id="rag-10k",
|
37 |
+
repo_type="dataset",
|
38 |
+
folder_path=log_folder,
|
39 |
+
path_in_repo="data",
|
40 |
+
every=2
|
41 |
+
)
|
42 |
+
|
43 |
+
# Define the Q&A system message
|
44 |
+
qna_system_message = """
|
45 |
+
You serve as an assistant at a leading financial technology firm.
|
46 |
+
This innovative company leverages advanced analytics to offer insights for investment management and financial planning.
|
47 |
+
It manages a comprehensive database of 10-K reports from a broad spectrum of industry participants.
|
48 |
+
10-K reports are rich in details regarding financial performance, risk factors, market trends, and strategic initiatives.
|
49 |
+
However, the financial analysts at the firm often face challenges in quickly extracting actionable insights from these reports.
|
50 |
+
This difficulty creates a bottleneck, affecting the firm's capability to provide timely and precise recommendations to its clients.
|
51 |
+
Your primary role is to respond to user inquiries about annual reports.
|
52 |
+
|
53 |
+
Please note the following guidelines for handling user queries:
|
54 |
+
- Context Information: Each user input will include necessary context marked by the token `#Context`. You should use only this provided information to formulate your response. The context will be passed via the `context` variable.
|
55 |
+
- User Questions: Queries from users will start with the token `#Question`, and the specific question will be passed through the `question` variable.
|
56 |
+
- Your Responses: Begin your answers with the token `#Answer`. Ensure your responses are based solely on the context given. Do not include or refer to any part of the context directly in your answers. If the required information is not present in the context, simply respond with "I don't know".
|
57 |
+
"""
|
58 |
+
|
59 |
+
# Define the user message template
|
60 |
+
qna_user_message_template = """
|
61 |
+
#Context: {context}
|
62 |
+
|
63 |
+
#Question: {question}
|
64 |
+
"""
|
65 |
+
|
66 |
+
# Define the predict function that runs when 'Submit' is clicked or when a API request is made
|
67 |
+
def predict(textbox,company):
|
68 |
+
|
69 |
+
#filter = # "dataset/"+company+"-10-k-2023.pdf"
|
70 |
+
relevant_document_chunks = vectorstore_persisted.similarity_search(textbox, k=5, filter={"source":company})
|
71 |
+
|
72 |
+
# Create context_for_query
|
73 |
+
context_for_query = ""
|
74 |
+
for doc in relevant_document_chunks:
|
75 |
+
page_content = doc.page_content.replace('\n', ' ')
|
76 |
+
page_number = doc.metadata['page']
|
77 |
+
contexts = f'{page_content} \nPage Number: {page_number}\n\n'
|
78 |
+
context_for_query = ''.join(contexts)
|
79 |
+
|
80 |
+
# Create messages
|
81 |
+
prompt = [
|
82 |
+
{'role':'system', 'content': qna_system_message},
|
83 |
+
{'role': 'user', 'content': qna_user_message_template.format(
|
84 |
+
context=context_for_query, question=textbox)}
|
85 |
+
]
|
86 |
+
|
87 |
+
# Get a response from the LLM
|
88 |
+
try:
|
89 |
+
response = client.chat.completions.create(
|
90 |
+
model=base_model, messages=prompt, temperature=0)
|
91 |
+
prediction = response.choices[0].message.content.strip()
|
92 |
+
except Exception as e:
|
93 |
+
prediction = f'Sorry, I encountered the following error: \n {e}'
|
94 |
+
|
95 |
+
# While the prediction is made, log both the inputs and outputs to a local log file
|
96 |
+
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
|
97 |
+
# access
|
98 |
+
with scheduler.lock:
|
99 |
+
with log_file.open("a") as f:
|
100 |
+
f.write(json.dumps(
|
101 |
+
{
|
102 |
+
'question': textbox,
|
103 |
+
'context': context_for_query,
|
104 |
+
'answer': prediction
|
105 |
+
}
|
106 |
+
))
|
107 |
+
f.write("\n")
|
108 |
+
return prediction
|
109 |
+
|
110 |
+
# Set-up the Gradio UI
|
111 |
+
# Add text box and radio button to the interface
|
112 |
+
# The radio button is used to select the company 10k report in which the context needs to be retrieved.
|
113 |
+
textbox = gr.Textbox(label='Query')
|
114 |
+
company = gr.Radio([('google', '/content/dataset/google-10-k-2023.pdf'),
|
115 |
+
('Meta', '/content/dataset/Meta-10-k-2023.pdf'),
|
116 |
+
('aws', '/content/dataset/aws-10-k-2023.pdf'),
|
117 |
+
('IBM', '/content/dataset/IBM-10-k-2023.pdf'),
|
118 |
+
('msft', '/content/dataset/msft-10-k-2023.pdf')]
|
119 |
+
label='Company')
|
120 |
+
|
121 |
+
# Create the interface
|
122 |
+
demo = gr.Interface(fn = predict, inputs = [textbox,company], outputs = prediction,
|
123 |
+
description = 'RAG on 2023 10-k statement by Google, Meta, Amazon, IBM, and Microsoft'
|
124 |
+
title = '10-k Questioner')
|
125 |
+
|
126 |
+
demo.queue()
|
127 |
+
demo.launch()
|