rolwinpinto commited on
Commit
cfff27b
1 Parent(s): d226ff4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -39
app.py CHANGED
@@ -3,29 +3,27 @@ import streamlit as st
3
  import PyPDF2
4
  import matplotlib.pyplot as plt
5
  from io import BytesIO
6
- from llama_index.embeddings import HuggingFaceEmbedding
7
- from llama_index.schema import Document
8
- from sklearn.metrics.pairwise import cosine_similarity
9
- import numpy as np
10
- import dotenv
11
  import re
12
  import requests
 
13
 
14
  # Load environment variables
15
  dotenv.load_dotenv()
16
 
17
- # Configure Hugging Face API
18
  API_URL = "https://api-inference.huggingface.co/models/sarvamai/sarvam-2b-v0.5"
19
  headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"}
20
 
21
  # Configure embedding model
22
- embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
23
 
24
  def query_huggingface_api(payload):
25
  response = requests.post(API_URL, headers=headers, json=payload)
26
  return response.json()
27
 
28
- def write_to_file(content, filename="./files/test.pdf"):
29
  os.makedirs(os.path.dirname(filename), exist_ok=True)
30
  with open(filename, "wb") as f:
31
  f.write(content)
@@ -35,10 +33,9 @@ def extract_financial_data(document_text):
35
  "Revenue": [],
36
  "Date": []
37
  }
38
-
39
  lines = document_text.split("\n")
40
  revenue_pattern = re.compile(r'\$?\d+(?:,\d{3})*(?:\.\d+)?')
41
-
42
  for i, line in enumerate(lines):
43
  if any(keyword in line.lower() for keyword in ["revenue", "total revenue", "sales"]):
44
  for j in range(i + 1, i + 6):
@@ -51,7 +48,7 @@ def extract_financial_data(document_text):
51
  financial_data["Revenue"].append(value)
52
  except ValueError:
53
  continue
54
-
55
  if "Q1" in line or "Q2" in line or "Q3" in line or "Q4" in line or re.search(r'FY\s*\d{4}', line):
56
  financial_data["Date"].append(line.strip())
57
 
@@ -61,7 +58,17 @@ def extract_financial_data(document_text):
61
 
62
  return financial_data
63
 
64
- def generate_summary(document_text, query):
 
 
 
 
 
 
 
 
 
 
65
  prompt = f"""
66
  You are a financial analyst. Your task is to provide a comprehensive analysis of the financial document.
67
  Analyze the following document and respond to the query:
@@ -99,26 +106,6 @@ def generate_comparison_graph(data):
99
  plt.tight_layout()
100
  st.pyplot(fig)
101
 
102
- def search_similar_sections(document_text, query, top_k=3):
103
- # Split the document into sections (you may need to adjust this based on your document structure)
104
- sections = document_text.split('\n\n')
105
-
106
- # Create Document objects for each section
107
- documents = [Document(text=section) for section in sections]
108
-
109
- # Compute embeddings for the query and all sections
110
- query_embedding = embed_model.get_text_embedding(query)
111
- section_embeddings = [embed_model.get_text_embedding(doc.text) for doc in documents]
112
-
113
- # Compute cosine similarities
114
- similarities = cosine_similarity([query_embedding], section_embeddings)[0]
115
-
116
- # Get indices of top-k similar sections
117
- top_indices = np.argsort(similarities)[-top_k:][::-1]
118
-
119
- # Return top-k similar sections
120
- return [sections[i] for i in top_indices]
121
-
122
  # Streamlit app
123
  def main():
124
  st.title("Fortune 500 Financial Document Analyzer")
@@ -142,20 +129,18 @@ def main():
142
  # Extract financial data
143
  financial_data = extract_financial_data(document_text)
144
 
 
 
 
 
145
  # Add a provision for user query input
146
  query = st.text_input("Enter your financial analysis query (e.g., 'What are the revenue trends?')", "")
147
 
148
  if query:
149
- summary = generate_summary(document_text, query)
150
  st.write("## Financial Analysis Result")
151
  st.write(summary)
152
 
153
- st.write("## Relevant Document Sections")
154
- similar_sections = search_similar_sections(document_text, query)
155
- for i, section in enumerate(similar_sections, 1):
156
- st.write(f"### Section {i}")
157
- st.write(section)
158
-
159
  # Display revenue comparison graph
160
  if financial_data["Revenue"] and financial_data["Date"]:
161
  st.write("## Revenue Comparison")
 
3
  import PyPDF2
4
  import matplotlib.pyplot as plt
5
  from io import BytesIO
6
+ from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader
7
+ from llama_index.embeddings.fastembed import FastEmbedEmbedding
 
 
 
8
  import re
9
  import requests
10
+ import dotenv
11
 
12
  # Load environment variables
13
  dotenv.load_dotenv()
14
 
15
+ # Configure Hugging Face API for Sarvam model
16
  API_URL = "https://api-inference.huggingface.co/models/sarvamai/sarvam-2b-v0.5"
17
  headers = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"}
18
 
19
  # Configure embedding model
20
+ Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
21
 
22
  def query_huggingface_api(payload):
23
  response = requests.post(API_URL, headers=headers, json=payload)
24
  return response.json()
25
 
26
+ def write_to_file(content, filename="./files/uploaded.pdf"):
27
  os.makedirs(os.path.dirname(filename), exist_ok=True)
28
  with open(filename, "wb") as f:
29
  f.write(content)
 
33
  "Revenue": [],
34
  "Date": []
35
  }
 
36
  lines = document_text.split("\n")
37
  revenue_pattern = re.compile(r'\$?\d+(?:,\d{3})*(?:\.\d+)?')
38
+
39
  for i, line in enumerate(lines):
40
  if any(keyword in line.lower() for keyword in ["revenue", "total revenue", "sales"]):
41
  for j in range(i + 1, i + 6):
 
48
  financial_data["Revenue"].append(value)
49
  except ValueError:
50
  continue
51
+
52
  if "Q1" in line or "Q2" in line or "Q3" in line or "Q4" in line or re.search(r'FY\s*\d{4}', line):
53
  financial_data["Date"].append(line.strip())
54
 
 
58
 
59
  return financial_data
60
 
61
+ def ingest_documents():
62
+ reader = SimpleDirectoryReader("./files/")
63
+ documents = reader.load_data()
64
+ return documents
65
+
66
+ def load_data(documents):
67
+ index = VectorStoreIndex.from_documents(documents)
68
+ return index
69
+
70
+ def generate_summary(index, document_text, query):
71
+ query_engine = index.as_query_engine()
72
  prompt = f"""
73
  You are a financial analyst. Your task is to provide a comprehensive analysis of the financial document.
74
  Analyze the following document and respond to the query:
 
106
  plt.tight_layout()
107
  st.pyplot(fig)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # Streamlit app
110
  def main():
111
  st.title("Fortune 500 Financial Document Analyzer")
 
129
  # Extract financial data
130
  financial_data = extract_financial_data(document_text)
131
 
132
+ # Ingest documents for summarization and query-driven analysis
133
+ documents = ingest_documents()
134
+ index = load_data(documents)
135
+
136
  # Add a provision for user query input
137
  query = st.text_input("Enter your financial analysis query (e.g., 'What are the revenue trends?')", "")
138
 
139
  if query:
140
+ summary = generate_summary(index, document_text, query)
141
  st.write("## Financial Analysis Result")
142
  st.write(summary)
143
 
 
 
 
 
 
 
144
  # Display revenue comparison graph
145
  if financial_data["Revenue"] and financial_data["Date"]:
146
  st.write("## Revenue Comparison")