Michaeldavidstein commited on
Commit
3a38016
·
verified ·
1 Parent(s): 2b40263

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +177 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from sentence_transformers import SentenceTransformer
4
+ import chromadb
5
+ import pandas as pd
6
+ import os
7
+ import json
8
+ from pathlib import Path
9
+ from llama_index.llms.anyscale import Anyscale
10
+
11
+ # Load the sentence transformer model
12
+ model = SentenceTransformer('all-MiniLM-L6-v2')
13
+
14
+ # Initialize the ChromaDB client
15
+ chroma_client = chromadb.Client()
16
+
17
+ # Function to build the database from CSV
18
+ def build_database():
19
+ # Read the CSV file
20
+ df = pd.read_csv('collection_data.csv')
21
+
22
+ # Create a collection
23
+ collection_name = 'Dataset-10k-companies'
24
+
25
+ # Delete the existing collection if it exists
26
+ chroma_client.delete_collection(name=collection_name)
27
+
28
+ # Create a new collection
29
+ collection = chroma_client.create_collection(name=collection_name)
30
+
31
+ # Add the data from the DataFrame to the collection
32
+ collection.add(
33
+ documents=df['documents'].tolist(),
34
+ ids=df['ids'].tolist(),
35
+ metadatas=df['metadatas'].apply(eval).tolist(),
36
+ embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
37
+ )
38
+
39
+ return collection
40
+
41
+ # Build the database when the app starts
42
+ collection = build_database()
43
+
44
+ # Instantiate the Anyscale client
45
+ client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf")
46
+
47
+ # Function to get relevant chunks
48
+ def get_relevant_chunks(query, collection, top_n=3):
49
+ query_embedding = model.encode(query).tolist()
50
+ results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
51
+
52
+ relevant_chunks = []
53
+ for i in range(len(results['documents'][0])):
54
+ chunk = results['documents'][0][i]
55
+ source = results['metadatas'][0][i]['source']
56
+ page = results['metadatas'][0][i]['page']
57
+ relevant_chunks.append((chunk, source, page))
58
+
59
+ return relevant_chunks
60
+
61
+ # Define system message for LLM
62
+ qna_system_message = """
63
+ You are an assistant to Finsights analysts. Your task is to provide relevant information about the financial performance of the companies followed by Finsights.
64
+
65
+ User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
66
+ The context contains references to specific portions of documents relevant to the user's query, along with source links.
67
+ The source for a context will begin with the token: ###Source.
68
+
69
+ When crafting your response:
70
+ 1. Select only the context relevant to answer the question.
71
+ 2. Include the source links in your response.
72
+ 3. User questions will begin with the token: ###Question.
73
+ 4. If the question is irrelevant to Finsights, respond with: "I am an assistant for Finsight Docs. I can only help you with questions related to Finsights."
74
+
75
+ Adhere to the following guidelines:
76
+ - Your response should only address the question asked and nothing else.
77
+ - Answer only using the context provided.
78
+ - Do not mention anything about the context in your final answer.
79
+ - If the answer is not found in the context, respond with: "I don't know."
80
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
81
+ - Do not make up sources. Use only the links provided in the sources section of the context. You are prohibited from providing other links/sources.
82
+
83
+ Here is an example of how to structure your response:
84
+
85
+ Answer:
86
+ [Answer]
87
+
88
+ Source:
89
+ [Source]
90
+ """
91
+ # Create a user message template
92
+ qna_user_message_template = """
93
+ ###Context
94
+ Here are some documents and their source links that are relevant to the question mentioned below.
95
+ {context}
96
+
97
+ ###Question
98
+ {question}
99
+ """
100
+
101
+ # Function to get LLM response
102
+ def get_llm_response(prompt, max_attempts=3):
103
+ full_response = ""
104
+ for attempt in range(max_attempts):
105
+ try:
106
+ response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible
107
+ chunk = response.text.strip()
108
+ full_response += chunk
109
+ if chunk.endswith((".", "!", "?")): # Check if response seems complete
110
+ break
111
+ else:
112
+ prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context
113
+ except Exception as e:
114
+ print(f"Attempt {attempt + 1} failed with error: {e}")
115
+ return full_response
116
+
117
+ # Prediction function
118
+ def predict(company, user_query):
119
+ try:
120
+ # Modify the query to include the company name
121
+ modified_query = f"{user_query} for {company}"
122
+
123
+ # Get relevant chunks
124
+ relevant_chunks = get_relevant_chunks(modified_query, collection)
125
+
126
+ # Prepare the context string
127
+ context = ""
128
+ for chunk, source, page in relevant_chunks:
129
+ context += chunk + "\n"
130
+ context += f"###Source {source}, Page {page}\n"
131
+
132
+ # Prepare the user message
133
+ user_message = qna_user_message_template.format(context=context, question=user_query)
134
+
135
+ # Craft the prompt to pass to the Llama model
136
+ prompt = f"{qna_system_message}\n\n{qna_user_message_template.format(context=context, question=user_query)}"
137
+
138
+ # Generate the response using the Llama model through Anyscale
139
+ answer = get_llm_response(prompt)
140
+
141
+ # Extract the generated response
142
+ # answer = response.text.strip()
143
+
144
+ # Log the interaction
145
+ log_interaction(company, user_query, context, answer)
146
+
147
+ return answer
148
+ except Exception as e:
149
+ return f"An error occurred: {str(e)}"
150
+
151
+ # Function to log interactions
152
+ def log_interaction(company, user_query, context, answer):
153
+ log_file = Path("interaction_log.jsonl")
154
+ with log_file.open("a") as f:
155
+ json.dump({
156
+ 'company': company,
157
+ 'user_query': user_query,
158
+ 'context': context,
159
+ 'answer': answer
160
+ }, f)
161
+ f.write("\n")
162
+
163
+ # Create Gradio interface
164
+ company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
165
+ iface = gr.Interface(
166
+ fn=predict,
167
+ inputs=[
168
+ gr.Radio(company_list, label="Select Company"),
169
+ gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
170
+ ],
171
+ outputs=gr.Textbox(label="Generated Answer"),
172
+ title="Company Reports Q&A",
173
+ description="Query the vector database and get an LLM response based on the documents in the collection."
174
+ )
175
+
176
+ # Launch the interface
177
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ sentence-transformers
3
+ chromadb
4
+ pandas
5
+ Anyscale