kgauvin603 commited on
Commit
7df12b8
1 Parent(s): adbfa0a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer
3
+ import chromadb
4
+ import pandas as pd
5
+ import os
6
+ import json
7
+ from pathlib import Path
8
+ from llama_index.llms.anyscale import Anyscale
9
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
10
+
11
+ # Load the sentence transformer model with 384 dimensions
12
+ model = SentenceTransformer('all-MiniLM-L6-v2')
13
+ # model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') # Ensure this model outputs 384-dimensional embeddings
14
+
15
+ # Check the dimensionality of the model
16
+ embedding_dim = model.get_sentence_embedding_dimension()
17
+ print(f"Embedding Dimension: {embedding_dim}") # This should print 384
18
+
19
+ # Initialize the ChromaDB client
20
+ chroma_client = chromadb.Client()
21
+
22
+ # Function to build the database from CSV
23
+ def build_database():
24
+ # Read the CSV file
25
+ df = pd.read_csv('vector_store.csv')
26
+
27
+ # Create a collection
28
+ collection_name = 'Dataset-10k-companies'
29
+
30
+ # Delete the existing collection if it exists
31
+ # chroma_client.delete_collection(name=collection_name)
32
+
33
+ # Create a new collection
34
+ collection = chroma_client.create_collection(name=collection_name)
35
+
36
+ # Ensure embeddings are 384-dimensional
37
+ def ensure_dimensionality(embedding):
38
+ embedding = eval(embedding.replace(',,', ','))
39
+ if len(embedding) != 384:
40
+ raise ValueError("Embedding dimensionality is incorrect")
41
+ return embedding
42
+
43
+ # Add the data from the DataFrame to the collection
44
+ collection.add(
45
+ documents=df['documents'].tolist(),
46
+ ids=df['ids'].tolist(),
47
+ metadatas=df['metadatas'].apply(eval).tolist(),
48
+ embeddings=df['embeddings'].apply(ensure_dimensionality).tolist()
49
+ )
50
+
51
+ return collection
52
+
53
+ # Build the database when the app starts
54
+ collection = build_database()
55
+
56
+ # Access the Anyscale API key from environment variables
57
+ anyscale_api_key = os.environ.get('anyscale_api_key')
58
+
59
+ # Instantiate the Anyscale client
60
+ client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf")
61
+
62
+ # Function to get relevant chunks
63
+ def get_relevant_chunks(query, collection, top_n=3):
64
+ query_embedding = model.encode(query).tolist()
65
+ results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
66
+
67
+ relevant_chunks = []
68
+ for i in range(len(results['documents'][0])):
69
+ chunk = results['documents'][0][i]
70
+ source = results['metadatas'][0][i]['source']
71
+ page = results['metadatas'][0][i]['page']
72
+ relevant_chunks.append((chunk, source, page))
73
+
74
+ return relevant_chunks
75
+
76
+ # Define system message for LLM
77
+ qna_system_message = """
78
+ You are an assistant to Finsights analysts. Your task is to provide relevant information about the financial performance of the companies followed by Finsights.
79
+ User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
80
+ The context contains references to specific portions of documents relevant to the user's query, along with source links.
81
+ The source for a context will begin with the token: ###Source.
82
+ When crafting your response:
83
+ 1. Select only the context relevant to answer the question.
84
+ 2. Include the source links in your response.
85
+ 3. User questions will begin with the token: ###Question.
86
+ 4. If the question is irrelevant to Finsights, respond with: "I am an assistant for Finsight Docs. I can only help you with questions related to Finsights."
87
+ Adhere to the following guidelines:
88
+ - Your response should only address the question asked and nothing else.
89
+ - Answer only using the context provided.
90
+ - Do not mention anything about the context in your final answer.
91
+ - If the answer is not found in the context, respond with: "I don't know."
92
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
93
+ - Do not make up sources. Use only the links provided in the sources section of the context. You are prohibited from providing other links/sources.
94
+ Here is an example of how to structure your response:
95
+ Answer:
96
+ [Answer]
97
+ Source:
98
+ [Source]
99
+ """
100
+ # Create a user message template
101
+ qna_user_message_template = """
102
+ ###Context
103
+ Here are some documents and their source links that are relevant to the question mentioned below.
104
+ {context}
105
+ ###Question
106
+ {question}
107
+ """
108
+
109
+ # Function to get LLM response
110
+ def get_llm_response(prompt, max_attempts=3):
111
+ full_response = ""
112
+ for attempt in range(max_attempts):
113
+ try:
114
+ response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible
115
+ chunk = response.text.strip()
116
+ full_response += chunk
117
+ if chunk.endswith((".", "!", "?")): # Check if response seems complete
118
+ break
119
+ else:
120
+ prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context
121
+ except Exception as e:
122
+ print(f"Attempt {attempt + 1} failed with error: {e}")
123
+ return full_response
124
+
125
+ # Prediction function
126
+ def predict(company, user_query):
127
+ try:
128
+ # Modify the query to include the company name
129
+ modified_query = f"{user_query} for {company}"
130
+
131
+ # Get relevant chunks
132
+ relevant_chunks = get_relevant_chunks(modified_query, collection)
133
+
134
+ # Prepare the context string
135
+ context = ""
136
+ for chunk, source, page in relevant_chunks:
137
+ context += chunk + "\n"
138
+ context += f"###Source {source}, Page {page}\n"
139
+
140
+ # Prepare the user message
141
+ user_message = qna_user_message_template.format(context=context, question=user_query)
142
+
143
+ # Craft the prompt to pass to the Llama model
144
+ prompt = f"{qna_system_message}\n\n{qna_user_message_template.format(context=context, question=user_query)}"
145
+
146
+ # Generate the response using the Llama model through Anyscale
147
+ answer = get_llm_response(prompt)
148
+
149
+ # Extract the generated response
150
+ # answer = response.text.strip()
151
+
152
+ # Log the interaction
153
+ log_interaction(company, user_query, context, answer)
154
+
155
+ return answer
156
+ except Exception as e:
157
+ return f"An error occurred: {str(e)}"
158
+
159
+ # Function to log interactions
160
+ def log_interaction(company, user_query, context, answer):
161
+ log_file = Path("interaction_log.jsonl")
162
+ with log_file.open("a") as f:
163
+ json.dump({
164
+ 'company': company,
165
+ 'user_query': user_query,
166
+ 'context': context,
167
+ 'answer': answer
168
+ }, f)
169
+ f.write("\n")
170
+
171
+ # Create Gradio interface
172
+ company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
173
+ iface = gr.Interface(
174
+ fn=predict,
175
+ inputs=[
176
+ gr.Radio(company_list, label="Select Company"),
177
+ gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
178
+ ],
179
+ outputs=gr.Textbox(label="Generated Answer"),
180
+ title="Company Reports Q&A",
181
+ description="Query the vector database and get an LLM response based on the documents in the collection."
182
+ )
183
+
184
+ # Launch the interface
185
+ iface.launch(share=True)