File size: 11,248 Bytes
d1c8f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import stat
import gradio as gr
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.postprocessor import MetadataReplacementPostProcessor
from llama_index.core import StorageContext
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
import zipfile
import requests
import torch
from llama_index.core import Settings
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
import sys
import logging
import os


enable_rerank = True
# sentence_window,naive,recursive_retrieval
retrieval_strategy = "sentence_window"
base_embedding_source = "hf"  # local,openai,hf
# intfloat/multilingual-e5-small local:BAAI/bge-small-en-v1.5 text-embedding-3-small nvidia/NV-Embed-v2 Alibaba-NLP/gte-large-en-v1.5
base_embedding_model = "Alibaba-NLP/gte-large-en-v1.5"
# meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-3B-Instruct meta-llama/Llama-2-7b-chat-hf google/gemma-2-9b CohereForAI/c4ai-command-r-plus CohereForAI/aya-23-8B
base_llm_model = "mistralai/Mistral-7B-Instruct-v0.3"
# AdaptLLM/finance-chat
base_llm_source = "hf"  # cohere,hf,anthropic
base_similarity_top_k = 20


# ChromaDB
env_extension = "_large"  # _large _dev_window _large_window
db_collection = f"gte{env_extension}"  # intfloat gte
read_db = True
active_chroma = True
root_path = "."
chroma_db_path = f"{root_path}/chroma_db"  # ./chroma_db
# ./processed_files.json
processed_files_log = f"{root_path}/processed_files{env_extension}.json"


# check hyperparameter
if retrieval_strategy not in ["sentence_window", "naive"]:  # recursive_retrieval
    raise Exception(f"{retrieval_strategy} retrieval_strategy is not support")


os.environ["OPENAI_API_KEY"] = 'sk-xxxxxxxxxx'
hf_api_key = os.getenv("HF_API_KEY")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))


torch.cuda.empty_cache()

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

print(f"loading embedding ..{base_embedding_model}")
if base_embedding_source == 'hf':
    from llama_index.embeddings.huggingface import HuggingFaceEmbedding
    Settings.embed_model = HuggingFaceEmbedding(
        model_name=base_embedding_model, trust_remote_code=True)  # ,
else:
    raise Exception("embedding model is invalid")

# setup prompts - specific to StableLM
if base_llm_source == 'hf':
    from llama_index.core import PromptTemplate

    # This will wrap the default prompts that are internal to llama-index
    # taken from https://huggingface.co/Writer/camel-5b-hf
    query_wrapper_prompt = PromptTemplate(
        "Below is an instruction that describes a task. "
        "you need to make sure that user's question and retrived context mention the same stock symbol if not please give no answer to user"
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{query_str}\n\n### Response:"
    )

if base_llm_source == 'hf':
    llm = HuggingFaceLLM(
        context_window=2048,
        max_new_tokens=512,  # 256
        generate_kwargs={"temperature": 0.1, "do_sample": False},  # 0.25
        query_wrapper_prompt=query_wrapper_prompt,
        tokenizer_name=base_llm_model,
        model_name=base_llm_model,
        device_map="auto",
        tokenizer_kwargs={"max_length": 2048},
        # uncomment this if using CUDA to reduce memory usage
        model_kwargs={"torch_dtype": torch.float16}
    )

    Settings.chunk_size = 512
    Settings.llm = llm

"""#### Load documents, build the VectorStoreIndex"""


