|
from langchain_community.document_loaders import TextLoader |
|
from langchain_community.docstore.document import Document |
|
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.retrievers import BM25Retriever |
|
from langchain_community.llms import OpenAI |
|
from langchain_openai import ChatOpenAI |
|
from langchain.chains import RetrievalQA |
|
from langchain.schema import AIMessage, HumanMessage |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
import os |
|
|
|
def split_with_source(text, source): |
|
splitter = CharacterTextSplitter( |
|
separator = "\n", |
|
chunk_size = 512, |
|
chunk_overlap = 128, |
|
add_start_index = True, |
|
) |
|
documents = splitter.create_documents([text]) |
|
|
|
for doc in documents: |
|
doc.metadata["source"] = source |
|
|
|
|
|
return documents |
|
|
|
|
|
def count_files_in_folder(folder_path): |
|
|
|
if not os.path.isdir(folder_path): |
|
print("Đường dẫn không hợp lệ.") |
|
return None |
|
|
|
|
|
files = os.listdir(folder_path) |
|
|
|
|
|
file_count = len(files) |
|
|
|
return file_count |
|
|
|
def get_document_from_raw_text(): |
|
documents = [Document(page_content="", metadata={'source': 0})] |
|
files = os.listdir(os.path.join(os.getcwd(), "raw_data")) |
|
|
|
for i in files: |
|
file_path = i |
|
with open(os.path.join(os.path.join(os.getcwd(), "raw_data"),file_path), 'r', encoding="utf-8") as file: |
|
|
|
|
|
content = file.read().replace('\n\n', "\n") |
|
|
|
new_doc = content |
|
texts = split_with_source(new_doc, i) |
|
documents = documents + texts |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return documents |
|
|
|
def load_the_embedding_retrieve(is_ready = False, k = 3, model= 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): |
|
embeddings = HuggingFaceEmbeddings(model_name=model) |
|
if is_ready: |
|
retriever = Chroma(persist_directory=os.path.join(os.getcwd(), "Data"), embedding_function=embeddings).as_retriever( |
|
search_kwargs={"k": k} |
|
) |
|
else: |
|
documents = get_document_from_raw_text() |
|
print(type(documents)) |
|
retriever = Chroma.from_documents(documents, embeddings).as_retriever( |
|
search_kwargs={"k": k} |
|
) |
|
|
|
|
|
return retriever |
|
|
|
def load_the_bm25_retrieve(k = 3): |
|
documents = get_document_from_raw_text() |
|
bm25_retriever = BM25Retriever.from_documents(documents) |
|
bm25_retriever.k = k |
|
|
|
return bm25_retriever |
|
|
|
def get_qachain(llm_name = "gpt-3.5-turbo-0125", chain_type = "stuff", retriever = None, return_source_documents = True): |
|
llm = ChatOpenAI(temperature=0, |
|
model_name=llm_name) |
|
return RetrievalQA.from_chain_type(llm=llm, |
|
chain_type=chain_type, |
|
retriever=retriever, |
|
return_source_documents=return_source_documents) |
|
|
|
|
|
def summarize_messages(demo_ephemeral_chat_history, llm): |
|
stored_messages = demo_ephemeral_chat_history.messages |
|
if len(stored_messages) == 0: |
|
return False |
|
summarization_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
( |
|
"user", os.environ['SUMARY_MESSAGE_PROMPT'], |
|
), |
|
] |
|
) |
|
summarization_chain = summarization_prompt | llm |
|
|
|
summary_message = summarization_chain.invoke({"chat_history": stored_messages}) |
|
|
|
demo_ephemeral_chat_history.clear() |
|
|
|
demo_ephemeral_chat_history.add_message(summary_message) |
|
|
|
return demo_ephemeral_chat_history |
|
|
|
def get_question_from_summarize(summary, question, llm): |
|
new_qa_prompt = ChatPromptTemplate.from_messages([ |
|
("system", os.environ['NEW_QUESTION_PROMPT']), |
|
("human", |
|
''' |
|
Sumary: {summary} |
|
Question: {question} |
|
Output: |
|
''' |
|
) |
|
] |
|
) |
|
|
|
new_qa_chain = new_qa_prompt | llm |
|
return new_qa_chain.invoke({'summary': summary, 'question': question}).content |
|
|
|
def get_final_answer(question, context, prompt, llm): |
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", prompt), |
|
("human", ''' |
|
Context: {context} |
|
Question: {question} |
|
Output: |
|
'''), |
|
] |
|
) |
|
|
|
answer_chain = qa_prompt | llm |
|
|
|
answer = answer_chain.invoke({'question': question, 'context': context}) |
|
|
|
return answer.content |
|
|
|
def process_llm_response(llm_response): |
|
print(llm_response['result']) |
|
print('\n\nSources:') |
|
for source in llm_response["source_documents"]: |
|
print(source.metadata['source']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|