Faran Butt commited on
Commit
dbb1f9d
β€’
2 Parent(s): 80f6b7b f473a07

Merge pull request #1 from Josephrp/main

Browse files

initial commit backend with buffer memory

Files changed (1) hide show
  1. backend/app.py +261 -0
backend/app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weaviate
2
+ import langchain
3
+ import apscheduler
4
+ import tempfile
5
+ import gradio as gr
6
+ from langchain.embeddings import CohereEmbeddings
7
+ from langchain.document_loaders import UnstructuredFileLoader
8
+ from langchain.vectorstores import Weaviate
9
+ from langchain.llms import OpenAI
10
+ from langchain.chains import RetrievalQA
11
+ import os
12
+ import urllib.request
13
+ import ssl
14
+ import mimetypes
15
+ from dotenv import load_dotenv
16
+ import cohere
17
+ from apscheduler.schedulers.background import BackgroundScheduler
18
+ import time
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+ openai_api_key = os.getenv('OPENAI')
23
+ cohere_api_key = os.getenv('COHERE')
24
+ weaviate_api_key = os.getenv('WEAVIATE')
25
+ weaviate_url = os.getenv('WEAVIATE_URL')
26
+ weaviate_username = os.getenv('WEAVIATE_USERNAME')
27
+ weaviate_password = os.getenv('WEAVIATE_PASSWORD')
28
+
29
+
30
+ # Function to refresh authentication
31
+ def refresh_authentication():
32
+ global my_credentials, client
33
+ my_credentials = weaviate.auth.AuthClientPassword(username=weaviate_username, password=weaviate_password)
34
+ client = weaviate.Client(weaviate_url, auth_client_secret=my_credentials)
35
+
36
+ # Initialize the scheduler for authentication refresh
37
+ scheduler = BackgroundScheduler()
38
+ scheduler.add_job(refresh_authentication, 'interval', minutes=30)
39
+ scheduler.start()
40
+
41
+ # Initial authentication
42
+ refresh_authentication()
43
+
44
+ Article = {
45
+ "class": "Article",
46
+ "description": "A class representing articles in the application",
47
+ "properties": [
48
+ {
49
+ "name": "title",
50
+ "description": "The title of the article",
51
+ "dataType": ["text"]
52
+ },
53
+ {
54
+ "name": "content",
55
+ "description": "The content of the article",
56
+ "dataType": ["text"]
57
+ },
58
+ {
59
+ "name": "author",
60
+ "description": "The author of the article",
61
+ "dataType": ["text"]
62
+ },
63
+ {
64
+ "name": "publishDate",
65
+ "description": "The date the article was published",
66
+ "dataType": ["date"]
67
+ }
68
+ ],
69
+ # "vectorIndexType": "hnsw",
70
+ # "vectorizer": "text2vec-contextionary"
71
+ }
72
+
73
+ # Function to check if a class exists in the schema
74
+ def class_exists(class_name):
75
+ try:
76
+ existing_schema = client.schema.get()
77
+ existing_classes = [cls["class"] for cls in existing_schema["classes"]]
78
+ return class_name in existing_classes
79
+ except Exception as e:
80
+ print(f"Error checking if class exists: {e}")
81
+ return False
82
+
83
+ # Check if 'Article' class already exists
84
+ if not class_exists("Article"):
85
+ # Create the schema if 'Article' class does not exist
86
+ try:
87
+ client.schema.create(schema)
88
+ except Exception as e:
89
+ print(f"Error creating schema: {e}")
90
+ else:
91
+ print("Class 'Article' already exists in the schema.")
92
+
93
+ # Initialize the schema
94
+ schema = {
95
+ "classes": [Article]
96
+ }
97
+
98
+ # Check if 'Article' class already exists
99
+ if not class_exists("Article"):
100
+ # Create the schema if 'Article' class does not exist
101
+ try:
102
+ client.schema.create(schema)
103
+ except Exception as e:
104
+ print(f"Error creating schema: {e}")
105
+ else:
106
+ # Retrieve the existing schema if 'Article' class exists
107
+ try:
108
+ existing_schema = client.schema.get()
109
+ print("Existing schema retrieved:", existing_schema)
110
+ except Exception as e:
111
+ print(f"Error retrieving existing schema: {e}")
112
+
113
+
114
+ # Initialize vectorstore
115
+ vectorstore = Weaviate(client, index_name="HereChat", text_key="text")
116
+ vectorstore._query_attrs = ["text", "title", "url", "views", "lang", "_additional {distance}"]
117
+ vectorstore.embedding = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key)
118
+
119
+ # Initialize Cohere client
120
+ co = cohere.Client(api_key=cohere_api_key)
121
+
122
+ def embed_pdf(file, filename, collection_name, file_type):
123
+ # Check the file type and handle accordingly
124
+ if file_type == "URL":
125
+ # Download the file from the URL
126
+ try:
127
+ context = ssl._create_unverified_context()
128
+ with urllib.request.urlopen(file, context=context) as response, open(filename, 'wb') as out_file:
129
+ data = response.read()
130
+ out_file.write(data)
131
+ file_path = filename
132
+ except Exception as e:
133
+ return {"error": f"Error downloading file from URL: {e}"}
134
+ elif file_type == "Binary":
135
+ # Handle binary file
136
+ if isinstance(file, str):
137
+ # Convert string to bytes if necessary
138
+ file = file.encode()
139
+ file_content = file
140
+ file_path = os.path.join('./', filename)
141
+ with open(file_path, 'wb') as f:
142
+ f.write(file_content)
143
+ else:
144
+ return {"error": "Invalid file type"}
145
+
146
+
147
+ # Checking filetype for document parsing
148
+ mime_type = mimetypes.guess_type(file_path)[0]
149
+ loader = UnstructuredFileLoader(file_path)
150
+ docs = loader.load()
151
+
152
+ # Generate embeddings and store documents in Weaviate
153
+ embeddings = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key)
154
+ for doc in docs:
155
+ embedding = embeddings.embed([doc['text']])
156
+ weaviate_document = {
157
+ "text": doc['text'],
158
+ "embedding": embedding
159
+ }
160
+ client.data_object.create(data_object=weaviate_document, class_name=collection_name)
161
+
162
+ # Clean up if a temporary file was created
163
+ if isinstance(file, bytes):
164
+ os.remove(file_path)
165
+ return {"message": f"Documents embedded in Weaviate collection '{collection_name}'"}
166
+
167
+ def retrieve_info(query):
168
+ llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
169
+ qa = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever())
170
+
171
+ # Retrieve initial results
172
+ initial_results = qa({"query": query})
173
+
174
+ # Assuming initial_results are in the desired format, extract the top documents
175
+ top_docs = initial_results[:25] # Adjust this if your result format is different
176
+
177
+ # Rerank the top results
178
+ reranked_results = co.rerank(query=query, documents=top_docs, top_n=3, model='rerank-english-v2.0')
179
+
180
+ # Format the reranked results according to the Article schema
181
+ formatted_results = []
182
+ for idx, r in enumerate(reranked_results):
183
+ formatted_result = {
184
+ "Document Rank": idx + 1,
185
+ "Title": r.document['title'],
186
+ "Content": r.document['content'],
187
+ "Author": r.document['author'],
188
+ "Publish Date": r.document['publishDate'],
189
+ "Relevance Score": f"{r.relevance_score:.2f}"
190
+ }
191
+ formatted_results.append(formatted_result)
192
+
193
+ return {"results": formatted_results}
194
+ # Format the reranked results and append to user prompt
195
+ user_prompt = f"User: {query}\n"
196
+ for idx, r in enumerate(reranked_results):
197
+ user_prompt += f"Document {idx + 1}: {r.document['text']}\nRelevance Score: {r.relevance_score:.2f}\n\n"
198
+
199
+ # Final API call to OpenAI
200
+ final_response = client.chat.completions.create(
201
+ model="gpt-4-1106-preview",
202
+ messages=[
203
+ {
204
+ "role": "system",
205
+ "content": "You are a redditor. Assess, rephrase, and explain the following. Provide long answers. Use the same words and language you receive."
206
+ },
207
+ {
208
+ "role": "user",
209
+ "content": user_prompt
210
+ }
211
+ ],
212
+ temperature=1.63,
213
+ max_tokens=2240,
214
+ top_p=1,
215
+ frequency_penalty=1.73,
216
+ presence_penalty=1.76
217
+ )
218
+
219
+ return final_response.choices[0].text
220
+
221
+ def combined_interface(query, file, collection_name):
222
+ if query:
223
+ article_info = retrieve_info(query)
224
+ return article_info
225
+ elif file is not None and collection_name:
226
+ filename = file[1] # Extract filename
227
+ file_content = file[0] # Extract file content
228
+
229
+ # Check if file_content is a URL or binary data
230
+ if isinstance(file_content, str) and file_content.startswith("http"):
231
+ file_type = "URL"
232
+ # Handle URL case (if needed)
233
+ else:
234
+ file_type = "Binary"
235
+ # Write binary data to a temporary file
236
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(filename)[1]) as temp_file:
237
+ temp_file.write(file_content)
238
+ temp_filepath = temp_file.name
239
+
240
+ # Pass the file path to embed_pdf
241
+ result = embed_pdf(temp_filepath, collection_name)
242
+
243
+ # Clean up the temporary file
244
+ os.remove(temp_filepath)
245
+
246
+ return result
247
+ else:
248
+ return "Please enter a query or upload a PDF file and specify a collection name."
249
+
250
+
251
+ iface = gr.Interface(
252
+ fn=combined_interface,
253
+ inputs=[
254
+ gr.Textbox(label="Query"),
255
+ gr.File(label="PDF File"),
256
+ gr.Textbox(label="Collection Name")
257
+ ],
258
+ outputs="text"
259
+ )
260
+
261
+ iface.launch()