def download_and_extract_chroma_db(url, destination):
    """Download and extract ChromaDB from Hugging Face Datasets."""
    # Create destination folder if it doesn't exist
    if not os.path.exists(destination):
        os.makedirs(destination)
    else:
        # If the folder exists, remove it to ensure a fresh extract
        print("Destination folder exists. Removing it...")
        for root, dirs, files in os.walk(destination, topdown=False):
            for file in files:
                os.remove(os.path.join(root, file))
            for dir in dirs:
                os.rmdir(os.path.join(root, dir))
        print("Destination folder cleared.")

    db_zip_path = os.path.join(destination, "chroma_db.zip")
    if not os.path.exists(db_zip_path):
        # Download the ChromaDB zip file
        print("Downloading ChromaDB from Hugging Face Datasets...")
        headers = {
            "Authorization": f"Bearer {hf_api_key}"
        }
        response = requests.get(url, headers=headers, stream=True)
        response.raise_for_status()
        with open(db_zip_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print("Download completed.")
    else:
        print("Zip file already exists, skipping download.")

    # Extract the zip file
    print("Extracting ChromaDB...")
    with zipfile.ZipFile(db_zip_path, 'r') as zip_ref:
        zip_ref.extractall(destination)
    print("Extraction completed. Zip file retained.")


# URL to your dataset hosted on Hugging Face
chroma_db_url = "https://huggingface.co/datasets/iamboolean/set50-db/resolve/main/chroma_db.zip"

# Local destination for the ChromaDB
chroma_db_path_extract = "./"  # You can change this to your desired path

# Download and extract the ChromaDB
download_and_extract_chroma_db(chroma_db_url, chroma_db_path_extract)

# Define ChromaDB client (persistent mode)er
db = chromadb.PersistentClient(path=chroma_db_path)
print(f"db path:{chroma_db_path}")
chroma_collection = db.get_or_create_collection(db_collection)
print(f"db collection:{db_collection}")


# Set up ChromaVectorStore and embeddings
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

document_count = chroma_collection.count()
print(f"Total documents in the collection: {document_count}")

index = VectorStoreIndex.from_vector_store(
    vector_store=vector_store,
    # embed_model=embed_model,
)

"""#### Query Index"""


rerank = SentenceTransformerRerank(
    model="cross-encoder/ms-marco-MiniLM-L-2-v2", top_n=10
)
node_postprocessors = []
# node_postprocessors.append(SimilarityPostprocessor(similarity_cutoff=0.6))

if retrieval_strategy == 'sentence_window':
    node_postprocessors.append(
        MetadataReplacementPostProcessor(target_metadata_key="window"))


if enable_rerank:
    node_postprocessors.append(rerank)


query_engine = index.as_query_engine(
    similarity_top_k=base_similarity_top_k,
    # the target key defaults to `window` to match the node_parser's default
    node_postprocessors=node_postprocessors,
)


def metadata_formatter(metadata):
    company_symbol = metadata['file_name'].split(
        '-')[0]  # Split at '-' and take the first part
    # Split at '-' and then '.' to extract the year
    year = metadata['file_name'].split('-')[1].split('.')[0]
    page_number = metadata['page_label']

    return f"Company File: {metadata['file_name'].split('-')[0]}, Year: {metadata['file_name'].split('-')[1].split('.')[0]}, Page Number: {metadata['page_label']}"


def query_journal(question):

    response = query_engine.query(question)  # Query the index
    matched_nodes = response.source_nodes  # Extract matched nodes

    # Prepare the matched nodes details
    retrieved_context = "\n".join([
        # f"Node ID: {node.node_id}\n"
        # f"Matched Content: {node.node.text}\n"
        # f"Metadata: {node.node.metadata if node.node.metadata else 'None'}"
        f"Metadata: {metadata_formatter(node.node.metadata) if node.node.metadata else 'None'}"
        for node in matched_nodes
    ])

    generated_answer = str(response)

    # Return both retrieved context and detailed matched nodes
    return retrieved_context, generated_answer


# Define the Gradio interface
with gr.Blocks() as app:
    # Title
    gr.Markdown(
        """
        <div style="text-align: center;">
            <h1>SET50RAG: Retrieval-Augmented Generation for Thai Public Companies Question Answering</h1>
        </div>
        """
    )

    # Description
    gr.Markdown(
        """
        The **SET50RAG** tool provides an interactive way to analyze and extract insights from **243 annual reports** of Thai public companies spanning **5 years**.
        By leveraging advanced **Retrieval-Augmented Generation**, including **GTE-Large embedding models**, **Sentence Window with Reranking**, and powerful **Large Language Models (LLMs)** like **Mistral-7B**, the system efficiently retrieves and answers complex financial queries.
        This scalable and cost-effective solution reduces reliance on parametric knowledge, ensuring contextually accurate and relevant responses.
        """
    )

    # How to Use Section
    gr.Markdown(
        """
        ### How to Use
        1. Type your question in the box or select an example question below.
        2. Click **Submit** to retrieve the context and get an AI-generated answer.
        3. Review the retrieved context and the generated answer to gain insights.
        ---
        """
    )

    # Example Questions Section
    gr.Markdown(
        """
        ### Example Questions
        - What is the revenue of PTTOR in 2022?
        - what is effect of COVID-19 on BDMS show me in Timeline format from 2019 to 2023?
        - How does CPALL plan for electric vehicles?
        """
    )

    # Interactive Section (RAG Box)
    with gr.Row():
        with gr.Column():
            user_question = gr.Textbox(
                label="Ask a Question",
                placeholder="Type your question here, e.g., 'What is the revenue of PTTOR in 2022?'",
            )
            example_question_button = gr.Button("Use Example Question")
        with gr.Column():
            generated_answer = gr.Textbox(
                label="Generated Answer",
                placeholder="The AI-generated answer will appear here.",
                interactive=False,
            )
            retrieved_context = gr.Textbox(
                label="Retrieved Context",
                placeholder="Relevant context will appear here.",
                interactive=False,
            )

    # Button for user interaction
    submit_button = gr.Button("Submit")

    # Example question logic
    def use_example_question():
        return "What is the revenue of PTTOR in 2022?"

    example_question_button.click(
        use_example_question, inputs=[], outputs=[user_question]
    )

    # Interaction logic for submitting user queries
    submit_button.click(
        query_journal, inputs=[user_question], outputs=[
            retrieved_context, generated_answer]
    )

    # Footer
    gr.Markdown(
        """
        ---
        ### Limitations and Bias:
        - Optimized for Thai financial reports from SET50 companies. Results may vary for other domains.
        - Retrieval and accuracy depend on data quality and embedding models.
        """
    )

# Launch the app
# app.launch()
app.launch(server_name="0.0.0.0")  # , server_port=7860