File size: 4,999 Bytes
e3ccde2
 
 
 
 
67a183b
cb3404b
e3ccde2
 
 
 
459c35c
 
 
e3ccde2
 
8b71b30
67a183b
e3ccde2
 
b6bd4b8
e3ccde2
b6bd4b8
 
 
 
e3ccde2
d2db699
 
e3ccde2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b776260
 
e3ccde2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2db699
 
e3ccde2
 
 
 
d2db699
e3ccde2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b776260
 
e3ccde2
cb60060
a64f84e
e3ccde2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123


## Setup

# Import the necessary Libraries
import os
import uuid
import json
import gradio as gr
from huggingface_hub import CommitScheduler
from pathlib import Path
from openai import OpenAI
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings

# Create Client
api_key = os.environ.get('anyscale_gl')
client = OpenAI(base_url='https://api.endpoints.anyscale.com/v1', api_key=api_key)

# Define the embedding model and the vectorstore
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
vectorstore_persisted = Chroma(
    collection_name='10k',
    persist_directory='./10k_vdb',
    embedding_function=embedding_model
)



# Prepare the logging functionality
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
log_folder = log_file.parent
scheduler = CommitScheduler(
    repo_id="rag-10k",
    repo_type="dataset",
    folder_path=log_folder,
    path_in_repo="data",
    every=2
)

# Define the Q&A system message
qna_system_message = """
You serve as an assistant at a leading financial technology firm.
This innovative company leverages advanced analytics to offer insights for investment management and financial planning.
It manages a comprehensive database of 10-K reports from a broad spectrum of industry participants.
10-K reports are rich in details regarding financial performance, risk factors, market trends, and strategic initiatives.
However, the financial analysts at the firm often face challenges in quickly extracting actionable insights from these reports.
This difficulty creates a bottleneck, affecting the firm's capability to provide timely and precise recommendations to its clients.
Your primary role is to respond to user inquiries about annual reports.

Please note the following guidelines for handling user queries:
- Context Information: Each user input will include necessary context marked by the token `#Context`. You should use only this provided information to formulate your response. The context will be passed via the `context` variable.
- User Questions: Queries from users will start with the token `#Question`, and the specific question will be passed through the `question` variable.
- Your Responses: Begin your answers with the token `#Answer`. Ensure your responses are based solely on the context given. Do not include or refer to any part of the context directly in your answers. If the required information is not present in the context, simply respond with "I don't know".
"""

# Define the user message template
qna_user_message_template = """
#Context: {context}

#Question: {question}
"""

# Define the predict function that runs when 'Submit' is clicked or when a API request is made
def predict(textbox,company):

    filter = "dataset/"+company+"-10-k-2023.pdf"
    relevant_document_chunks = vectorstore_persisted.similarity_search(textbox, k=5, filter={"source":filter})

    # Create context_for_query
    context_for_query = ""
    for doc in relevant_document_chunks:
        page_content = doc.page_content.replace('\n', ' ')
        page_number = doc.metadata['page']
        contexts = f'{page_content} \nPage Number: {page_number}\n\n'
        context_for_query = ''.join(contexts)

    # Create messages
    prompt = [
        {'role':'system', 'content': qna_system_message},
        {'role': 'user', 'content': qna_user_message_template.format(
            context=context_for_query, question=textbox)}
    ]
    
    base_model = 'mlabonne/NeuralHermes-2.5-Mistral-7B' #'gpt-3.5-turbo-16k'

    # Get a response from the LLM
    try:
        response = client.chat.completions.create(
            model=base_model, messages=prompt, temperature=0)
        prediction = response.choices[0].message.content.strip()
    except Exception as e:
        prediction = f'Sorry, I encountered the following error: \n {e}'

    # While the prediction is made, log both the inputs and outputs to a local log file
    # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
    # access
    with scheduler.lock:
        with log_file.open("a") as f:
            f.write(json.dumps(
                {
                    'question': textbox,
                    'context': context_for_query,
                    'answer': prediction
                }
            ))
            f.write("\n")
    return prediction

# Set-up the Gradio UI
# Add text box and radio button to the interface
# The radio button is used to select the company 10k report in which the context needs to be retrieved.
textbox = gr.Textbox(label='Query')
company = gr.Radio(['google', 'Meta', 'aws', 'IBM', 'msft'])

# Create the interface
demo = gr.Interface(fn = predict, inputs = [textbox,company], outputs = 'text',
                    description = 'RAG on 2023 10-k statement by Google, Meta, Amazon, IBM, and Microsoft',
                    title = '10-k Questioner')

demo.queue()
demo.launch()