Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import re | |
import pathlib | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from langchain.llms import HuggingFacePipeline | |
from langchain.llms import LlamaCpp | |
from langchain import PromptTemplate, LLMChain | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import FAISS | |
from PyPDF2 import PdfReader | |
import os | |
import time | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT | |
from langchain.document_loaders import TextLoader | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.document_loaders import Docx2txtLoader | |
from langchain.document_loaders.image import UnstructuredImageLoader | |
from langchain.document_loaders import UnstructuredHTMLLoader | |
from langchain.document_loaders import UnstructuredPowerPointLoader | |
from langchain.document_loaders import TextLoader | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory.chat_message_histories.streamlit import StreamlitChatMessageHistory | |
# sidebar contents | |
with st.sidebar: | |
st.title('DOC-QA DEMO ') | |
st.markdown(''' | |
## About | |
Detail this application: | |
- LLM model: llama2-7b-chat-4bit | |
- Hardware resource : Huggingface space 8 vCPU 32 GB | |
''') | |
class UploadDoc: | |
def __init__(self, path_data): | |
self.path_data = path_data | |
def prepare_filetype(self): | |
extension_lists = { | |
".docx": [], | |
".pdf": [], | |
".html": [], | |
".png": [], | |
".pptx": [], | |
".txt": [], | |
} | |
path_list = [] | |
for path, subdirs, files in os.walk(self.path_data): | |
for name in files: | |
path_list.append(os.path.join(path, name)) | |
#print(os.path.join(path, name)) | |
# Loop through the path_list and categorize files | |
for filename in path_list: | |
file_extension = pathlib.Path(filename).suffix | |
#print("File Extension:", file_extension) | |
if file_extension in extension_lists: | |
extension_lists[file_extension].append(filename) | |
return extension_lists | |
def upload_docx(self, extension_lists): | |
#word | |
data_docxs = [] | |
for doc in extension_lists[".docx"]: | |
loader = Docx2txtLoader(doc) | |
data = loader.load() | |
data_docxs.extend(data) | |
return data_docxs | |
def upload_pdf(self, extension_lists): | |
data_pdf = [] | |
for doc in extension_lists[".pdf"]: | |
loader = PyPDFLoader(doc) | |
data = loader.load_and_split() | |
data_pdf.extend(data) | |
return data_pdf | |
def upload_html(self, extension_lists): | |
#html | |
data_html = [] | |
for doc in extension_lists[".html"]: | |
loader = UnstructuredHTMLLoader(doc) | |
data = loader.load() | |
data_html.extend(data) | |
return data_html | |
def upload_png_ocr(self, extension_lists): | |
#png ocr | |
data_png = [] | |
for doc in extension_lists[".png"]: | |
loader = UnstructuredImageLoader(doc) | |
data = loader.load() | |
data_png.extend(data) | |
return data_png | |
def upload_pptx(self, extension_lists): | |
#power point | |
data_pptx = [] | |
for doc in extension_lists[".pptx"]: | |
loader = UnstructuredPowerPointLoader(doc) | |
data = loader.load() | |
data_pptx.extend(data) | |
return data_pptx | |
def upload_txt(self, extension_lists): | |
#txt | |
data_txt = [] | |
for doc in extension_lists[".txt"]: | |
loader = TextLoader(doc) | |
data = loader.load() | |
data_txt.extend(data) | |
return data_txt | |
def count_files(self, extension_lists): | |
file_extension_counts = {} | |
# Count the quantity of each item | |
for ext, file_list in extension_lists.items(): | |
file_extension_counts[ext] = len(file_list) | |
return print(f"number of file:{file_extension_counts}") | |
# Print the counts | |
# for ext, count in file_extension_counts.items(): | |
# return print(f"{ext}: {count} file") | |
def create_document(self, dataframe=True): | |
documents = [] | |
extension_lists = self.prepare_filetype() | |
self.count_files(extension_lists) | |
upload_functions = { | |
".docx": self.upload_docx, | |
".pdf": self.upload_pdf, | |
".html": self.upload_html, | |
".png": self.upload_png_ocr, | |
".pptx": self.upload_pptx, | |
".txt": self.upload_txt, | |
} | |
for extension, upload_function in upload_functions.items(): | |
if len(extension_lists[extension]) > 0: | |
if extension == ".xlsx" or extension == ".csv": | |
data = upload_function(extension_lists, dataframe) | |
else: | |
data = upload_function(extension_lists) | |
documents.extend(data) | |
return documents | |
def split_docs(documents,chunk_size=500): | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100) | |
sp_docs = text_splitter.split_documents(documents) | |
return sp_docs | |
def load_llama2_llamaCpp(): | |
core_model_name = "llama-2-7b-chat.Q4_0.gguf" | |
#n_gpu_layers = 32 | |
n_batch = 32 | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
llm = LlamaCpp( | |
model_path=core_model_name, | |
#n_gpu_layers=n_gpu_layers, | |
n_batch=n_batch, | |
callback_manager=callback_manager, | |
verbose=True,n_ctx = 1024, temperature = 0.1, max_tokens = 256 | |
) | |
return llm | |
def set_custom_prompt(): | |
custom_prompt_template = """ Use the following pieces of information from context to answer the user's question. | |
If you don't know the answer, don't try to make up an answer. | |
Context : {context} | |
Question : {question} | |
Only returns the helpful answer below and nothing else. | |
Helpful answer: | |
""" | |
prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', | |
'question', | |
]) | |
return prompt | |
def load_embeddings(): | |
embeddings = HuggingFaceEmbeddings(model_name = "thenlper/gte-base", | |
model_kwargs = {'device': 'cpu'}) | |
return embeddings | |
def main(): | |
# msgs = StreamlitChatMessageHistory(key="langchain_messages") | |
# print(msgs) | |
# if "messages" not in st.session_state: | |
# st.session_state.messages = [] | |
# DB_FAISS_UPLOAD_PATH = "vectorstores/db_faiss" | |
st.header("DOCUMENT QUESTION ANSWERING IS2") | |
# directory = "data" | |
# data_dir = UploadDoc(directory).create_document() | |
# data.extend(data_dir) | |
# #create vector from upload | |
# if len(data) > 0 : | |
# sp_docs = split_docs(documents = data) | |
# st.write(f"This document have {len(sp_docs)} chunks") | |
# embeddings = load_embeddings() | |
# with st.spinner('Wait for create vector'): | |
# db = FAISS.from_documents(sp_docs, embeddings) | |
# # db.save_local(DB_FAISS_UPLOAD_PATH) | |
# # st.write(f"Your model is already store in {DB_FAISS_UPLOAD_PATH}") | |
llm = load_llama2_llamaCpp() | |
qa_prompt = set_custom_prompt() | |
#memory = ConversationBufferWindowMemory(k = 0, return_messages=True, input_key= 'question', output_key='answer', memory_key="chat_history") | |
#memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
#doc_chain = load_qa_chain(llm, chain_type="stuff", prompt = qa_prompt) | |
#question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) | |
embeddings = load_embeddings() | |
uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf") | |
print(uploaded_file) | |
if uploaded_file is not None: | |
pdf_reader = PdfReader(uploaded_file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() | |
print(text) | |
db = FAISS.from_texts(text, embeddings) | |
memory = ConversationBufferMemory(memory_key="chat_history", | |
return_messages=True, | |
input_key="query", | |
output_key="result") | |
qa_chain = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = "stuff", | |
retriever = db.as_retriever(search_kwargs = {'k':3}), | |
return_source_documents = True, | |
memory = memory, | |
chain_type_kwargs = {"prompt":qa_prompt}) | |
query = st.text_input("ASK ABOUT THE DOCS:") | |
if query: | |
start = time.time() | |
response = qa_chain({'query': query}) | |
st.write(response["result"]) | |
end = time.time() | |
st.write("Respone time:",int(end-start),"sec") | |
# qa_chain = ConversationalRetrievalChain( | |
# retriever =db.as_retriever(search_kwargs={'k':2}), | |
# question_generator=question_generator, | |
# #condense_question_prompt=CONDENSE_QUESTION_PROMPT, | |
# combine_docs_chain=doc_chain, | |
# return_source_documents=True, | |
# memory = memory, | |
# #get_chat_history=lambda h :h | |
# ) | |
# for message in st.session_state.messages: | |
# with st.chat_message(message["role"]): | |
# st.markdown(message["content"]) | |
# # Accept user input | |
# if query := st.chat_input("What is up?"): | |
# # Display user message in chat message container | |
# with st.chat_message("user"): | |
# st.markdown(query) | |
# # Add user message to chat history | |
# st.session_state.messages.append({"role": "user", "content": query}) | |
# start = time.time() | |
# response = qa_chain({'query': query}) | |
# # url_list = set([i.metadata['source'] for i in response['source_documents']]) | |
# #print(f"condensed quesion : {question_generator.run({'chat_history': response['chat_history'], 'question' : query})}") | |
# with st.chat_message("assistant"): | |
# st.markdown(response['result']) | |
# end = time.time() | |
# st.write("Respone time:",int(end-start),"sec") | |
# print(response) | |
# # Add assistant response to chat history | |
# st.session_state.messages.append({"role": "assistant", "content": response['result']}) | |
# # with st.expander("See the related documents"): | |
# # for count, url in enumerate(url_list): | |
# # #url_reg = regex_source(url) | |
# # st.write(str(count+1)+":", url) | |
# clear_button = st.button("Start new convo") | |
# if clear_button : | |
# st.session_state.messages = [] | |
# qa_chain.memory.chat_memory.clear() | |
if __name__ == '__main__': | |
main() |