RAGREPORTS / app2.py
kajila's picture
Rename app.py to app2.py
a821671 verified
import subprocess
import sys
import os
import uuid
import json
from pathlib import Path
# Install dependencies if not already installed
def install_packages():
subprocess.check_call([sys.executable, "-m", "pip", "install", "openai"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "langchain_community"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "sentence-transformers"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "huggingface_hub"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "python-dotenv"])
install_packages()
# Import installed modules
from huggingface_hub import login, CommitScheduler
import openai
import gradio as gr
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Get API tokens from environment variables
#hf_token = os.getenv("HUGGINGFACE_TOKEN")
openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure OPENAI_API_KEY is in your .env file
#if hf_token is None:
# raise ValueError("Hugging Face token is missing. Please check your .env file.")
# Log in to Hugging Face
# Retrieve the Hugging Face token from environment variables
hf_token = os.getenv("hf_token")
# Check if the token is retrieved successfully
if not hf_token:
raise ValueError("Hugging Face token is missing. Please set 'hf_token' as an environment variable.")
# Log in to Hugging Face with the retrieved token
login(hf_token)
print("Logged in to Hugging Face successfully.")
# Set up embeddings and vector store
embeddings = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")
collection_name = 'report-10k-2024'
vectorstore_persisted = Chroma(
collection_name=collection_name,
persist_directory='./report_10kdb',
embedding_function=embeddings
)
retriever = vectorstore_persisted.as_retriever(
search_type='similarity',
search_kwargs={'k': 5}
)
# Define Q&A system message
qna_system_message = """
You are an AI assistant for Finsights Grey Inc., helping automate extraction, summarization, and analysis of 10-K reports.
Your responses should be based solely on the context provided.
If an answer is not found in the context, respond with "I don't know."
"""
qna_user_message_template = """
###Context
Here are some documents that are relevant to the question.
{context}
###Question
{question}
"""
# Define the predict function
def predict(user_input, company):
filter = "dataset/" + company + "-10-k-2023.pdf"
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter})
# Create context for query
context_list = [d.page_content for d in relevant_document_chunks]
context_for_query = ".".join(context_list)
# Create messages
prompt = [
{'role': 'system', 'content': qna_system_message},
{'role': 'user', 'content': qna_user_message_template.format(context=context_for_query, question=user_input)}
]
# Get response from the LLM
# Get response from the LLM using the updated API method
# Get response from the LLM using the updated API method
response = openai.completions.create(
model='gpt-3.5-turbo', # Specify the model you want to use
messages=prompt, # Pass the prompt (context and user message)
temperature=0 # Set temperature for response variety
)
# Extract the prediction from the response
prediction = response['choices'][0]['message']['content']
#except Exception as e:
# This will run if an exception occurs
prediction = str(e)
# Print the prediction or error
print(prediction)
# Log inputs and outputs to a local log file
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
log_folder = log_file.parent
scheduler = CommitScheduler(
repo_id="RAGREPORTS-log",
repo_type="dataset",
folder_path=log_folder,
path_in_repo="data",
every=2
)
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps(
{
'user_input': user_input,
'retrieved_context': context_for_query,
'model_response': prediction
}
))
f.write("\n")
# Return the prediction after logging
#return prediction
def get_predict(question, company):
# Map user selection to company name
company_map = {
"AWS": "aws",
"IBM": "IBM",
"Google": "Google",
"Meta": "meta",
"Microsoft": "msft"
}
selected_company = company_map.get(company)
if not selected_company:
return "Invalid company selected"
return predict(question, selected_company)
# Set-up the Gradio UI
with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
with gr.Row():
company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
question = gr.Textbox(label="Enter your question")
submit = gr.Button("Submit")
output = gr.Textbox(label="Output")
submit.click(
fn=get_predict,
inputs=[question, company],
outputs=output
)
demo.queue()
demo.launch()