|
|
|
import os |
|
import uuid |
|
import json |
|
|
|
import gradio as gr |
|
|
|
from openai import OpenAI |
|
|
|
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
|
|
from huggingface_hub import CommitScheduler |
|
from pathlib import Path |
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
os.environ["Project_3_api_fierworks"]=os.getenv("Project_3_api_fierworks") |
|
|
|
client = OpenAI( |
|
base_url="https://api.fireworks.ai/inference/v1", |
|
api_key=os.environ['Project_3_api_fierworks'] |
|
) |
|
|
|
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') |
|
|
|
|
|
collection_name = 'report-10k-2024' |
|
|
|
vectorstore_persisted = Chroma( |
|
collection_name=collection_name, |
|
persist_directory='./report_10kdb', |
|
embedding_function=embedding_model |
|
) |
|
|
|
|
|
|
|
retriever = vectorstore_persisted.as_retriever( |
|
search_type='similarity', |
|
search_kwargs={'k': 5} |
|
) |
|
|
|
|
|
|
|
|
|
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" |
|
log_folder = log_file.parent |
|
|
|
scheduler = CommitScheduler( |
|
repo_id="RAG-investment-recommendation-log", |
|
repo_type="dataset", |
|
folder_path=log_folder, |
|
path_in_repo="data", |
|
every=2 |
|
) |
|
|
|
|
|
|
|
qna_system_message = """ |
|
You are an assistant to a researcher. Your task is to provide relevant information about The 10K reports repository. |
|
|
|
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context. |
|
The context contains references to specific portions of documents relevant to the user's query, along with source links. |
|
The source for a context will begin with the token ###Source |
|
|
|
When crafting your response: |
|
1. Select the most relevant context or contexts to answer the question. |
|
2. Include the source links in your response. |
|
3. User questions will begin with the token: ###Question. |
|
4. If the question is irrelevant to 10k report respond with - "Apologies, I can only help you with questions related to the 10k Reports." |
|
|
|
Please adhere to the following guidelines: |
|
- Answer only using the context provided. If you do not know the answer say 'Sorry,I do not know.' |
|
- Do not mention anything about the context in your final answer. |
|
- If the answer is not found in the context, it is very important for you to respond with "I don't know. Please check the docs found in the report repository." |
|
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Sources: |
|
- Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources or general knowledge about people not in the 10k reports. |
|
|
|
Here is an example of how to structure your response: |
|
|
|
Answer: |
|
[Answer] |
|
|
|
Source |
|
[Source] |
|
""" |
|
|
|
|
|
qna_user_message_template = """ |
|
###Context |
|
Here are some documents that are relevant to the question. |
|
{context} |
|
``` |
|
{question} |
|
``` |
|
""" |
|
|
|
|
|
def predict(user_input,company): |
|
|
|
filter = "dataset/"+company+"-10-k-2023.pdf" |
|
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter}) |
|
|
|
|
|
context_list = [d.page_content for d in relevant_document_chunks] |
|
context_for_query = ".".join(context_list) |
|
|
|
|
|
prompt = [ |
|
{'role':'system', 'content': qna_system_message}, |
|
{'role': 'user', 'content': qna_user_message_template.format( |
|
context=context_for_query, |
|
question=user_input |
|
) |
|
} |
|
] |
|
|
|
|
|
try: |
|
response = client.chat.completions.create( |
|
model='accounts/fireworks/models/mixtral-8x7b-instruct', |
|
messages=prompt, |
|
temperature=0 |
|
) |
|
|
|
prediction = response.choices[0].message.content |
|
|
|
except Exception as e: |
|
prediction = e |
|
|
|
|
|
|
|
|
|
|
|
with scheduler.lock: |
|
with log_file.open("a") as f: |
|
f.write(json.dumps( |
|
{ |
|
'user_input': user_input, |
|
'retrieved_context': context_for_query, |
|
'model_response': prediction |
|
} |
|
)) |
|
f.write("\n") |
|
|
|
return prediction |
|
|
|
examples = [ |
|
["What are the AWS policies and frameworks regarding AI ethics, governance, and responsible AI use as detailed in their 10-K reports?", "AWS"], |
|
["What are the primary business segments of AWS, and how does each segment contribute to the overall revenue and profitability?", "AWS"], |
|
["What are the key risk factors identified in the 10-K report that could potentially impact the AWS business operations and financial performance?", "AWS"], |
|
["Has the company made any significant acquisitions in the AI space, and how are these acquisitions being integrated into the company's strategy?", "Microsoft"], |
|
["How much capital has been allocated towards AI research and development?","Google"], |
|
["What initiatives has IBM implemented to address ethical concerns surrounding AI, such as fairness, accountability, and privacy?","IBM"], |
|
["How does Meta plan to differentiate itself in the AI space relative to competitors?","Meta"] |
|
] |
|
|
|
def get_predict(question, company): |
|
|
|
if company == "AWS": |
|
|
|
selectedCompany = "aws" |
|
elif company == "IBM": |
|
|
|
selectedCompany = "IBM" |
|
elif company == "Google": |
|
|
|
selectedCompany = "Google" |
|
elif company == "Meta": |
|
|
|
selectedCompany = "meta" |
|
elif company == "Microsoft": |
|
|
|
selectedCompany = "msft" |
|
else: |
|
return "Invalid company selected" |
|
|
|
output = predict(question, company) |
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme="Taithrah/Minimal@>=0.0.1,<0.1.0") as demo: |
|
with gr.Row(): |
|
company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company") |
|
with gr.Row(): |
|
question = gr.Textbox(label="Enter your question") |
|
|
|
submit = gr.Button("Submit") |
|
output = gr.Textbox(label="Output") |
|
|
|
submit.click( |
|
fn=get_predict, |
|
inputs=[question, company], |
|
outputs=output |
|
) |
|
examples_component = gr.Examples(examples=examples, inputs=[question, company]) |
|
demo.queue() |
|
demo.launch() |