Spaces:
Running
Running
File size: 5,199 Bytes
8515a17 2d4455b 8515a17 2d4455b 8515a17 a63eb02 d31a2f9 2d4455b 8515a17 a63eb02 f8d987f 2d4455b 15bfcda f8d987f 15bfcda 2d4455b f8d987f a63eb02 8515a17 2d4455b a63eb02 8515a17 2d4455b 317f434 a63eb02 8515a17 a63eb02 8515a17 a63eb02 8515a17 111afc4 a63eb02 111afc4 a63eb02 2d4455b a63eb02 2d4455b f8d987f 2d4455b f8d987f 2d4455b 317f434 2d4455b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import streamlit as st
import os
from langchain_community.document_loaders import PDFMinerLoader
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import torch
st.title("DocChatAI | Chat over PDF Doc")
# Custom CSS for chat messages
st.markdown("""
<style>
.user-message {
text-align: right;
background-color: #3c8ce7;
color: white;
padding: 10px;
border-radius: 10px;
margin-bottom: 10px;
display: inline-block;
width: fit-content;
max-width: 70%;
margin-left: auto;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}
.assistant-message {
text-align: left;
background-color: #d16ba5;
color: white;
padding: 10px;
border-radius: 10px;
margin-bottom: 10px;
display: inline-block;
width: fit-content;
max-width: 70%;
margin-right: auto;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}
</style>
""", unsafe_allow_html=True)
def get_file_size(file):
file.seek(0, os.SEEK_END)
file_size = file.tell()
file.seek(0)
return file_size
# Add a sidebar for model selection and user details
st.sidebar.write("Settings")
st.sidebar.write("-----------")
model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
selected_model = st.sidebar.radio("Choose Model", model_options)
st.sidebar.write("-----------")
uploaded_file = st.sidebar.file_uploader("Upload file", type=["pdf"])
st.sidebar.write("-----------")
st.sidebar.write("About Me")
st.sidebar.write("Name: Deepak Yadav")
st.sidebar.write("Bio: Passionate about AI and machine learning. Enjoys working on innovative projects and sharing knowledge with the community.")
st.sidebar.write("[GitHub](https://github.com/deepak7376)")
st.sidebar.write("[LinkedIn](https://www.linkedin.com/in/dky7376/)")
st.sidebar.write("-----------")
@st.cache_resource
def initialize_qa_chain(filepath, CHECKPOINT):
loader = PDFMinerLoader(filepath)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
splits = text_splitter.split_documents(documents)
# Create embeddings
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = FAISS.from_documents(splits, embeddings)
# Initialize model
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
pipe = pipeline(
'text2text-generation',
model=BASE_MODEL,
tokenizer=TOKENIZER,
max_length=256,
do_sample=True,
temperature=0.3,
top_p=0.95,
)
llm = HuggingFacePipeline(pipeline=pipe)
# Build a QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectordb.as_retriever(),
)
return qa_chain
def process_answer(instruction, qa_chain):
generated_text = qa_chain.run(instruction)
return generated_text
if uploaded_file is not None:
os.makedirs("docs", exist_ok=True)
filepath = os.path.join("docs", uploaded_file.name)
with open(filepath, "wb") as temp_file:
temp_file.write(uploaded_file.read())
temp_filepath = temp_file.name
with st.spinner('Embeddings are in process...'):
qa_chain = initialize_qa_chain(temp_filepath, selected_model)
else:
qa_chain = None
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
if message["role"] == "user":
st.markdown(f"<div class='user-message'>{message['content']}</div>", unsafe_allow_html=True)
else:
st.markdown(f"<div class='assistant-message'>{message['content']}</div>", unsafe_allow_html=True)
# React to user input
if prompt := st.chat_input("What is up?"):
# Display user message in chat message container
st.markdown(f"<div class='user-message'>{prompt}</div>", unsafe_allow_html=True)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
if qa_chain:
# Generate response
response = process_answer({'query': prompt}, qa_chain)
else:
# Prompt to upload a file
response = "Please upload a PDF file to enable the chatbot."
# Display assistant response in chat message container
st.markdown(f"<div class='assistant-message'>{response}</div>", unsafe_allow_html=True)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
|