Spaces:
Sleeping
Sleeping
# Import the necessary libraries | |
import subprocess | |
import sys | |
# Function to install a package using pip | |
def install(package): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
# Install required packages | |
try: | |
install("gradio") | |
install("openai==1.23.2") | |
install("tiktoken==0.6.0") | |
install("pypdf==4.0.1") | |
install("langchain==0.1.1") | |
install("langchain-community==0.0.13") | |
install("chromadb==0.4.22") | |
install("sentence-transformers==2.3.1") | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred: {e}") | |
import gradio as gr | |
import os | |
import uuid | |
import json | |
import pandas as pd | |
import subprocess | |
from openai import OpenAI | |
from huggingface_hub import HfApi | |
from huggingface_hub import CommitScheduler | |
from huggingface_hub import hf_hub_download | |
import zipfile | |
# Define your repository and file path | |
repo_id = "kgauvin603/rag-10k" | |
file_path = "dataset.zip" | |
# Download the file | |
downloaded_file = hf_hub_download(repo_id, file_path) | |
# Print the path to the downloaded file | |
print(f"Downloaded file is located at: {downloaded_file}") | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings | |
) | |
from langchain_community.vectorstores import Chroma | |
#from google.colab import userdata, drive | |
from pathlib import Path | |
from langchain.document_loaders import PyPDFDirectoryLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import json | |
import tiktoken | |
import pandas as pd | |
import tiktoken | |
# Define the embedding model and the vectorstore | |
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') | |
# If dataset directory exixts, remove it and all of the contents within | |
#if os.path.exists('dataset'): | |
# !rm -rf dataset | |
# If collection_db exists, remove it and all of the contents within | |
#if os.path.exists('collection_db'): | |
# !rm -rf dataset | |
#Mount the Google Drive | |
#drive.mount('/content/drive') | |
#Upload Dataset-10k.zip and unzip it dataset folder using -d option | |
#!unzip Dataset-10k.zip -d dataset | |
import subprocess | |
# Command to unzip the file | |
#command = "unzip kgauvin603/rag-10k-analysis/Dataset-10k.zip -d dataset" | |
command = "pip install transformers huggingface_hub requests" | |
# Execute the command | |
try: | |
subprocess.run(command, check=True, shell=True) | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred: {e}") | |
from huggingface_hub import hf_hub_download | |
import zipfile | |
import os | |
import requests | |
# Provide pdf_folder_location | |
repo_id = "kgauvin603/rag-10k" | |
file_path = "dataset.zip" | |
# Get the URL for the file in the repository | |
file_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}" | |
# Download the file into memory | |
response = requests.get(file_url) | |
response.raise_for_status() # Ensure the request was successful | |
# Open the zip file in memory | |
with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref: | |
# List the files in the zip archive | |
zip_file_list = zip_ref.namelist() | |
print(f"Files in the zip archive: {zip_file_list}") | |
# Extract specific files or work with them directly in memory | |
# For example, reading a specific file | |
with zip_ref.open('dataset/some_file.txt') as file: | |
file_content = file.read() | |
print(file_content.decode('utf-8')) | |
# Define the extraction path | |
#extraction_path = "./extracted_files" | |
# Create the directory if it doesn't exist | |
#os.makedirs(extraction_path, exist_ok=True) | |
# Extract the contents of the zip file | |
#with zipfile.ZipFile(downloaded_file, 'r') as zip_ref: | |
# zip_ref.extractall(extraction_path) | |
# List the files in the extraction path | |
#extracted_files = os.listdir(extraction_path) | |
#print(f"Extracted files: {extracted_files}") | |
# Load the directory to pdf_loader | |
pdf_loader = PyPDFDirectoryLoader(pdf_folder_location) | |
# Create text_splitter using recursive splitter | |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
encoding_name='cl100k_base', | |
chunk_size=512, | |
chunk_overlap=16 | |
) | |
# Create chunks | |
report_chunks = pdf_loader.load_and_split(text_splitter) | |
#Create a Colelction Name | |
collection_name = 'collection' | |
# Create the vector Database | |
vectorstore = Chroma.from_documents( | |
report_chunks, | |
embedding_model, | |
collection_name=collection_name, | |
persist_directory='./collection_db' | |
) | |
# Persist the DB | |
vectorstore.persist() | |
vectorstore_persisted = Chroma( | |
collection_name=collection_name, | |
persist_directory='./collection_db', | |
embedding_function=embedding_model | |
) | |
retriever = vectorstore_persisted.as_retriever( | |
search_type='similarity', | |
search_kwargs={'k': 5} | |
) | |
#Mount the Google Drive | |
#drive.mount('/content/drive') | |
#Copy the persisted database to your drive | |
#command = "!cp -r collection_db /content/drive/MyDrive/" | |
# Execute the command | |
#try: | |
# subprocess.run(command, check=True, shell=True) | |
#except subprocess.CalledProcessError as e: | |
# print(f"An error occurred: {e}") | |
# Get anyscale api key | |
anyscale_api_key = userdata.get('dev-work') | |
# Initialise the client | |
client = OpenAI( | |
base_url="https://api.endpoints.anyscale.com/v1", | |
api_key=anyscale_api_key | |
) | |
#Provide the model name | |
model_name = 'mlabonne/NeuralHermes-2.5-Mistral-7B' | |
# Initialise the embedding model | |
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') | |
# Load the persisted DB | |
persisted_vectordb_location = './collection_db' | |
#Create a Colelction Name | |
collection_name = 'collection' | |
# Load the persisted DB | |
vectorstore_persisted = Chroma( | |
collection_name=collection_name, | |
persist_directory=persisted_vectordb_location, | |
embedding_function=embedding_model | |
) | |
# Prepare the logging functionality | |
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" | |
log_folder = log_file.parent | |
scheduler = CommitScheduler( | |
repo_id="kgauvin603/rag-10k-analysis", | |
repo_type="dataset", | |
folder_path=log_folder, | |
path_in_repo="data", | |
every=2, | |
token=hf_token | |
) | |
# Define the Q&A system message | |
qna_system_message = """You are an assistant to a financial services firm who answers user queries on annual reports. | |
User input will have the context required by you to answer user questions. | |
This context will begin with the token: ###Context. | |
The context contains references to specific portions of a document relevant to the user query. | |
User questions will begin with the token: ###Question. | |
Please answer only using the context provided in the input. Do not mention anything about the context in your final answer. | |
If the answer is not found in the context, respond "I don't know". | |
""" | |
# Create a message template | |
qna_user_message_template = """ | |
###Context | |
Here are some documents that are relevant to the question mentioned below. | |
{context} | |
###Question | |
{question} | |
""" | |
# Define the predict function that runs when 'Submit' is clicked or when an API request is made | |
def predict(user_input, company): | |
filter = "dataset/" + company + "-10-k-2023.pdf" | |
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter}) | |
# Create context_for_query | |
context_list = [d.page_content for d in relevant_document_chunks] | |
context_for_query = ". ".join(context_list) | |
# Create messages | |
prompt = [ | |
{'role': 'system', 'content': qna_system_message}, | |
{'role': 'user', 'content': qna_user_message_template.format( | |
context=context_for_query, | |
question=user_input | |
)} | |
] | |
try: | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=prompt, | |
temperature=0 | |
) | |
prediction = response.choices[0].message.content.strip() | |
except Exception as e: | |
prediction = f'Sorry, I encountered the following error: \n{e}' | |
# Log both the inputs and outputs to a local log file | |
# Ensure that the commit scheduler is locked to avoid parallel access | |
with scheduler.lock: | |
with log_file.open("a") as f: | |
f.write(json.dumps( | |
{ | |
'user_input': user_input, | |
'retrieved_context': context_for_query, | |
'model_response': prediction | |
} | |
)) | |
f.write("\n") | |
return prediction | |
# Set up the Gradio UI | |
# Add text box and radio button to the interface | |
# The radio button is used to select the company 10k report in which the context needs to be retrieved. | |
textbox = gr.Textbox(label="User Input") | |
#company = gr.List(label="Select Company", choices=["IBM", "Meta", "aws", "google","msft"]) | |
company = gr.Dropdown(label="Select Company", choices=["IBM", "Meta", "aws", "google","msft"]) | |
# Create the interface | |
# For the inputs parameter of Interface provide [textbox, company] | |
demo = gr.Interface(fn=predict, inputs=[textbox, company], outputs="text") | |
demo.queue() | |
demo.launch(share=True) |