witp_poc / app.py
vishwask's picture
Update app.py
c8ed0a0
raw
history blame
No virus
10.2 kB
import os
import torch
import uuid
import requests
import streamlit as st
from streamlit.logger import get_logger
from auto_gptq import AutoGPTQForCausalLM
from langchain import HuggingFacePipeline, PromptTemplate
from langchain.chains import RetrievalQA
from langchain.document_loaders import PyPDFDirectoryLoader
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pdf2image import convert_from_path
from transformers import AutoTokenizer, TextStreamer, pipeline
from langchain.memory import ConversationBufferMemory
from gtts import gTTS
from io import BytesIO
from langchain.chains import ConversationalRetrievalChain
import streamlit.components.v1 as components
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.vectorstores.utils import filter_complex_metadata
import fitz
from PIL import Image
from langchain.vectorstores import FAISS
import transformers
from pydub import AudioSegment
from streamlit_extras.streaming_write import write
import time
user_session_id = uuid.uuid4()
logger = get_logger(__name__)
st.set_page_config(page_title="Document QA by Dono", page_icon="🤖", )
st.session_state.disabled = False
st.title("Document QA by Dono")
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
@st.cache_data
def load_data():
loader = PyPDFDirectoryLoader("/home/user/app/ML/")
docs = loader.load()
return docs
@st.cache_resource
def load_model(_docs):
embeddings = HuggingFaceInstructEmbeddings(model_name="/home/user/app/all-MiniLM-L6-v2/",model_kwargs={"device":DEVICE})
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
texts = text_splitter.split_documents(docs)
db = FAISS.from_documents(texts, embeddings)
model_name_or_path = "/home/user/app/Llama-2-13B-chat-GPTQ/"
model_basename = "model"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
model = AutoGPTQForCausalLM.from_quantized(
model_name_or_path,
revision="gptq-8bit-128g-actorder_False",
model_basename=model_basename,
use_safetensors=True,
trust_remote_code=True,
inject_fused_attention=False,
device=DEVICE,
quantize_config=None,
)
# DEFAULT_SYSTEM_PROMPT = """
# You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.
# Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
# Please ensure that your responses are socially unbiased and positive in nature.
# Always provide the citation for the answer from the text.
# Try to include any section or subsection present in the text responsible for the answer.
# Provide reference. Provide page number, section, sub section etc.
# If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
# Given a government document that outlines rules and regulations for a specific industry or sector, use your language model to answer questions about the rules and their applicability over time.
# The document may include provisions that take effect at different times, such as immediately upon publication, after a grace period, or on a specific date in the future.
# Your task is to identify the relevant rules and determine when they go into effect, taking into account any dependencies or exceptions that may apply.
# The current date is 14 September, 2023. Try to extract information which is closer to this date.
# Take a deep breath and work on this problem step-by-step.
# """.strip()
DEFAULT_SYSTEM_PROMPT = """
You are a helpful, respectful and honest assistant with knowledge of machine learning, data science, computer science, Python programming language, mathematics, probability and statistics.
Take a deep breath and work on the given problem step-by-step.
""".strip()
def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"""[INST] <<SYS>>{system_prompt}<</SYS>>{prompt} [/INST]""".strip()
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
text_pipeline = pipeline("text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=1024,
temperature=0.5,
top_p=0.95,
repetition_penalty=1.15,
streamer=streamer,)
llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0.5})
SYSTEM_PROMPT = ("Use the following pieces of context to answer the question at the end. "
"If you don't know the answer, just say that you don't know, "
"don't try to make up an answer.")
template = generate_prompt("""{context} Question: {question} """,system_prompt=SYSTEM_PROMPT,) #Enter memory here!
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) #Add history here
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 10}),
return_source_documents=True,
chain_type_kwargs={"prompt": prompt,
"verbose": False})
print('load done')
return qa_chain
model_name_or_path = "Llama-2-13B-chat-GPTQ"
model_basename = "model"
st.session_state["llm_model"] = model_name_or_path
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def on_select():
st.session_state.disabled = True
def get_message_history():
for message in st.session_state.messages:
role, content = message["role"], message["content"]
yield f"{role.title()}: {content}"
docs = load_data()
qa_chain = load_model(docs)
if prompt := st.chat_input("How can I help you today?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner(text="Looking for relevant answer"):
message_placeholder = st.empty()
full_response = ""
message_history = "\n".join(list(get_message_history())[-3:])
result = qa_chain(prompt)
output = [result['result']]
def generate_pdf():
page_number = int(result['source_documents'][0].metadata['page'])
doc = fitz.open(str(result['source_documents'][0].metadata['source']))
text = str(result['source_documents'][0].page_content)
if text != '':
for page in doc:
text_instances = page.search_for(text)
for inst in text_instances:
highlight = page.add_highlight_annot(inst)
highlight.update()
doc.save("/home/user/app/pdf2image/output.pdf", garbage=4, deflate=True, clean=True)
def pdf_page_to_image(pdf_file, page_number, output_image):
pdf_document = fitz.open(pdf_file)
page = pdf_document[page_number]
dpi = 300 # You can adjust this as needed
pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 100, dpi / 100))
pix.save(output_image, "png")
pdf_document.close()
pdf_page_to_image('/home/user/app/pdf2image/output.pdf', page_number, '/home/user/app/pdf2image/output.png')
#image = Image.open('/home/user/app/pdf2image/output.png')
#message_placeholder.image(image)
#st.session_state.reference = True
def generate_audio():
with open('/home/user/app/audio/audio.mp3','wb') as sound_file:
tts = gTTS(result['result'], lang='en', tld='co.in')
tts.write_to_fp(sound_file)
sound = AudioSegment.from_mp3("/home/user/app/audio/audio.mp3")
sound.export("/home/user/app/audio/audio.wav", format="wav")
st.session_state['reference'] = '/home/user/app/pdf2image/default_output.png'
st.session_state['audio'] = ''
def stream_example():
for word in result['result'].split():
yield word + " "
time.sleep(0.1)
for item in output:
full_response += item
message_placeholder.markdown(full_response + "▌")
message_placeholder.markdown(full_response)
#message_placeholder.markdown(result['source_documents'])
# for item in output:
# full_response += item
# message_placeholder.markdown(write(stream_example))
#write(stream_example)
# message_placeholder.markdown(result['result'])
sound_file = BytesIO()
tts = gTTS(result['result'], lang='en')
tts.write_to_fp(sound_file)
st.audio(sound_file)
if "reference" not in st.session_state:
st.session_state.reference = False
if "audio" not in st.session_state:
st.session_state.audio = False
with st.sidebar:
choice = st.radio("References",["Reference"])
if choice == 'Reference':
generate_pdf()
st.session_state['reference'] = '/home/user/app/pdf2image/output.png'
st.image(st.session_state['reference'])
# if choice == 'TTS':
# with open('/home/user/app/audio/audio.mp3','wb') as sound_file:
# tts = gTTS(result['result'], lang='en', tld = 'co.in')
# tts.write_to_fp(sound_file)
# sound = AudioSegment.from_mp3("/home/user/app/audio/audio.mp3")
# sound.export("/home/user/app/audio/audio.wav", format="wav")
# st.session_state['audio'] = '/home/user/app/audio/audio.wav'
# st.audio(st.session_state['audio'])
st.session_state.messages.append({"role": "assistant", "content": full_response})