Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import Chroma | |
| from langchain.document_loaders import TextLoader, DirectoryLoader | |
| import os | |
| import re | |
| from sentence_transformers.cross_encoder import CrossEncoder | |
| import numpy as np | |
| from langchain.schema.retriever import BaseRetriever, Document | |
| from typing import List | |
| from langchain.callbacks.manager import CallbackManagerForRetrieverRun | |
| from langchain.vectorstores import VectorStore | |
| from llm import URALLM | |
| from langchain.prompts import PromptTemplate | |
| # Get role for passage document | |
| def get_role(document): | |
| """ | |
| Get role for student. | |
| """ | |
| # Tìm kiếm các từ khóa liên quan đến vai trò học viên trong document. | |
| keywords = [ | |
| "sinh viên", | |
| "đại học", | |
| "học viên", | |
| "thạc sĩ", | |
| "nghiên cứu sinh", | |
| "tiến sĩ", | |
| ] | |
| role = [] | |
| for keyword in keywords: | |
| if keyword in document.metadata['source'].lower(): | |
| role.append(keyword) | |
| return ", ".join(role) | |
| def processing_data(data_path): | |
| folders = os.listdir(data_path) | |
| dir_loaders = [] | |
| # Add the documents to the project | |
| for folder in folders: | |
| dir_loader = DirectoryLoader((os.path.join(data_path, folder)), loader_cls=TextLoader) | |
| dir_loaders.append(dir_loader) | |
| # Load the text files. | |
| loaded_documents = [] | |
| for dir_loader in dir_loaders: | |
| loaded_documents.append(dir_loader.load()) | |
| data = [] | |
| for i in range(len(loaded_documents)): | |
| for j in range(len(loaded_documents[i])): | |
| data.append(loaded_documents[i][j]) | |
| # Final data prepare for vector database | |
| for document in data: | |
| role = get_role(document) | |
| document.metadata['role'] = role | |
| return data | |
| # Embedding model | |
| embedding = HuggingFaceEmbeddings( | |
| model_name="VoVanPhuc/sup-SimCSE-VietNamese-phobert-base", | |
| model_kwargs={"device": "cpu"} | |
| ) | |
| # embedding = HuggingFaceEmbeddings( | |
| # model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| # model_kwargs={"device": "cpu"} | |
| # ) | |
| # Vector database | |
| data_path = 'raw_data' | |
| persist_directory = 'vector_db' | |
| vectordb = Chroma.from_documents( | |
| documents=processing_data(data_path), | |
| embedding=embedding, | |
| persist_directory=persist_directory | |
| ) | |
| class CustomRetriever(BaseRetriever): | |
| vectorstores:Chroma | |
| retriever:vectordb.as_retriever() | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| # Use your existing retriever to get the documents | |
| documents = self.retriever.get_relevant_documents(query, callbacks=run_manager.get_child()) | |
| # Get page content | |
| docs_content = [] | |
| for i in range(len(documents)): | |
| docs_content.append(documents[i].page_content) | |
| model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| # So we create the respective sentence combinations | |
| sentence_combinations = [[query, document] for document in docs_content] | |
| # Compute the similarity scores for these combinations | |
| similarity_scores = model.predict(sentence_combinations) | |
| # Sort the scores in decreasing order | |
| sim_scores_argsort = reversed(np.argsort(similarity_scores)) | |
| # Store the rerank document in new list | |
| docs = [] | |
| for idx in sim_scores_argsort: | |
| docs.append(documents[idx]) | |
| docs_top_2 = docs[0:2] | |
| return docs_top_2 | |
| llm = URALLM() | |
| custom_retriever = CustomRetriever(vectorstores = vectordb,retriever = vectordb.as_retriever(search_kwargs={"k": 50})) | |
| # Build prompt | |
| template = """[INST] <<SYS>> | |
| Bạn là một chatbot hỗ trợ trả lời các quy định học vụ của trường Đại học Bách Khoa - ĐHQG TP.HCM. | |
| Trả lời câu hỏi dựa trên văn bản được cung cấp. | |
| Nếu không tìm thấy câu trả lời, vui lòng trả lời: "Xin lỗi, tôi không có thông tin cho câu hỏi này!" | |
| <</SYS>> | |
| Văn bản: {context} | |
| Câu hỏi: {question} | |
| Câu trả lời: | |
| [/INST]""" | |
| QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,) | |
| # Run chain | |
| from langchain.chains import RetrievalQA | |
| qa_chain = RetrievalQA.from_chain_type(llm, | |
| verbose=False, | |
| # retriever=vectordb.as_retriever(), | |
| retriever=custom_retriever, | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}) | |
| def remove_special_characters(text): | |
| text = text.replace('].', '') | |
| text = text.replace('/.', '') | |
| text = text.replace('/.-', '') | |
| text = text.replace('-', '') | |
| return text | |
| def rag(question: str) -> str: | |
| # call QA chain | |
| response = qa_chain({"query": question}) | |
| return remove_special_characters(response["result"]) | |