bia-anto commited on
Commit
e3ccde2
1 Parent(s): 62fdf1c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ## Setup
4
+
5
+ # Import the necessary Libraries
6
+ import os
7
+ import uuid
8
+ import joblib
9
+ import json
10
+ import gradio as gr
11
+ import pandas as pd
12
+ from huggingface_hub import CommitScheduler
13
+ from pathlib import Path
14
+ from google.colab import userdata, drive
15
+ from openai import OpenAI
16
+ from langchain_community.vectorstores import Chroma
17
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
18
+
19
+ # Create Client
20
+ client = OpenAI(base_url='https://api.endpoints.anyscale.com/v1', api_key=userdata.get('anyscale_gl'))
21
+
22
+ # Define the embedding model and the vectorstore
23
+ embedding_model = 'thenlper/gte-large'
24
+ base_model = 'mlabonne/NeuralHermes-2.5-Mistral-7B'
25
+
26
+ # Load the persisted vectorDB
27
+ vectorstore_persisted = Chroma(
28
+ collection_name = '10k', persist_directory = '/content/drive/MyDrive/10k_vdb',
29
+ embedding_function = SentenceTransformerEmbeddings(model_name=embedding_model)
30
+ )
31
+
32
+ # Prepare the logging functionality
33
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
34
+ log_folder = log_file.parent
35
+ scheduler = CommitScheduler(
36
+ repo_id="rag-10k",
37
+ repo_type="dataset",
38
+ folder_path=log_folder,
39
+ path_in_repo="data",
40
+ every=2
41
+ )
42
+
43
+ # Define the Q&A system message
44
+ qna_system_message = """
45
+ You serve as an assistant at a leading financial technology firm.
46
+ This innovative company leverages advanced analytics to offer insights for investment management and financial planning.
47
+ It manages a comprehensive database of 10-K reports from a broad spectrum of industry participants.
48
+ 10-K reports are rich in details regarding financial performance, risk factors, market trends, and strategic initiatives.
49
+ However, the financial analysts at the firm often face challenges in quickly extracting actionable insights from these reports.
50
+ This difficulty creates a bottleneck, affecting the firm's capability to provide timely and precise recommendations to its clients.
51
+ Your primary role is to respond to user inquiries about annual reports.
52
+
53
+ Please note the following guidelines for handling user queries:
54
+ - Context Information: Each user input will include necessary context marked by the token `#Context`. You should use only this provided information to formulate your response. The context will be passed via the `context` variable.
55
+ - User Questions: Queries from users will start with the token `#Question`, and the specific question will be passed through the `question` variable.
56
+ - Your Responses: Begin your answers with the token `#Answer`. Ensure your responses are based solely on the context given. Do not include or refer to any part of the context directly in your answers. If the required information is not present in the context, simply respond with "I don't know".
57
+ """
58
+
59
+ # Define the user message template
60
+ qna_user_message_template = """
61
+ #Context: {context}
62
+
63
+ #Question: {question}
64
+ """
65
+
66
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
67
+ def predict(textbox,company):
68
+
69
+ #filter = # "dataset/"+company+"-10-k-2023.pdf"
70
+ relevant_document_chunks = vectorstore_persisted.similarity_search(textbox, k=5, filter={"source":company})
71
+
72
+ # Create context_for_query
73
+ context_for_query = ""
74
+ for doc in relevant_document_chunks:
75
+ page_content = doc.page_content.replace('\n', ' ')
76
+ page_number = doc.metadata['page']
77
+ contexts = f'{page_content} \nPage Number: {page_number}\n\n'
78
+ context_for_query = ''.join(contexts)
79
+
80
+ # Create messages
81
+ prompt = [
82
+ {'role':'system', 'content': qna_system_message},
83
+ {'role': 'user', 'content': qna_user_message_template.format(
84
+ context=context_for_query, question=textbox)}
85
+ ]
86
+
87
+ # Get a response from the LLM
88
+ try:
89
+ response = client.chat.completions.create(
90
+ model=base_model, messages=prompt, temperature=0)
91
+ prediction = response.choices[0].message.content.strip()
92
+ except Exception as e:
93
+ prediction = f'Sorry, I encountered the following error: \n {e}'
94
+
95
+ # While the prediction is made, log both the inputs and outputs to a local log file
96
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
97
+ # access
98
+ with scheduler.lock:
99
+ with log_file.open("a") as f:
100
+ f.write(json.dumps(
101
+ {
102
+ 'question': textbox,
103
+ 'context': context_for_query,
104
+ 'answer': prediction
105
+ }
106
+ ))
107
+ f.write("\n")
108
+ return prediction
109
+
110
+ # Set-up the Gradio UI
111
+ # Add text box and radio button to the interface
112
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
113
+ textbox = gr.Textbox(label='Query')
114
+ company = gr.Radio([('google', '/content/dataset/google-10-k-2023.pdf'),
115
+ ('Meta', '/content/dataset/Meta-10-k-2023.pdf'),
116
+ ('aws', '/content/dataset/aws-10-k-2023.pdf'),
117
+ ('IBM', '/content/dataset/IBM-10-k-2023.pdf'),
118
+ ('msft', '/content/dataset/msft-10-k-2023.pdf')]
119
+ label='Company')
120
+
121
+ # Create the interface
122
+ demo = gr.Interface(fn = predict, inputs = [textbox,company], outputs = prediction,
123
+ description = 'RAG on 2023 10-k statement by Google, Meta, Amazon, IBM, and Microsoft'
124
+ title = '10-k Questioner')
125
+
126
+ demo.queue()
127
+ demo.launch()