Spaces:
Sleeping
Sleeping
import subprocess | |
import streamlit as st | |
from dotenv import load_dotenv | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import FastEmbedEmbeddings # General embeddings from HuggingFace models. | |
from langchain.memory import ConversationBufferMemory | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from htmlTemplates import css, bot_template, user_template | |
from langchain.llms import LlamaCpp # For loading transformer models. | |
from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader | |
import tempfile | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain import hub | |
import os | |
import glob | |
import gc | |
# TEXT LOADERS | |
def get_pdf_text(pdf_docs): | |
""" | |
Purpose: A hypothetical loader for PDF files in Python. | |
Usage: Used to extract text or other information from PDF documents. | |
Load Function: A load_pdf function might be used to read and extract data from a PDF file. | |
input : pdf document path | |
returns : extracted text | |
""" | |
temp_dir = tempfile.TemporaryDirectory() | |
temp_filepath = os.path.join(temp_dir.name, pdf_docs.name) | |
with open(temp_filepath, "wb") as f: | |
f.write(pdf_docs.getvalue()) | |
pdf_loader = PyPDFLoader(temp_filepath) | |
pdf_doc = pdf_loader.load() | |
return pdf_doc | |
def get_text_file(text_docs): | |
""" | |
""" | |
temp_dir = tempfile.TemporaryDirectory() | |
temp_filepath = os.path.join(temp_dir.name, text_docs.name) | |
with open(temp_filepath, "wb") as f: | |
f.write(text_docs.getvalue()) | |
text_loader = TextLoader(temp_filepath) | |
text_doc = text_loader.load() | |
return text_doc | |
def get_csv_file(csv_docs): | |
temp_dir = tempfile.TemporaryDirectory() | |
temp_filepath = os.path.join(temp_dir.name, csv_docs.name) | |
with open(temp_filepath, "wb") as f: | |
f.write(csv_docs.getvalue()) | |
csv_loader = CSVLoader(temp_filepath) | |
csv_doc = csv_loader.load() | |
return csv_doc | |
def get_json_file(json_docs): | |
temp_dir = tempfile.TemporaryDirectory() | |
temp_filepath = os.path.join(temp_dir.name, json_docs.name) | |
with open(temp_filepath, "wb") as f: | |
f.write(json_docs.getvalue()) | |
json_loader = JSONLoader( | |
file_path=temp_filepath, | |
jq_schema='.messages[].content', | |
text_content=False | |
) | |
json_doc = json_loader.load() | |
return json_doc | |
def get_text_chunks(documents): | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=512, | |
chunk_overlap=50, | |
length_function=len | |
) | |
documents = text_splitter.split_documents(documents) | |
return documents | |
def get_vectorstore(text_chunks, embeddings): | |
vectorstore = Chroma.from_documents(documents= text_chunks, | |
embedding= st.session_state.embeddings, | |
persist_directory= "./vectordb/") | |
# Document stored | |
return vectorstore | |
def get_conversation_chain(vectorstore): | |
model_path = "models/llama-2-13b-chat.Q4_K_S.gguf" | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
llm = LlamaCpp(model_path= model_path, | |
n_ctx=4000, | |
max_tokens= 500, | |
n_gpu_layers = 50, | |
n_batch = 512, | |
callback_manager = callback_manager | |
verbose=True) | |
memory = ConversationBufferMemory( | |
memory_key='chat_history', return_messages=True) | |
# prompt template π | |
template = """ | |
You are a Experience human Resource Manager. When the employee asks you a question, you will have to refer the company policy and respond in a professional way. Make sure to sound Empethetic while being professional and sound like a Human! | |
Try to summarise the content and keep the answer to the point. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Followe the template below | |
Example: | |
Question : how many paid leaves do i have ? | |
Answer : The number of paid leaves varies depending on the type of leave, like privilege leave you're entitled to a maximum of 21 days in a calendar year. Other leaves might have different entitlements. thanks for asking! | |
make sure to add "thanks for asking!" after every answer | |
{context} | |
Question: {question} | |
Answer: | |
Just answer to the point! | |
""" | |
rag_prompt_custom = PromptTemplate.from_template(template) | |
# prompt = hub.pull("rlm/rag-prompt") | |
conversation_chain = RetrievalQA.from_chain_type( | |
llm, | |
retriever=vectorstore.as_retriever(), | |
chain_type_kwargs={"prompt": rag_prompt_custom}, | |
) | |
conversation_chain.callback_manager = callback_manager | |
conversation_chain.memory = ConversationBufferMemory() | |
return conversation_chain | |
def handle_userinput(): | |
clear = False | |
# Add clear chat button | |
if st.button("Clear Chat history"): | |
clear = True | |
st.session_state.messages = [] | |
if "messages" not in st.session_state: | |
st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}] | |
for msg in st.session_state.messages: | |
st.chat_message(msg["role"]).write(msg["content"]) | |
if prompt := st.chat_input(): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
st.chat_message("user").write(prompt) | |
if clear: | |
st.session_state.conversation.clean() | |
msg = st.session_state.conversation.run(prompt) | |
print(msg) | |
st.session_state.messages.append({"role": "assistant", "content": msg}) | |
st.chat_message("assistant").write(msg) | |
# Function to apply rounded edges using CSS | |
def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30): | |
st.markdown( | |
f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>', | |
unsafe_allow_html=True,) | |
st.image(image_path, use_column_width=True, output_format='auto') | |
def main(): | |
load_dotenv() | |
gc.collect() | |
st.set_page_config(page_title="Chat with multiple Files", | |
page_icon=":books:") | |
st.write(css, unsafe_allow_html=True) | |
if "conversation" not in st.session_state: | |
st.session_state.conversation = None | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = None | |
st.title("π¬ Randstad HR Chatbot") | |
st.subheader("π A HR powered by Generative AI") | |
# user_question = st.text_input("Ask a question about your documents:") | |
st.session_state.embeddings = FastEmbedEmbeddings( model_name= "BAAI/bge-small-en-v1.5", | |
cache_dir="./embedding_model/") | |
if len(glob.glob("./vectordb/*.sqlite3")) > 0: | |
vectorstore = Chroma(persist_directory="./vectordb/", embedding_function=st.session_state.embeddings) | |
st.session_state.conversation = get_conversation_chain(vectorstore) | |
handle_userinput() | |
with st.sidebar: | |
add_rounded_edges() | |
st.subheader("Your documents") | |
docs = st.file_uploader( | |
"Upload File (pdf,text,csv...) and click 'Process'", accept_multiple_files=True) | |
if st.button("Process"): | |
with st.spinner("Processing"): | |
# get pdf text | |
doc_list = [] | |
for file in docs: | |
print('file - type : ', file.type) | |
if file.type == 'text/plain': | |
# file is .txt | |
doc_list.extend(get_text_file(file)) | |
elif file.type in ['application/octet-stream', 'application/pdf']: | |
# file is .pdf | |
doc_list.extend(get_pdf_text(file)) | |
elif file.type == 'text/csv': | |
# file is .csv | |
doc_list.extend(get_csv_file(file)) | |
elif file.type == 'application/json': | |
# file is .json | |
doc_list.extend(get_json_file(file)) | |
# get the text chunks | |
text_chunks = get_text_chunks(doc_list) | |
# create vector store | |
vectorstore = get_vectorstore(text_chunks, st.session_state.embeddings) | |
# create conversation chain | |
st.session_state.conversation = get_conversation_chain(vectorstore) | |
if __name__ == '__main__': | |
command = 'CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir' | |
# Run the command using subprocess | |
try: | |
subprocess.run(command, shell=True, check=True) | |
print("Command executed successfully.") | |
except subprocess.CalledProcessError as e: | |
print(f"Error: {e}") | |
main() | |