Spaces:
Sleeping
Sleeping
update
Browse files- LICENSE +21 -0
- README.md +66 -12
- app.py +173 -31
- app1.py +50 -0
- config.yaml +21 -0
- data/bk-tuyen-sinh-2024.docx +0 -0
- prepare_data.py +100 -0
- processed_chunks.pickle +3 -0
- processed_data/processed_bk-tuyen-sinh-2024.docx +0 -0
- requirements.txt +12 -2
- utils/process_tables.py +141 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Tran Anh Quoc
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,66 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RAG-Based Chatbot System
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
This project aims to build a Retrieval-Augmented Generation (RAG)-based chatbot system. The chatbot utilizes context-aware chunking for efficient document processing and leverages open-source models for embeddings and language generation.
|
5 |
+
|
6 |
+
## Features
|
7 |
+
- **Context-Aware Chunking**: Implements an optimal, manual chunking strategy that uses special chunk markers to separate chunks within documents. This allows for easy extraction of chunks using Python's split() method. Make sure the documents users upload have been chunked using a special chunk marker separator for effective processing.
|
8 |
+
|
9 |
+
- **Data Processing Tool**: - **Data Processing Tool**: Converts tables in documents into HTML tables to handle challenges such as long tables and merged cells. The tool requires users to bold headers in tables for clarity and automatically identifies table types, including those with 1 header or more than 2 headers.
|
10 |
+
|
11 |
+
- **Open-Source Models**: Uses open-source models instead of proprietary ones like those from OpenAI, providing a cost-effective and flexible solution.
|
12 |
+
- **Embedding Model**: Utilizes `intfloat/multilingual-e5-small`, which is highly efficient and particularly effective for Vietnamese text.
|
13 |
+
- **Language Model**: Uses `Viet-Mistral/Vistral-7B-Chat`, a language model based on Mistral, with continued pretraining on Vietnamese for better generation performance.
|
14 |
+
|
15 |
+
## Installation
|
16 |
+
1. Clone the repository:
|
17 |
+
```sh
|
18 |
+
git clone https://github.com/quoctata2911/RAG-based-ChatBot-System.git
|
19 |
+
```
|
20 |
+
|
21 |
+
2. Navigate to the project directory:
|
22 |
+
```sh
|
23 |
+
cd RAG-Based-Chatbot-System
|
24 |
+
```
|
25 |
+
|
26 |
+
3. Install the required dependencies:
|
27 |
+
```sh
|
28 |
+
pip install -r requirements.txt
|
29 |
+
```
|
30 |
+
|
31 |
+
## Usage
|
32 |
+
Upload your Word .docx documents into the data folder. Ensure that each document has been chunked using a special chunk marker separator as specified in the config.yaml file.
|
33 |
+
|
34 |
+
1. Configure the chunk marker:
|
35 |
+
- Open the `config.yaml` file located in the project directory.
|
36 |
+
- Locate the parameter defining the chunk marker and adjust it as needed for your document segmentation requirements.
|
37 |
+
|
38 |
+
2. Prepare the data:
|
39 |
+
```sh
|
40 |
+
python prepare_data.py
|
41 |
+
```
|
42 |
+
3. Run the chatbot:
|
43 |
+
```sh
|
44 |
+
python chat.py
|
45 |
+
```
|
46 |
+
|
47 |
+
## Project Structure
|
48 |
+
- **prepare_data.py**: Script to preprocess and chunk documents, converting tables into HTML and segmenting them with chunk markers.
|
49 |
+
- **chat.py**: Main script to run the chatbot system.
|
50 |
+
|
51 |
+
## Models
|
52 |
+
- **Embedding Model**: We use the `intfloat/multilingual-e5-small` model for generating embeddings. This model is particularly effective for Vietnamese text, outperforming other models in our benchmarks.
|
53 |
+
|
54 |
+
- **Language Model**: The language model used is Vistral, a variant of the Mistral model that has been further pre-trained on Vietnamese text for improved performance in language generation tasks.
|
55 |
+
|
56 |
+
## Benchmarking and Performance
|
57 |
+
Through extensive benchmarking, the `intfloat/multilingual-e5-small` model has proven to be the best choice for Vietnamese embeddings, offering a balance of efficiency and performance. The Vistral model enhances language generation capabilities, ensuring the chatbot responds accurately and naturally in Vietnamese.
|
58 |
+
|
59 |
+
## Contributions
|
60 |
+
We welcome contributions to improve the RAG-ChatBot. Please fork the repository and create a pull request with your changes. For major changes, please open an issue first to discuss what you would like to change.
|
61 |
+
|
62 |
+
## License
|
63 |
+
This project is licensed under the MIT License. See the LICENSE file for more details.
|
64 |
+
|
65 |
+
## Contact
|
66 |
+
For any questions or suggestions, please contact me at quoctrananh2911@gmail.com
|
app.py
CHANGED
@@ -1,50 +1,192 @@
|
|
1 |
-
import
|
2 |
import torch
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
def run(text, intensity):
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
max_length=intensity
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
top_k=40,
|
21 |
-
num_beams=5,
|
22 |
-
early_stopping=True,
|
23 |
-
no_repeat_ngram_size=2,
|
24 |
-
num_return_sequences=2)
|
25 |
-
|
26 |
-
for i, sample_output in enumerate(sample_outputs):
|
27 |
-
res +="Mẫu số {}\n \n{}".format(i+1, tokenizer.decode(sample_output.tolist()))
|
28 |
-
res +='\n \n \n \n'
|
29 |
-
return res
|
30 |
-
|
31 |
-
# demo = gr.Interface(
|
32 |
-
# fn=run,
|
33 |
-
# inputs=["text", "slider"],
|
34 |
-
# outputs=["text"],
|
35 |
-
# )
|
36 |
-
|
37 |
-
demo = gr.Interface(fn=run,
|
38 |
inputs=[gr.Textbox(label="Nhập vào nội dung input",value="Con đường xưa em đi"),gr.Slider(label="Độ dài output muốn tạo ra", value=20, minimum=10, maximum=100, step=2)],
|
39 |
outputs=gr.Textbox(label="Output"), # <-- Number of output components: 1
|
40 |
)
|
41 |
|
42 |
-
demo.launch()
|
43 |
|
44 |
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
|
|
48 |
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
|
|
1 |
+
import yaml
|
2 |
import torch
|
3 |
+
import logging
|
4 |
+
import argparse
|
5 |
+
import warnings
|
6 |
+
import pandas as pd
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from jsonargparse import CLI
|
9 |
+
from types import SimpleNamespace
|
10 |
+
from llama_index.core.schema import TextNode
|
11 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
12 |
+
from llama_index.core import Prompt, Settings, VectorStoreIndex
|
13 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def load_config(config_path='config.yaml'):
|
20 |
+
print('-> Loading config file ...')
|
21 |
+
cfg = yaml.safe_load(
|
22 |
+
open(config_path).read()
|
23 |
+
)
|
24 |
+
|
25 |
+
for k,v in cfg.items():
|
26 |
+
if type(v) == dict:
|
27 |
+
cfg[k] = SimpleNamespace(**v)
|
28 |
+
cfg = SimpleNamespace(**cfg)
|
29 |
+
return cfg
|
30 |
+
|
31 |
+
def get_prompt_template():
|
32 |
+
template = (
|
33 |
+
"Bạn là trợ lý ảo hữu ích và thông minh được huấn luyên được để trả lời các câu hỏi từ người dùng giữa trên các thông tin ngữ cảnh liên quan được cung cấp\n"
|
34 |
+
"Thông tin ngữ cảnh:\n"
|
35 |
+
"---------------------\n"
|
36 |
+
"{context_str}"
|
37 |
+
"\n---------------------\n"
|
38 |
+
"Dựa trên những thông tin ngữ cảnh bên trên, hãy trả lời câu hỏi sau: {query_str}\n"
|
39 |
+
)
|
40 |
+
qa_template = Prompt(template)
|
41 |
+
return qa_template
|
42 |
+
|
43 |
+
def reset_settings(cfg):
|
44 |
+
embed_model =HuggingFaceEmbeddings(
|
45 |
+
model_name=cfg.architecture.embedding_model
|
46 |
+
)
|
47 |
+
Settings.embed_model = embed_model
|
48 |
+
Settings.llm = None
|
49 |
|
50 |
+
def get_retriever(cfg, prompt_template):
|
51 |
+
chunks = pd.read_pickle('processed_chunks.pickle')['chunk'].values.tolist()
|
52 |
+
nodes = [TextNode(text=chunk) for chunk in chunks]
|
53 |
+
index = VectorStoreIndex(nodes=nodes)
|
54 |
+
retriever = index.as_query_engine(
|
55 |
+
similarity_top_k=cfg.retrieve.top_k,
|
56 |
+
text_qa_template=prompt_template
|
57 |
+
)
|
58 |
+
return retriever
|
59 |
|
60 |
+
def load_tokenizer(cfg):
|
61 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
62 |
+
cfg.architecture.llm_model,
|
63 |
+
token=cfg.architecture.hf_token
|
64 |
+
)
|
65 |
+
|
66 |
+
if tokenizer.pad_token is None:
|
67 |
+
tokenizer.pad_token = tokenizer.eos_token
|
68 |
+
return tokenizer
|
69 |
+
|
70 |
+
def get_llm(cfg):
|
71 |
+
if cfg.architecture.llm_quantized:
|
72 |
+
bnb_config = BitsAndBytesConfig(
|
73 |
+
load_in_4bit=True,
|
74 |
+
bnb_4bit_use_double_quant=True,
|
75 |
+
bnb_4bit_quant_type="nf4",
|
76 |
+
bnb_4bit_compute_dtype=torch.float16
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
bnb_config = None
|
80 |
+
|
81 |
+
|
82 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
83 |
+
cfg.architecture.llm_model,
|
84 |
+
torch_dtype=torch.bfloat16,
|
85 |
+
device_map=cfg.environment.device,
|
86 |
+
token=cfg.architecture.hf_token,
|
87 |
+
low_cpu_mem_usage=True,
|
88 |
+
quantization_config=bnb_config,
|
89 |
+
)
|
90 |
+
|
91 |
+
return llm.eval()
|
92 |
|
93 |
|
94 |
def run(text, intensity):
|
95 |
+
prompt = retriever.query(text).response
|
96 |
+
prompt = tokenizer.bos_token + '[INST] ' + prompt + ' [/INST]'
|
97 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
98 |
+
input_ids = tokenizer([prompt], return_tensors='pt').to(cfg.environment.device)
|
99 |
+
|
100 |
+
_ = language_model.generate(
|
101 |
+
**input_ids,
|
102 |
+
streamer=streamer,
|
103 |
+
pad_token_id=tokenizer.pad_token_id,
|
104 |
+
max_new_tokens=cfg.generation.max_new_tokens,
|
105 |
+
do_sample=cfg.generation.do_sample,
|
106 |
+
temperature=cfg.generation.temperature
|
107 |
+
)
|
108 |
+
|
109 |
+
# print(20*'---')
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
res="Chatbot Data Mining 2024 \n \n \n"
|
114 |
max_length=intensity
|
115 |
+
|
116 |
+
return _
|
117 |
+
|
118 |
+
|
119 |
+
def vistral_chat(cfg, retriever, tokenizer, language_model):
|
120 |
+
demo = gr.Interface(fn=run,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
inputs=[gr.Textbox(label="Nhập vào nội dung input",value="Con đường xưa em đi"),gr.Slider(label="Độ dài output muốn tạo ra", value=20, minimum=10, maximum=100, step=2)],
|
122 |
outputs=gr.Textbox(label="Output"), # <-- Number of output components: 1
|
123 |
)
|
124 |
|
125 |
+
demo.launch()
|
126 |
|
127 |
|
128 |
|
129 |
+
# while True:
|
130 |
+
# user_query = input('👨🦰 ')
|
131 |
+
# prompt = retriever.query(user_query).response
|
132 |
+
# prompt = tokenizer.bos_token + '[INST] ' + prompt + ' [/INST]'
|
133 |
+
# streamer = TextStreamer(tokenizer, skip_prompt=True)
|
134 |
+
# input_ids = tokenizer([prompt], return_tensors='pt').to(cfg.environment.device)
|
135 |
|
136 |
+
# _ = language_model.generate(
|
137 |
+
# **input_ids,
|
138 |
+
# streamer=streamer,
|
139 |
+
# pad_token_id=tokenizer.pad_token_id,
|
140 |
+
# max_new_tokens=cfg.generation.max_new_tokens,
|
141 |
+
# do_sample=cfg.generation.do_sample,
|
142 |
+
# temperature=cfg.generation.temperature
|
143 |
+
# )
|
144 |
|
145 |
+
# print(20*'---')
|
146 |
|
147 |
|
148 |
+
def main(config_path):
|
149 |
+
# Configure logging
|
150 |
+
logging.basicConfig(level=logging.INFO,
|
151 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
152 |
+
logger = logging.getLogger(__name__)
|
153 |
+
|
154 |
+
try:
|
155 |
+
# Log the start of the process
|
156 |
+
logger.info("Starting the process with config file: %s", config_path)
|
157 |
+
|
158 |
+
# Load configuration from the file
|
159 |
+
config = load_config(config_path)
|
160 |
+
|
161 |
+
# Load necessary components
|
162 |
+
prompt_template = get_prompt_template()
|
163 |
+
|
164 |
+
# Replace OpenAI embed model and llm with custom ones
|
165 |
+
reset_settings(config)
|
166 |
+
|
167 |
+
# Get retriever
|
168 |
+
retriever = get_retriever(config, prompt_template)
|
169 |
+
|
170 |
+
# Load tokenizer and language model
|
171 |
+
tokenizer = load_tokenizer(config)
|
172 |
+
language_model = get_llm(config)
|
173 |
+
|
174 |
+
# Start the command line interface
|
175 |
+
vistral_chat(config, retriever, tokenizer, language_model)
|
176 |
+
|
177 |
+
# Log successful completion
|
178 |
+
logger.info("Process completed successfully.")
|
179 |
+
|
180 |
+
except FileNotFoundError as e:
|
181 |
+
logger.error("Configuration file not found: %s", e)
|
182 |
+
except Exception as e:
|
183 |
+
logger.exception("An error occurred: %s", e)
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
parser = argparse.ArgumentParser(description='Process some configurations.')
|
187 |
+
parser.add_argument('--config', type=str, default='config.yaml', help='Path to the configuration file')
|
188 |
+
args = parser.parse_args()
|
189 |
+
main(args.config)
|
190 |
+
|
191 |
+
|
192 |
|
app1.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
4 |
+
|
5 |
+
|
6 |
+
tokenizer = GPT2Tokenizer.from_pretrained('NlpHUST/gpt2-vietnamese')
|
7 |
+
model = GPT2LMHeadModel.from_pretrained('NlpHUST/gpt2-vietnamese')
|
8 |
+
# max_length = 100
|
9 |
+
|
10 |
+
|
11 |
+
def run(text, intensity):
|
12 |
+
res="Tham khảo NlpHUST model \n \n \n"
|
13 |
+
max_length=intensity
|
14 |
+
|
15 |
+
input_ids = tokenizer.encode(text, return_tensors='pt')
|
16 |
+
sample_outputs = model.generate(input_ids,pad_token_id=tokenizer.eos_token_id,
|
17 |
+
do_sample=True,
|
18 |
+
max_length=max_length,
|
19 |
+
min_length=5,
|
20 |
+
top_k=40,
|
21 |
+
num_beams=5,
|
22 |
+
early_stopping=True,
|
23 |
+
no_repeat_ngram_size=2,
|
24 |
+
num_return_sequences=2)
|
25 |
+
|
26 |
+
for i, sample_output in enumerate(sample_outputs):
|
27 |
+
res +="Mẫu số {}\n \n{}".format(i+1, tokenizer.decode(sample_output.tolist()))
|
28 |
+
res +='\n \n \n \n'
|
29 |
+
return res
|
30 |
+
|
31 |
+
# demo = gr.Interface(
|
32 |
+
# fn=run,
|
33 |
+
# inputs=["text", "slider"],
|
34 |
+
# outputs=["text"],
|
35 |
+
# )
|
36 |
+
|
37 |
+
demo = gr.Interface(fn=run,
|
38 |
+
inputs=[gr.Textbox(label="Nhập vào nội dung input",value="Con đường xưa em đi"),gr.Slider(label="Độ dài output muốn tạo ra", value=20, minimum=10, maximum=100, step=2)],
|
39 |
+
outputs=gr.Textbox(label="Output"), # <-- Number of output components: 1
|
40 |
+
)
|
41 |
+
|
42 |
+
demo.launch()
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
architecture:
|
2 |
+
llm_model: Viet-Mistral/Vistral-7B-Chat
|
3 |
+
embedding_model: intfloat/multilingual-e5-small
|
4 |
+
hf_token: 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
|
5 |
+
llm_quantized: False
|
6 |
+
dataset:
|
7 |
+
chunk_marker: BK_CHUNK
|
8 |
+
required_exts: .docx
|
9 |
+
data_dir: ./data
|
10 |
+
processed_data_dir: ./processed_data
|
11 |
+
signal_type: html_table
|
12 |
+
keep_bold: True
|
13 |
+
retrieve:
|
14 |
+
top_k: 2
|
15 |
+
generation:
|
16 |
+
max_new_tokens: 2048
|
17 |
+
temperature: 0.6
|
18 |
+
do_sample: True
|
19 |
+
top_p: 0.9
|
20 |
+
environment:
|
21 |
+
device: 'cuda'
|
data/bk-tuyen-sinh-2024.docx
ADDED
Binary file (48.1 kB). View file
|
|
prepare_data.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import yaml
|
4 |
+
import logging
|
5 |
+
import pandas as pd
|
6 |
+
from pathlib import Path
|
7 |
+
from jsonargparse import CLI
|
8 |
+
from docx.api import Document
|
9 |
+
from types import SimpleNamespace
|
10 |
+
from llama_index.core import SimpleDirectoryReader
|
11 |
+
from utils.process_tables import extract_and_replace_docx_tables
|
12 |
+
|
13 |
+
# Configure logging
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO,
|
16 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
17 |
+
handlers=[
|
18 |
+
logging.FileHandler("script.log"),
|
19 |
+
logging.StreamHandler()
|
20 |
+
]
|
21 |
+
)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
def load_config(file_path='config.yaml'):
|
25 |
+
logger.info('Loading config file ...')
|
26 |
+
try:
|
27 |
+
with open(file_path, 'r') as file:
|
28 |
+
cfg = yaml.safe_load(file)
|
29 |
+
for k, v in cfg.items():
|
30 |
+
if isinstance(v, dict):
|
31 |
+
cfg[k] = SimpleNamespace(**v)
|
32 |
+
logger.info('Config file loaded successfully.')
|
33 |
+
return SimpleNamespace(**cfg)
|
34 |
+
except Exception as e:
|
35 |
+
logger.error(f'Error loading config file: {e}')
|
36 |
+
raise
|
37 |
+
|
38 |
+
cfg = load_config()
|
39 |
+
|
40 |
+
def process_docx_files(data_dir=Path(cfg.dataset.data_dir),
|
41 |
+
processed_data_dir=Path(cfg.dataset.processed_data_dir),
|
42 |
+
chunk_marker=cfg.dataset.chunk_marker):
|
43 |
+
try:
|
44 |
+
if not os.path.exists(processed_data_dir):
|
45 |
+
shutil.rmtree(processed_data_dir)
|
46 |
+
|
47 |
+
docx_files = [file for file in os.listdir(data_dir) if file.endswith('.docx')]
|
48 |
+
logger.info(f'Found {len(docx_files)} DOCX files to process.')
|
49 |
+
|
50 |
+
for fname in docx_files:
|
51 |
+
document, html_chunked_tables = extract_and_replace_docx_tables(
|
52 |
+
docx_file=data_dir / fname,
|
53 |
+
chunk_marker=chunk_marker
|
54 |
+
)
|
55 |
+
document.save(processed_data_dir / f'processed_{fname}')
|
56 |
+
logger.info(f'Processed and saved {fname}')
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(f'Error processing DOCX files: {e}')
|
59 |
+
raise
|
60 |
+
|
61 |
+
def load_processed_data(processed_data_dir=Path(cfg.dataset.processed_data_dir)):
|
62 |
+
try:
|
63 |
+
documents = SimpleDirectoryReader(
|
64 |
+
input_dir=processed_data_dir,
|
65 |
+
required_exts=[cfg.dataset.required_exts],
|
66 |
+
).load_data()
|
67 |
+
logger.info('Processed data loaded successfully.')
|
68 |
+
return documents
|
69 |
+
except Exception as e:
|
70 |
+
logger.error(f'Error loading processed data: {e}')
|
71 |
+
raise
|
72 |
+
|
73 |
+
def get_chunks(documents, chunk_marker=cfg.dataset.chunk_marker):
|
74 |
+
try:
|
75 |
+
chunks = [chunk.strip() for doc in documents for chunk in doc.text.split(chunk_marker) if chunk.strip()]
|
76 |
+
logger.info(f'Extracted {len(chunks)} chunks from documents.')
|
77 |
+
return chunks
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f'Error extracting chunks: {e}')
|
80 |
+
raise
|
81 |
+
|
82 |
+
def main():
|
83 |
+
logger.info('Starting document processing ...')
|
84 |
+
try:
|
85 |
+
process_docx_files()
|
86 |
+
|
87 |
+
documents = load_processed_data()
|
88 |
+
chunks = get_chunks(documents)
|
89 |
+
num_chunks = len(chunks)
|
90 |
+
logger.info(f'Total number of chunks: {num_chunks}')
|
91 |
+
|
92 |
+
df_chunks = pd.DataFrame({'chunk': chunks})
|
93 |
+
df_chunks.to_pickle('processed_chunks.pickle')
|
94 |
+
logger.info('All chunks saved to processed_chunks.pickle')
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(f'Error in main processing: {e}')
|
97 |
+
raise
|
98 |
+
|
99 |
+
if __name__ == '__main__':
|
100 |
+
main()
|
processed_chunks.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c48413343bfeff0b6cff53fcb2328358c2b8d0e1006f89b4ebd44e679ec96bb2
|
3 |
+
size 33042
|
processed_data/processed_bk-tuyen-sinh-2024.docx
ADDED
Binary file (26.6 kB). View file
|
|
requirements.txt
CHANGED
@@ -1,2 +1,12 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
llama-index==0.10.44
|
2 |
+
mammoth==1.7.1
|
3 |
+
python-docx==1.1.2
|
4 |
+
docx2txt==0.8
|
5 |
+
langchain-community==0.2.4
|
6 |
+
sentence-transformers==3.0.1
|
7 |
+
accelerate==0.31.0
|
8 |
+
bitsandbytes==0.43.1
|
9 |
+
langchain-huggingface==0.0.3
|
10 |
+
jsonargparse==4.29.0
|
11 |
+
llama-index-embeddings-langchain
|
12 |
+
gradio
|
utils/process_tables.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import mammoth
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
from docx.api import Document
|
5 |
+
import re
|
6 |
+
|
7 |
+
def extract_and_replace_docx_tables(docx_file, chunk_marker):
|
8 |
+
start_time = time.time() # Record start time
|
9 |
+
|
10 |
+
document = Document(docx_file)
|
11 |
+
docx_tables = document.tables
|
12 |
+
total_tables = len(document.tables)
|
13 |
+
|
14 |
+
with open(docx_file, "rb") as docx_file:
|
15 |
+
result = mammoth.convert_to_html(docx_file)
|
16 |
+
html = result.value
|
17 |
+
|
18 |
+
tables = extract_html_tables(html)
|
19 |
+
|
20 |
+
html_chunked_tables = get_html_table_chunks(tables, chunk_marker=chunk_marker)
|
21 |
+
|
22 |
+
|
23 |
+
html_tables = []
|
24 |
+
for table in html_chunked_tables:
|
25 |
+
temp_document = Document()
|
26 |
+
html_table = temp_document.add_paragraph(table)._element
|
27 |
+
html_table.alignment = 0
|
28 |
+
html_tables.append(html_table)
|
29 |
+
|
30 |
+
|
31 |
+
track = 0
|
32 |
+
while len(document.tables) > 0:
|
33 |
+
track += 1
|
34 |
+
try:
|
35 |
+
html_table = html_tables[0]
|
36 |
+
document.element.body.replace(document.tables[0]._element, html_table)
|
37 |
+
html_tables.remove(html_table)
|
38 |
+
end_time = time.time() # Record end time
|
39 |
+
# print(f'{track} of {total_tables} | Success | Time: {end_time - start_time:.2f} seconds')
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
print(f'{track} of {total_tables} | Fail: {e}')
|
43 |
+
if track >= 200:
|
44 |
+
break
|
45 |
+
|
46 |
+
return document, html_chunked_tables
|
47 |
+
|
48 |
+
def extract_html_tables(html):
|
49 |
+
soup = BeautifulSoup(html, 'html.parser')
|
50 |
+
tables = soup.find_all('table')
|
51 |
+
return tables
|
52 |
+
|
53 |
+
def get_html_table_chunks(tables, chunk_marker):
|
54 |
+
|
55 |
+
html_chunk_marker = '<strong>' + chunk_marker + '</strong>'
|
56 |
+
|
57 |
+
html_table_chunks = []
|
58 |
+
|
59 |
+
for table_soup in tables:
|
60 |
+
|
61 |
+
html_table_string = str(table_soup)
|
62 |
+
html_table_string = html_table_string.replace('<table>', '<table>\n')
|
63 |
+
html_table_string = html_table_string.replace('<tr>', '\n<tr>\n')
|
64 |
+
html_table_string = html_table_string.replace('</tr>', '\n</tr>')
|
65 |
+
html_table_string = html_table_string.replace('<thead>', '<thead>')
|
66 |
+
html_table_string = html_table_string.replace('</thead>', '\n</thead>\n')
|
67 |
+
html_table_string = html_table_string.replace('<tbody>', '<tbody>')
|
68 |
+
html_table_string = html_table_string.replace('</tbody>', '\n</tbody>\n')
|
69 |
+
|
70 |
+
with open('table_html.txt', mode='w', encoding='utf8') as f:
|
71 |
+
f.write(html_table_string)
|
72 |
+
|
73 |
+
with open('table_html.txt', mode='r', encoding='utf8') as f:
|
74 |
+
lines = f.readlines()
|
75 |
+
|
76 |
+
start_table = lines[0].strip()
|
77 |
+
end_table = lines[-1].strip()
|
78 |
+
|
79 |
+
# Get start and end tags for tbody
|
80 |
+
start_tbody = '<tbody>' if '<tbody>' in html_table_string else ''
|
81 |
+
end_tbody = '</tbody>' if '</tbody>' in html_table_string else ''
|
82 |
+
|
83 |
+
# Extract and clean headers if present
|
84 |
+
headers = str(table_soup.find('thead')) if 'thead' in html_table_string else ''
|
85 |
+
headers = re.sub(r'>\n\s*<', '><', headers)
|
86 |
+
|
87 |
+
processed_lines = []
|
88 |
+
for line in lines:
|
89 |
+
if chunk_marker in line:
|
90 |
+
start_index = line.find(html_chunk_marker)
|
91 |
+
chunk_start = start_index - len('<p>')
|
92 |
+
chunk_end = start_index + len(html_chunk_marker) + len('</p>')
|
93 |
+
|
94 |
+
chunk_html = line[chunk_start:chunk_end]
|
95 |
+
if chunk_html.startswith('<p>') & chunk_html.endswith('</p>'):
|
96 |
+
line = line.replace('<p>', '')
|
97 |
+
line = line.replace('</p>', '')
|
98 |
+
else:
|
99 |
+
pass
|
100 |
+
|
101 |
+
line = line.replace(html_chunk_marker, '')
|
102 |
+
line = line.replace(' </td>', '</td>').strip()
|
103 |
+
line += chunk_marker
|
104 |
+
|
105 |
+
processed_lines.append(line)
|
106 |
+
|
107 |
+
processed_lines = [line.strip() for line in processed_lines]
|
108 |
+
html_table = ''.join(processed_lines)
|
109 |
+
html_chunks = html_table.split(chunk_marker)
|
110 |
+
|
111 |
+
proccessed_html_chunks = []
|
112 |
+
for index, chunk in enumerate(html_chunks):
|
113 |
+
if index == 0:
|
114 |
+
chunk += (end_tbody + end_table)
|
115 |
+
first_chunk = chunk.replace(end_table, '')
|
116 |
+
start = first_chunk.find('<tr>')
|
117 |
+
end = first_chunk.find('</tr>') + len('</tr>')
|
118 |
+
headers = first_chunk[start:end]
|
119 |
+
|
120 |
+
|
121 |
+
elif chunk == html_chunks[-1]:
|
122 |
+
chunk = start_table + headers + start_tbody + chunk
|
123 |
+
|
124 |
+
else:
|
125 |
+
chunk = start_table + headers + start_tbody + chunk + end_tbody + end_table
|
126 |
+
|
127 |
+
proccessed_html_chunks.append(chunk)
|
128 |
+
|
129 |
+
chunks_to_html = ''
|
130 |
+
for html_chunk in proccessed_html_chunks:
|
131 |
+
chunks_to_html += wrap_signal(html_chunk, signal_type='html')
|
132 |
+
if html_chunk != proccessed_html_chunks[-1]:
|
133 |
+
chunks_to_html += f'\n\n{chunk_marker}\n\n'
|
134 |
+
|
135 |
+
html_table_chunks.append(chunks_to_html)
|
136 |
+
|
137 |
+
return html_table_chunks
|
138 |
+
|
139 |
+
def wrap_signal(data, signal_type):
|
140 |
+
data = f"```{signal_type}\n{data}\n```"
|
141 |
+
return data
|