|
import weaviate |
|
import langchain |
|
import apscheduler |
|
import tempfile |
|
import gradio as gr |
|
from langchain.embeddings import CohereEmbeddings |
|
from langchain.document_loaders import UnstructuredFileLoader |
|
from langchain.vectorstores import Weaviate |
|
from langchain.llms import OpenAI |
|
from langchain.chains import RetrievalQA |
|
import os |
|
import urllib.request |
|
import ssl |
|
import mimetypes |
|
from dotenv import load_dotenv |
|
import cohere |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
import time |
|
|
|
|
|
load_dotenv() |
|
openai_api_key = os.getenv('OPENAI') |
|
cohere_api_key = os.getenv('COHERE') |
|
weaviate_api_key = os.getenv('WEAVIATE') |
|
weaviate_url = os.getenv('WEAVIATE_URL') |
|
weaviate_username = os.getenv('WEAVIATE_USERNAME') |
|
weaviate_password = os.getenv('WEAVIATE_PASSWORD') |
|
|
|
|
|
|
|
def refresh_authentication(): |
|
global my_credentials, client |
|
my_credentials = weaviate.auth.AuthClientPassword(username=weaviate_username, password=weaviate_password) |
|
client = weaviate.Client(weaviate_url, auth_client_secret=my_credentials) |
|
|
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(refresh_authentication, 'interval', minutes=30) |
|
scheduler.start() |
|
|
|
|
|
refresh_authentication() |
|
|
|
Article = { |
|
"class": "Article", |
|
"description": "A class representing articles in the application", |
|
"properties": [ |
|
{ |
|
"name": "title", |
|
"description": "The title of the article", |
|
"dataType": ["text"] |
|
}, |
|
{ |
|
"name": "content", |
|
"description": "The content of the article", |
|
"dataType": ["text"] |
|
}, |
|
{ |
|
"name": "author", |
|
"description": "The author of the article", |
|
"dataType": ["text"] |
|
}, |
|
{ |
|
"name": "publishDate", |
|
"description": "The date the article was published", |
|
"dataType": ["date"] |
|
} |
|
], |
|
|
|
|
|
} |
|
|
|
|
|
def class_exists(class_name): |
|
try: |
|
existing_schema = client.schema.get() |
|
existing_classes = [cls["class"] for cls in existing_schema["classes"]] |
|
return class_name in existing_classes |
|
except Exception as e: |
|
print(f"Error checking if class exists: {e}") |
|
return False |
|
|
|
|
|
if not class_exists("Article"): |
|
|
|
try: |
|
client.schema.create(schema) |
|
except Exception as e: |
|
print(f"Error creating schema: {e}") |
|
else: |
|
print("Class 'Article' already exists in the schema.") |
|
|
|
|
|
schema = { |
|
"classes": [Article] |
|
} |
|
|
|
|
|
if not class_exists("Article"): |
|
|
|
try: |
|
client.schema.create(schema) |
|
except Exception as e: |
|
print(f"Error creating schema: {e}") |
|
else: |
|
|
|
try: |
|
existing_schema = client.schema.get() |
|
print("Existing schema retrieved:", existing_schema) |
|
except Exception as e: |
|
print(f"Error retrieving existing schema: {e}") |
|
|
|
|
|
|
|
vectorstore = Weaviate(client, index_name="HereChat", text_key="text") |
|
vectorstore._query_attrs = ["text", "title", "url", "views", "lang", "_additional {distance}"] |
|
vectorstore.embedding = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key) |
|
|
|
|
|
co = cohere.Client(api_key=cohere_api_key) |
|
|
|
def embed_pdf(file, filename, collection_name, file_type): |
|
|
|
if file_type == "URL": |
|
|
|
try: |
|
context = ssl._create_unverified_context() |
|
with urllib.request.urlopen(file, context=context) as response, open(filename, 'wb') as out_file: |
|
data = response.read() |
|
out_file.write(data) |
|
file_path = filename |
|
except Exception as e: |
|
return {"error": f"Error downloading file from URL: {e}"} |
|
elif file_type == "Binary": |
|
|
|
if isinstance(file, str): |
|
|
|
file = file.encode() |
|
file_content = file |
|
file_path = os.path.join('./', filename) |
|
with open(file_path, 'wb') as f: |
|
f.write(file_content) |
|
else: |
|
return {"error": "Invalid file type"} |
|
|
|
|
|
|
|
mime_type = mimetypes.guess_type(file_path)[0] |
|
loader = UnstructuredFileLoader(file_path) |
|
docs = loader.load() |
|
|
|
|
|
embeddings = CohereEmbeddings(model="embed-multilingual-v2.0", cohere_api_key=cohere_api_key) |
|
for doc in docs: |
|
embedding = embeddings.embed([doc['text']]) |
|
weaviate_document = { |
|
"text": doc['text'], |
|
"embedding": embedding |
|
} |
|
client.data_object.create(data_object=weaviate_document, class_name=collection_name) |
|
|
|
|
|
if isinstance(file, bytes): |
|
os.remove(file_path) |
|
return {"message": f"Documents embedded in Weaviate collection '{collection_name}'"} |
|
|
|
def retrieve_info(query): |
|
llm = OpenAI(temperature=0, openai_api_key=openai_api_key) |
|
qa = RetrievalQA.from_chain_type(llm, retriever=vectorstore.as_retriever()) |
|
|
|
|
|
initial_results = qa({"query": query}) |
|
|
|
|
|
top_docs = initial_results[:25] |
|
|
|
|
|
reranked_results = co.rerank(query=query, documents=top_docs, top_n=3, model='rerank-english-v2.0') |
|
|
|
|
|
formatted_results = [] |
|
for idx, r in enumerate(reranked_results): |
|
formatted_result = { |
|
"Document Rank": idx + 1, |
|
"Title": r.document['title'], |
|
"Content": r.document['content'], |
|
"Author": r.document['author'], |
|
"Publish Date": r.document['publishDate'], |
|
"Relevance Score": f"{r.relevance_score:.2f}" |
|
} |
|
formatted_results.append(formatted_result) |
|
|
|
return {"results": formatted_results} |
|
|
|
user_prompt = f"User: {query}\n" |
|
for idx, r in enumerate(reranked_results): |
|
user_prompt += f"Document {idx + 1}: {r.document['text']}\nRelevance Score: {r.relevance_score:.2f}\n\n" |
|
|
|
|
|
final_response = client.chat.completions.create( |
|
model="gpt-4-1106-preview", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": "You are a redditor. Assess, rephrase, and explain the following. Provide long answers. Use the same words and language you receive." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": user_prompt |
|
} |
|
], |
|
temperature=1.63, |
|
max_tokens=2240, |
|
top_p=1, |
|
frequency_penalty=1.73, |
|
presence_penalty=1.76 |
|
) |
|
|
|
return final_response.choices[0].text |
|
|
|
def combined_interface(query, file, collection_name): |
|
if query: |
|
article_info = retrieve_info(query) |
|
return article_info |
|
elif file is not None and collection_name: |
|
filename = file[1] |
|
file_content = file[0] |
|
|
|
|
|
if isinstance(file_content, str) and file_content.startswith("http"): |
|
file_type = "URL" |
|
|
|
else: |
|
file_type = "Binary" |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(filename)[1]) as temp_file: |
|
temp_file.write(file_content) |
|
temp_filepath = temp_file.name |
|
|
|
|
|
result = embed_pdf(temp_filepath, collection_name) |
|
|
|
|
|
os.remove(temp_filepath) |
|
|
|
return result |
|
else: |
|
return "Please enter a query or upload a PDF file and specify a collection name." |
|
|
|
|
|
iface = gr.Interface( |
|
fn=combined_interface, |
|
inputs=[ |
|
gr.Textbox(label="Query"), |
|
gr.File(label="PDF File"), |
|
gr.Textbox(label="Collection Name") |
|
], |
|
outputs="text" |
|
) |
|
|
|
iface.launch() |