|
|
|
|
|
|
|
|
|
|
|
import os |
|
import uuid |
|
import json |
|
import gradio as gr |
|
from huggingface_hub import CommitScheduler |
|
from pathlib import Path |
|
from openai import OpenAI |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings |
|
|
|
|
|
api_key = os.environ.get('anyscale_gl') |
|
client = OpenAI(base_url='https://api.endpoints.anyscale.com/v1', api_key=api_key) |
|
|
|
|
|
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') |
|
vectorstore_persisted = Chroma( |
|
collection_name='10k', |
|
persist_directory='./10k_vdb', |
|
embedding_function=embedding_model |
|
) |
|
|
|
|
|
|
|
|
|
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" |
|
log_folder = log_file.parent |
|
scheduler = CommitScheduler( |
|
repo_id="rag-10k", |
|
repo_type="dataset", |
|
folder_path=log_folder, |
|
path_in_repo="data", |
|
every=2 |
|
) |
|
|
|
|
|
qna_system_message = """ |
|
You serve as an assistant at a leading financial technology firm. |
|
This innovative company leverages advanced analytics to offer insights for investment management and financial planning. |
|
It manages a comprehensive database of 10-K reports from a broad spectrum of industry participants. |
|
10-K reports are rich in details regarding financial performance, risk factors, market trends, and strategic initiatives. |
|
However, the financial analysts at the firm often face challenges in quickly extracting actionable insights from these reports. |
|
This difficulty creates a bottleneck, affecting the firm's capability to provide timely and precise recommendations to its clients. |
|
Your primary role is to respond to user inquiries about annual reports. |
|
|
|
Please note the following guidelines for handling user queries: |
|
- Context Information: Each user input will include necessary context marked by the token `#Context`. You should use only this provided information to formulate your response. The context will be passed via the `context` variable. |
|
- User Questions: Queries from users will start with the token `#Question`, and the specific question will be passed through the `question` variable. |
|
- Your Responses: Begin your answers with the token `#Answer`. Ensure your responses are based solely on the context given. Do not include or refer to any part of the context directly in your answers. If the required information is not present in the context, simply respond with "I don't know". |
|
""" |
|
|
|
|
|
qna_user_message_template = """ |
|
#Context: {context} |
|
|
|
#Question: {question} |
|
""" |
|
|
|
|
|
def predict(textbox,company): |
|
|
|
filter = "dataset/"+company+"-10-k-2023.pdf" |
|
relevant_document_chunks = vectorstore_persisted.similarity_search(textbox, k=5, filter={"source":filter}) |
|
|
|
|
|
context_for_query = "" |
|
for doc in relevant_document_chunks: |
|
page_content = doc.page_content.replace('\n', ' ') |
|
page_number = doc.metadata['page'] |
|
contexts = f'{page_content} \nPage Number: {page_number}\n\n' |
|
context_for_query = ''.join(contexts) |
|
|
|
|
|
prompt = [ |
|
{'role':'system', 'content': qna_system_message}, |
|
{'role': 'user', 'content': qna_user_message_template.format( |
|
context=context_for_query, question=textbox)} |
|
] |
|
|
|
base_model = 'mlabonne/NeuralHermes-2.5-Mistral-7B' |
|
|
|
|
|
try: |
|
response = client.chat.completions.create( |
|
model=base_model, messages=prompt, temperature=0) |
|
prediction = response.choices[0].message.content.strip() |
|
except Exception as e: |
|
prediction = f'Sorry, I encountered the following error: \n {e}' |
|
|
|
|
|
|
|
|
|
with scheduler.lock: |
|
with log_file.open("a") as f: |
|
f.write(json.dumps( |
|
{ |
|
'question': textbox, |
|
'context': context_for_query, |
|
'answer': prediction |
|
} |
|
)) |
|
f.write("\n") |
|
return prediction |
|
|
|
|
|
|
|
|
|
textbox = gr.Textbox(label='Query') |
|
company = gr.Radio(['google', 'Meta', 'aws', 'IBM', 'msft']) |
|
|
|
|
|
demo = gr.Interface(fn = predict, inputs = [textbox,company], outputs = 'text', |
|
description = 'RAG on 2023 10-k statement by Google, Meta, Amazon, IBM, and Microsoft', |
|
title = '10-k Questioner') |
|
|
|
demo.queue() |
|
demo.launch() |
|
|