Debyez commited on
Commit
1552f02
·
verified ·
1 Parent(s): fc2c5cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -0
app.py CHANGED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from langchain import HuggingFacePipeline, PromptTemplate
4
+ from langchain.chains import RetrievalQA
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+ import os
10
+ import re
11
+ import pickle
12
+ import fitz # PyMuPDF
13
+ from langchain.schema import Document
14
+ import langdetect
15
+
16
+ def clean_output(output: str) -> str:
17
+ print("Raw output:", output) # Debugging line
18
+ start_index = output.find('[/INST]') + len('[/INST]')
19
+ cleaned_output = output[start_index:].strip()
20
+ print("Cleaned output:", cleaned_output) # Debugging line
21
+ return cleaned_output
22
+
23
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
24
+
25
+ def split_text_into_paragraphs(text_content):
26
+ paragraphs = text_content.split('#')
27
+ return [paragraph.strip() for paragraph in paragraphs if paragraph.strip()]
28
+
29
+ def sanitize_filename(filename):
30
+ sanitized_name = re.sub(r'[^a-zA-Z0-9_-]', '_', filename)
31
+ return sanitized_name[:63]
32
+
33
+ def extract_text_from_pdf(pdf_path):
34
+ text_content = ''
35
+ with fitz.open(pdf_path) as pdf_document:
36
+ for page_num in range(len(pdf_document)):
37
+ page = pdf_document[page_num]
38
+ text_content += page.get_text()
39
+ return text_content
40
+
41
+ def detect_language(text):
42
+ try:
43
+ return langdetect.detect(text)
44
+ except:
45
+ return "en" # Default to English if detection fails
46
+
47
+ def process_pdf_file(filename, pdf_path, embeddings, llm, prompt):
48
+ print(f'\nProcessing: {pdf_path}')
49
+ text_content = extract_text_from_pdf(pdf_path)
50
+
51
+ language = detect_language(text_content)
52
+ print(f"Detected language: {language}")
53
+
54
+ paragraphs = split_text_into_paragraphs(text_content)
55
+ documents = [Document(page_content=paragraph, metadata={"language": language, "source": filename}) for paragraph in paragraphs]
56
+
57
+ print(f"Number of documents created: {len(documents)}")
58
+
59
+ collection_name = sanitize_filename(os.path.basename(filename))
60
+ db = Chroma.from_documents(documents, embeddings, collection_name=collection_name)
61
+ retriever = db.as_retriever(search_kwargs={"k": 2})
62
+ qa_chain = RetrievalQA.from_chain_type(
63
+ llm=llm,
64
+ chain_type="stuff",
65
+ retriever=retriever,
66
+ return_source_documents=True,
67
+ chain_type_kwargs={"prompt": prompt},
68
+ )
69
+
70
+ print(f"QA chain created for {filename}")
71
+ return qa_chain, language
72
+
73
+ SYSTEM_PROMPT = """
74
+ Use the provided context to answer the question clearly and concisely. Do not repeat the context in your answer.
75
+ """
76
+
77
+ def generate_prompt(prompt: str, system_prompt: str = SYSTEM_PROMPT) -> str:
78
+ return f"""
79
+ [INST] <>
80
+ {system_prompt}
81
+ <>
82
+
83
+ {prompt} [/INST]
84
+ """.strip()
85
+
86
+ def main():
87
+ # Streamlit UI
88
+ st.title("PDF-Powered Chatbot")
89
+
90
+ # File Uploader
91
+ uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
92
+
93
+ # Model Loading
94
+ model_pickle_path = '/kaggle/working/model.pkl'
95
+
96
+ if os.path.exists(model_pickle_path):
97
+ with open(model_pickle_path, 'rb') as f:
98
+ model, tokenizer = pickle.load(f)
99
+ else:
100
+ MODEL_NAME = "sarvamai/sarvam-2b-v0.5"
101
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
102
+ tokenizer.pad_token = tokenizer.eos_token
103
+
104
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
105
+ with open(model_pickle_path, 'wb') as f:
106
+ pickle.dump((model, tokenizer), f)
107
+
108
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
109
+
110
+ text_pipeline = pipeline(
111
+ "text-generation",
112
+ model=model,
113
+ tokenizer=tokenizer,
114
+ max_new_tokens=1024,
115
+ temperature=0.1,
116
+ top_p=0.95,
117
+ repetition_penalty=1.15,
118
+ device=DEVICE
119
+ )
120
+
121
+ llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0})
122
+
123
+ template = generate_prompt(
124
+ """
125
+ {context}
126
+
127
+ Question: {question}
128
+ """,
129
+ system_prompt=SYSTEM_PROMPT,
130
+ )
131
+ prompt = PromptTemplate(template=template, input_variables=["context", "question"])
132
+
133
+ # Initialize QA chains dictionary
134
+ qa_chains = {}
135
+
136
+ # Process uploaded files
137
+ if uploaded_files:
138
+ with st.spinner("Processing PDFs..."):
139
+ for uploaded_file in uploaded_files:
140
+ file_path = uploaded_file.name # Use the filename directly
141
+ qa_chain, doc_language = process_pdf_file(uploaded_file.name, file_path, embeddings, llm, prompt)
142
+ qa_chains[doc_language] = (qa_chain, uploaded_file.name)
143
+
144
+ st.success("PDFs processed! You can now ask questions.")
145
+
146
+ # Chat interface
147
+ if st.button("Clear Chat History"):
148
+ st.session_state.chat_history = []
149
+
150
+ if "chat_history" not in st.session_state:
151
+ st.session_state.chat_history = []
152
+
153
+ for message in st.session_state.chat_history:
154
+ with st.chat_message(message["role"]):
155
+ st.markdown(message["content"])
156
+
157
+ if prompt := st.chat_input("Ask your question here"):
158
+ st.session_state.chat_history.append({"role": "user", "content": prompt})
159
+ with st.chat_message("user"):
160
+ st.markdown(prompt)
161
+
162
+ with st.spinner("Generating response..."):
163
+ query_language = detect_language(prompt)
164
+
165
+ if query_language in qa_chains:
166
+ qa_chain, _ = qa_chains[query_language]
167
+ result = qa_chain({"query": prompt})
168
+ cleaned_answer = clean_output(result['result'])
169
+
170
+ with st.chat_message("assistant"):
171
+ st.markdown(cleaned_answer)
172
+ st.session_state.chat_history.append({"role": "assistant", "content": cleaned_answer})
173
+ else:
174
+ with st.chat_message("assistant"):
175
+ st.markdown(f"No document available for the detected language: {query_language}")
176
+ st.session_state.chat_history.append({"role": "assistant", "content": f"No document available for the detected language: {query_language}"})
177
+
178
+ if __name__ == "__main__":
179
+ main()