ShynBui commited on
Commit
d15d46e
1 Parent(s): a2604b2

first commit

Browse files
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import gradio as gr
3
+ import os
4
+ from langchain.retrievers import EnsembleRetriever
5
+ from utils import *
6
+ import requests
7
+ from pyvi import ViTokenizer, ViPosTagger
8
+ import time
9
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
10
+ import torch
11
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
12
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
+ from langchain_community.chat_message_histories import ChatMessageHistory
14
+
15
+ retriever = load_the_embedding_retrieve(is_ready=False, k=10)
16
+ bm25_retriever = load_the_bm25_retrieve(k=1)
17
+
18
+ ensemble_retriever = EnsembleRetriever(
19
+ retrievers=[bm25_retriever, retriever], weights=[0.1, 0.9]
20
+ )
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN"))
23
+ model = AutoModelForQuestionAnswering.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN"))
24
+
25
+
26
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=1, openai_api_key=os.environ["OPENAI_API_KEY"])
27
+
28
+
29
+
30
+
31
+ def greet3(quote, history):
32
+ # print(history)
33
+ demo_ephemeral_chat_history = ChatMessageHistory()
34
+ if history == '':
35
+ history = [("Bạn có thể giải thích về quy chế và quyền của sinh viên tại trường này không?",
36
+ '''Quy chế và quyền của sinh viên tại trường Đại học Mở TP.HCM được quy định rõ trong các điều khoản sau:
37
+ 1. Hiệu trưởng Trường có quyền ra quyết định thành lập và quy định cụ thể về chức năng, nhiệm vụ, tổ chức và hoạt động của Hội đồng khen thưởng và kỷ luật sinh viên.
38
+ 2. Sinh viên có quyền khiếu nại về khen thưởng, kỷ luật. Khi có vi phạm kỷ luật, sinh viên có quyền được phân tích và đề nghị hình thức kỷ luật thông qua việc họp với các tổ chức sinh viên và gửi biên bản họp đến phòng Công tác sinh viên để trình Hội đồng.
39
+ 3. Sinh viên có quyền đề đạt nguyện vọng và khiếu nại lên Hiệu trưởng Trường để giải quyết các vấn đề có liên quan đến quyền, lợi ích chính đáng của sinh viên.
40
+ 4. Sinh viên được hỗ trợ giới thiệu nhà trọ theo quy định của trường.
41
+
42
+ Các chủ đề liên quan mà bạn có thể muốn tìm hiểu thêm:
43
+ - Quy chế và quyền của sinh viên tại các trường đại học khác.
44
+ - Hệ thống hỗ trợ sinh viên tại trường Đại học Mở TP.HCM.
45
+ - Quy trình khiếu nại và giải quyết tranh chấp sinh viên tại trường Đại học Mở TP.HCM.
46
+ '''),
47
+ ("Chào.",
48
+ "Chào. Chúng ta vừa bắt đầu câu chuyện thôi.")]
49
+ for user, assistant in history[-1:]:
50
+ demo_ephemeral_chat_history.add_user_message(user)
51
+ demo_ephemeral_chat_history.add_ai_message(assistant)
52
+ else:
53
+ for user, assistant in eval(history)[-1:]:
54
+ demo_ephemeral_chat_history.add_user_message(user)
55
+ demo_ephemeral_chat_history.add_ai_message(assistant)
56
+
57
+ # Summary the message
58
+
59
+ chat_history = summarize_messages(demo_ephemeral_chat_history=demo_ephemeral_chat_history, llm=llm).messages
60
+ # print("Chat history:", chat_history)
61
+
62
+ # Get the new question
63
+ new_question = get_question_from_summarize(chat_history[0].content, quote, llm)
64
+
65
+ # Retrieve
66
+ documents_query = ensemble_retriever.invoke(new_question)
67
+
68
+ # print(documents_query)
69
+
70
+ context = ''
71
+ for i in documents_query:
72
+ context += i.page_content + '\n'
73
+ # print(context)
74
+ # Get answer
75
+
76
+ answer = get_final_answer(question=new_question, context=context,
77
+ prompt=os.environ['PROMPT'], llm=llm)
78
+
79
+ return new_question, answer
80
+
81
+
82
+ if __name__ == "__main__":
83
+ quote = "Địa chỉ nhà trường?"
84
+
85
+ iface = gr.Interface(fn=greet3, inputs=["text", "text"], outputs=["text", "text"])
86
+ iface.launch(share=True)
87
+
88
+
89
+ #Những cái đã làm tốt hơn những gì - Đóng góp gì
90
+ # 1. Dataset - Xu lý
91
+ # 2. Tăng ngữ cảnh
92
+ # 3. Tăng khả năng truy vết
93
+ # 4.
raw_data/data_dang_bang.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ altair==5.2.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ asgiref==3.7.2
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ backoff==2.2.1
11
+ bcrypt==4.1.2
12
+ build==1.0.3
13
+ cachetools==5.3.3
14
+ certifi==2024.2.2
15
+ charset-normalizer==3.3.2
16
+ chroma-hnswlib==0.7.3
17
+ chromadb==0.4.24
18
+ click==8.1.7
19
+ colorama==0.4.6
20
+ coloredlogs==15.0.1
21
+ contourpy==1.2.0
22
+ cycler==0.12.1
23
+ dataclasses-json==0.6.4
24
+ Deprecated==1.2.14
25
+ exceptiongroup==1.2.0
26
+ fastapi==0.110.0
27
+ ffmpy==0.3.2
28
+ filelock==3.13.1
29
+ flatbuffers==23.5.26
30
+ fonttools==4.49.0
31
+ frozenlist==1.4.1
32
+ fsspec==2024.2.0
33
+ google-auth==2.28.1
34
+ googleapis-common-protos==1.62.0
35
+ gradio==4.19.2
36
+ gradio_client==0.10.1
37
+ greenlet==3.0.3
38
+ grpcio==1.62.0
39
+ h11==0.14.0
40
+ httpcore==1.0.4
41
+ httptools==0.6.1
42
+ httpx==0.27.0
43
+ huggingface-hub==0.21.1
44
+ humanfriendly==10.0
45
+ idna==3.6
46
+ importlib-metadata==6.11.0
47
+ importlib_resources==6.1.2
48
+ Jinja2==3.1.3
49
+ joblib==1.3.2
50
+ jsonpatch==1.33
51
+ jsonpointer==2.4
52
+ jsonschema==4.21.1
53
+ jsonschema-specifications==2023.12.1
54
+ kiwisolver==1.4.5
55
+ kubernetes==29.0.0
56
+ langchain==0.1.9
57
+ langchain-community==0.0.24
58
+ langchain-core==0.1.27
59
+ langsmith==0.1.10
60
+ markdown-it-py==3.0.0
61
+ MarkupSafe==2.1.5
62
+ marshmallow==3.21.0
63
+ matplotlib==3.8.3
64
+ mdurl==0.1.2
65
+ mmh3==4.1.0
66
+ monotonic==1.6
67
+ mpmath==1.3.0
68
+ multidict==6.0.5
69
+ mypy-extensions==1.0.0
70
+ networkx==3.2.1
71
+ numpy==1.26.4
72
+ oauthlib==3.2.2
73
+ onnxruntime==1.17.1
74
+ opentelemetry-api==1.23.0
75
+ opentelemetry-exporter-otlp-proto-common==1.23.0
76
+ opentelemetry-exporter-otlp-proto-grpc==1.23.0
77
+ opentelemetry-instrumentation==0.44b0
78
+ opentelemetry-instrumentation-asgi==0.44b0
79
+ opentelemetry-instrumentation-fastapi==0.44b0
80
+ opentelemetry-proto==1.23.0
81
+ opentelemetry-sdk==1.23.0
82
+ opentelemetry-semantic-conventions==0.44b0
83
+ opentelemetry-util-http==0.44b0
84
+ orjson==3.9.15
85
+ overrides==7.7.0
86
+ packaging==23.2
87
+ pandas==2.2.1
88
+ pillow==10.2.0
89
+ posthog==3.4.2
90
+ protobuf==4.25.3
91
+ pulsar-client==3.4.0
92
+ pyasn1==0.5.1
93
+ pyasn1-modules==0.3.0
94
+ pydantic==2.6.3
95
+ pydantic_core==2.16.3
96
+ pydub==0.25.1
97
+ Pygments==2.17.2
98
+ pyparsing==3.1.1
99
+ PyPika==0.48.9
100
+ pyproject_hooks==1.0.0
101
+ pyreadline3==3.4.1
102
+ python-crfsuite==0.9.10
103
+ python-dateutil==2.8.2
104
+ python-dotenv==1.0.1
105
+ python-multipart==0.0.9
106
+ pytz==2024.1
107
+ pyvi==0.1.1
108
+ PyYAML==6.0.1
109
+ rank-bm25==0.2.2
110
+ referencing==0.33.0
111
+ regex==2023.12.25
112
+ requests==2.31.0
113
+ requests-oauthlib==1.3.1
114
+ rich==13.7.0
115
+ rpds-py==0.18.0
116
+ rsa==4.9
117
+ ruff==0.2.2
118
+ safetensors==0.4.2
119
+ scikit-learn==1.4.1.post1
120
+ scipy==1.12.0
121
+ semantic-version==2.10.0
122
+ sentence-transformers==2.4.0
123
+ shellingham==1.5.4
124
+ six==1.16.0
125
+ sklearn-crfsuite==0.3.6
126
+ sniffio==1.3.1
127
+ SQLAlchemy==2.0.27
128
+ starlette==0.36.3
129
+ sympy==1.12
130
+ tabulate==0.9.0
131
+ tenacity==8.2.3
132
+ threadpoolctl==3.3.0
133
+ tokenizers==0.15.2
134
+ tomli==2.0.1
135
+ tomlkit==0.12.0
136
+ toolz==0.12.1
137
+ torch==2.2.1
138
+ tqdm==4.66.2
139
+ transformers==4.38.1
140
+ typer==0.9.0
141
+ typing-inspect==0.9.0
142
+ typing_extensions==4.10.0
143
+ tzdata==2024.1
144
+ urllib3==2.2.1
145
+ uvicorn==0.27.1
146
+ watchfiles==0.21.0
147
+ websocket-client==1.7.0
148
+ websockets==11.0.3
149
+ wrapt==1.16.0
150
+ yarl==1.9.4
151
+ zipp==3.17.0
152
+ ## The following requirements were added by pip freeze:
153
+ distro==1.9.0
154
+ langchain-openai==0.0.8
155
+ openai==1.13.3
156
+ tiktoken==0.6.0
table_data/data_dang_bang.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ data
2
+ NOne
utils.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from langchain_community.document_loaders import TextLoader
3
+ from langchain_community.docstore.document import Document
4
+ from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.retrievers import BM25Retriever
8
+ from langchain_community.llms import OpenAI
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.schema import AIMessage, HumanMessage
12
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
13
+ import os
14
+
15
+
16
+ def split_with_source(text, source):
17
+ splitter = CharacterTextSplitter(
18
+ separator = "\n",
19
+ chunk_size = 400,
20
+ chunk_overlap = 0,
21
+ length_function = len,
22
+ add_start_index = True,
23
+ )
24
+ documents = splitter.create_documents([text])
25
+ # print(documents)
26
+ for doc in documents:
27
+ doc.metadata["source"] = source
28
+ # print(doc.metadata)
29
+
30
+ return documents
31
+
32
+ def get_document_from_raw_text_each_line():
33
+ documents = [Document(page_content="", metadata={'source': 0})]
34
+ files = os.listdir(os.path.join(os.getcwd(), "raw_data"))
35
+ # print(files)
36
+ for i in files:
37
+ file_path = i
38
+ with open(os.path.join(os.path.join(os.getcwd(), "raw_data"),file_path), 'r', encoding="utf-8") as file:
39
+ # Xử lý bằng text_spliter
40
+ # Tiền xử lý văn bản
41
+ content = file.readlines()
42
+ text = []
43
+ #Split
44
+ for line in content:
45
+ line = line.strip()
46
+ documents.append(Document(page_content=line, metadata={"source": i}))
47
+
48
+ return documents
49
+
50
+ def count_files_in_folder(folder_path):
51
+ # Kiểm tra xem đường dẫn thư mục có tồn tại không
52
+ if not os.path.isdir(folder_path):
53
+ print("Đường dẫn không hợp lệ.")
54
+ return None
55
+
56
+ # Sử dụng os.listdir() để lấy danh sách các tập tin và thư mục trong thư mục
57
+ files = os.listdir(folder_path)
58
+
59
+ # Đếm số lượng tập tin trong danh sách
60
+ file_count = len(files)
61
+
62
+ return file_count
63
+
64
+ def get_document_from_raw_text():
65
+ documents = [Document(page_content="", metadata={'source': 0})]
66
+ files = os.listdir(os.path.join(os.getcwd(), "raw_data"))
67
+ # print(files)
68
+ for i in files:
69
+ file_path = i
70
+ with open(os.path.join(os.path.join(os.getcwd(), "raw_data"),file_path), 'r', encoding="utf-8") as file:
71
+ # Xử lý bằng text_spliter
72
+ # Tiền xử lý văn bản
73
+ content = file.read().replace('\n\n', "\n")
74
+ # content = ''.join(content.split('.'))
75
+ new_doc = content
76
+ texts = split_with_source(new_doc, i)
77
+ # texts = get_document_from_raw_text_each_line()
78
+ documents = documents + texts
79
+
80
+ ##Xử lý mỗi khi xuống dòng
81
+ # for line in file:
82
+ # # Loại bỏ khoảng trắng thừa và ký tự xuống dòng ở đầu và cuối mỗi dòng
83
+ # line = line.strip()
84
+ # documents.append(Document(page_content=line, metadata={"source": i}))
85
+ # print(documents)
86
+ return documents
87
+
88
+ def get_document_from_table():
89
+ documents = [Document(page_content="", metadata={'source': 0})]
90
+ files = os.listdir(os.path.join(os.getcwd(), "table_data"))
91
+ # print(files)
92
+ for i in files:
93
+ file_path = i
94
+ data = pd.read_csv(os.path.join(os.path.join(os.getcwd(), "table_data"),file_path))
95
+ for j, row in data.iterrows():
96
+ documents.append(Document(page_content=row['data'], metadata={"source": file_path}))
97
+ return documents
98
+
99
+ def load_the_embedding_retrieve(is_ready = False, k = 3, model= 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
100
+ embeddings = HuggingFaceEmbeddings(model_name=model)
101
+ if is_ready:
102
+ retriever = Chroma(persist_directory=os.path.join(os.getcwd(), "Data"), embedding_function=embeddings).as_retriever(
103
+ search_kwargs={"k": k}
104
+ )
105
+ else:
106
+ documents = get_document_from_raw_text() + get_document_from_table()
107
+ # print(type(documents))
108
+ retriever = Chroma.from_documents(documents, embeddings).as_retriever(
109
+ search_kwargs={"k": k}
110
+ )
111
+
112
+
113
+ return retriever
114
+
115
+ def load_the_bm25_retrieve(k = 3):
116
+ documents = get_document_from_raw_text() + get_document_from_table()
117
+ bm25_retriever = BM25Retriever.from_documents(documents)
118
+ bm25_retriever.k = k
119
+
120
+ return bm25_retriever
121
+
122
+ def get_qachain(llm_name = "gpt-3.5-turbo-0125", chain_type = "stuff", retriever = None, return_source_documents = True):
123
+ llm = ChatOpenAI(temperature=0,
124
+ model_name=llm_name)
125
+ return RetrievalQA.from_chain_type(llm=llm,
126
+ chain_type=chain_type,
127
+ retriever=retriever,
128
+ return_source_documents=return_source_documents)
129
+
130
+
131
+
132
+ def summarize_messages(demo_ephemeral_chat_history, llm):
133
+ stored_messages = demo_ephemeral_chat_history.messages
134
+ human_chat = stored_messages[0].content
135
+ ai_chat = stored_messages[1].content
136
+ if len(stored_messages) == 0:
137
+ return False
138
+ summarization_prompt = ChatPromptTemplate.from_messages(
139
+ [
140
+ (
141
+ "system", os.environ['SUMARY_MESSAGE_PROMPT'],
142
+ ),
143
+ (
144
+ "human",
145
+ '''
146
+ History:
147
+ Human: {human}
148
+ AI: {AI}
149
+ Output:'''
150
+ )
151
+ ,
152
+ ]
153
+ )
154
+ summarization_chain = summarization_prompt | llm
155
+
156
+ summary_message = summarization_chain.invoke({"AI": ai_chat, "human": human_chat})
157
+
158
+ demo_ephemeral_chat_history.clear()
159
+
160
+ demo_ephemeral_chat_history.add_message(summary_message)
161
+
162
+ return demo_ephemeral_chat_history
163
+
164
+ def get_question_from_summarize(summary, question, llm):
165
+ new_qa_prompt = ChatPromptTemplate.from_messages([
166
+ ("system", os.environ['NEW_QUESTION_PROMPT']),
167
+ ("human",
168
+ '''
169
+ Summary: {summary}
170
+ Question: {question}
171
+ Output:'''
172
+ )
173
+ ]
174
+ )
175
+
176
+ new_qa_chain = new_qa_prompt | llm
177
+ return new_qa_chain.invoke({'summary': summary, 'question': question}).content
178
+
179
+ def get_final_answer(question, context, prompt, llm):
180
+ qa_prompt = ChatPromptTemplate.from_messages(
181
+ [
182
+ ("system", prompt),
183
+ ("human", '''
184
+ Context: {context}
185
+ Question: {question}
186
+ Output:'''),
187
+ ]
188
+ )
189
+
190
+ answer_chain = qa_prompt | llm
191
+
192
+ answer = answer_chain.invoke({'question': question, 'context': context})
193
+
194
+ return answer.content
195
+
196
+
197
+
198
+ def process_llm_response(llm_response):
199
+ print(llm_response['result'])
200
+ print('\n\nSources:')
201
+ for source in llm_response["source_documents"]:
202
+ print(source.metadata['source'])
203
+
204
+
205
+
206
+
207
+
208
+
209
+