import os import torch import uuid import requests import streamlit as st from streamlit.logger import get_logger from auto_gptq import AutoGPTQForCausalLM from langchain import HuggingFacePipeline, PromptTemplate from langchain.chains import RetrievalQA from langchain.document_loaders import PyPDFDirectoryLoader from langchain.embeddings import HuggingFaceInstructEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from pdf2image import convert_from_path from transformers import AutoTokenizer, TextStreamer, pipeline from langchain.memory import ConversationBufferMemory from gtts import gTTS from io import BytesIO from langchain.chains import ConversationalRetrievalChain import streamlit.components.v1 as components from langchain.document_loaders import UnstructuredMarkdownLoader from langchain.vectorstores.utils import filter_complex_metadata import fitz from PIL import Image from langchain.vectorstores import FAISS import transformers from pydub import AudioSegment from streamlit_extras.streaming_write import write import time user_session_id = uuid.uuid4() logger = get_logger(__name__) st.set_page_config(page_title="Document QA by Dono", page_icon="🤖", ) st.session_state.disabled = False st.title("Document QA by Dono") DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" @st.cache_data def load_data(): loader = PyPDFDirectoryLoader("/home/user/app/ML/") docs = loader.load() return docs @st.cache_resource def load_model(_docs): embeddings = HuggingFaceInstructEmbeddings(model_name="/home/user/app/all-MiniLM-L6-v2/",model_kwargs={"device":DEVICE}) text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256) texts = text_splitter.split_documents(docs) db = FAISS.from_documents(texts, embeddings) model_name_or_path = "/home/user/app/Llama-2-13B-chat-GPTQ/" model_basename = "model" tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) model = AutoGPTQForCausalLM.from_quantized( model_name_or_path, revision="gptq-8bit-128g-actorder_False", model_basename=model_basename, use_safetensors=True, trust_remote_code=True, inject_fused_attention=False, device=DEVICE, quantize_config=None, ) # DEFAULT_SYSTEM_PROMPT = """ # You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. # Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. # Please ensure that your responses are socially unbiased and positive in nature. # Always provide the citation for the answer from the text. # Try to include any section or subsection present in the text responsible for the answer. # Provide reference. Provide page number, section, sub section etc. # If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. # Given a government document that outlines rules and regulations for a specific industry or sector, use your language model to answer questions about the rules and their applicability over time. # The document may include provisions that take effect at different times, such as immediately upon publication, after a grace period, or on a specific date in the future. # Your task is to identify the relevant rules and determine when they go into effect, taking into account any dependencies or exceptions that may apply. # The current date is 14 September, 2023. Try to extract information which is closer to this date. # Take a deep breath and work on this problem step-by-step. # """.strip() DEFAULT_SYSTEM_PROMPT = """ You are a helpful, respectful and honest assistant with knowledge of machine learning, data science, computer science, Python programming language, mathematics, probability and statistics. Take a deep breath and work on the given problem step-by-step. """.strip() def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str: return f"""[INST] <>{system_prompt}<>{prompt} [/INST]""".strip() streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, temperature=0.5, top_p=0.95, repetition_penalty=1.15, streamer=streamer,) llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0.5}) SYSTEM_PROMPT = ("Use the following pieces of context to answer the question at the end. " "If you don't know the answer, just say that you don't know, " "don't try to make up an answer.") template = generate_prompt("""{context} Question: {question} """,system_prompt=SYSTEM_PROMPT,) #Enter memory here! prompt = PromptTemplate(template=template, input_variables=["context", "question"]) #Add history here qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=db.as_retriever(search_kwargs={"k": 10}), return_source_documents=True, chain_type_kwargs={"prompt": prompt, "verbose": False}) print('load done') return qa_chain model_name_or_path = "Llama-2-13B-chat-GPTQ" model_basename = "model" st.session_state["llm_model"] = model_name_or_path if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) def on_select(): st.session_state.disabled = True def get_message_history(): for message in st.session_state.messages: role, content = message["role"], message["content"] yield f"{role.title()}: {content}" docs = load_data() qa_chain = load_model(docs) if prompt := st.chat_input("How can I help you today?"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner(text="Looking for relevant answer"): message_placeholder = st.empty() full_response = "" message_history = "\n".join(list(get_message_history())[-3:]) result = qa_chain(prompt) output = [result['result']] def generate_pdf(): page_number = int(result['source_documents'][0].metadata['page']) doc = fitz.open(str(result['source_documents'][0].metadata['source'])) text = str(result['source_documents'][0].page_content) if text != '': for page in doc: text_instances = page.search_for(text) for inst in text_instances: highlight = page.add_highlight_annot(inst) highlight.update() doc.save("/home/user/app/pdf2image/output.pdf", garbage=4, deflate=True, clean=True) def pdf_page_to_image(pdf_file, page_number, output_image): pdf_document = fitz.open(pdf_file) page = pdf_document[page_number] dpi = 300 # You can adjust this as needed pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 100, dpi / 100)) pix.save(output_image, "png") pdf_document.close() pdf_page_to_image('/home/user/app/pdf2image/output.pdf', page_number, '/home/user/app/pdf2image/output.png') #image = Image.open('/home/user/app/pdf2image/output.png') #message_placeholder.image(image) #st.session_state.reference = True def generate_audio(): with open('/home/user/app/audio/audio.mp3','wb') as sound_file: tts = gTTS(result['result'], lang='en', tld='co.in') tts.write_to_fp(sound_file) sound = AudioSegment.from_mp3("/home/user/app/audio/audio.mp3") sound.export("/home/user/app/audio/audio.wav", format="wav") st.session_state['reference'] = '/home/user/app/pdf2image/default_output.png' st.session_state['audio'] = '' def stream_example(): for word in result['result'].split(): yield word + " " time.sleep(0.1) for item in output: full_response += item message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) #message_placeholder.markdown(result['source_documents']) # for item in output: # full_response += item # message_placeholder.markdown(write(stream_example)) #write(stream_example) # message_placeholder.markdown(result['result']) sound_file = BytesIO() tts = gTTS(result['result'], lang='en') tts.write_to_fp(sound_file) st.audio(sound_file) if "reference" not in st.session_state: st.session_state.reference = False if "audio" not in st.session_state: st.session_state.audio = False with st.sidebar: choice = st.radio("References",["Reference"]) if choice == 'Reference': generate_pdf() st.session_state['reference'] = '/home/user/app/pdf2image/output.png' st.image(st.session_state['reference']) # if choice == 'TTS': # with open('/home/user/app/audio/audio.mp3','wb') as sound_file: # tts = gTTS(result['result'], lang='en', tld = 'co.in') # tts.write_to_fp(sound_file) # sound = AudioSegment.from_mp3("/home/user/app/audio/audio.mp3") # sound.export("/home/user/app/audio/audio.wav", format="wav") # st.session_state['audio'] = '/home/user/app/audio/audio.wav' # st.audio(st.session_state['audio']) st.session_state.messages.append({"role": "assistant", "content": full_response})