Srikesh commited on
Commit
dd765b2
Β·
verified Β·
1 Parent(s): ec0c183

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -62
app.py CHANGED
@@ -1,9 +1,13 @@
 
 
 
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer
3
  import numpy as np
4
  from pypdf import PdfReader
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
7
 
8
  # Global variables
9
  chunks = []
@@ -11,30 +15,72 @@ embeddings = []
11
  model = None
12
  tokenizer = None
13
  embed_model = None
 
14
 
15
  def initialize_models():
16
- """Initialize models on startup"""
17
  global model, tokenizer, embed_model
18
 
19
  print("Loading models...")
20
 
21
- # Load embedding model
22
- embed_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Load language model
25
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
26
- tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_name,
29
  torch_dtype=torch.float32,
30
- low_cpu_mem_usage=True
 
31
  )
32
 
 
 
 
 
33
  print("Models loaded successfully!")
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def process_pdf(pdf_file):
36
- """Process PDF and create embeddings"""
37
- global chunks, embeddings, embed_model
38
 
39
  if pdf_file is None:
40
  return "❌ Please upload a PDF file!", None
@@ -49,37 +95,44 @@ def process_pdf(pdf_file):
49
  if not text.strip():
50
  return "❌ Could not extract text from PDF!", None
51
 
52
- # Split into chunks
53
- chunk_size = 1000
54
- overlap = 200
55
- chunks = []
56
 
57
- for i in range(0, len(text), chunk_size - overlap):
58
- chunk = text[i:i + chunk_size]
59
- if chunk.strip():
60
- chunks.append(chunk)
61
 
62
- # Create embeddings
63
- embeddings = embed_model.encode(chunks, show_progress_bar=False)
 
 
 
 
 
 
64
 
65
  return f"βœ… PDF processed! Created {len(chunks)} chunks. You can now ask questions!", None
66
 
67
  except Exception as e:
 
68
  return f"❌ Error: {str(e)}", None
69
 
70
- def find_relevant_chunks(query, top_k=3):
71
- """Find most relevant chunks using cosine similarity"""
72
  global chunks, embeddings, embed_model
73
 
74
- if not chunks:
75
  return []
76
 
77
- query_embedding = embed_model.encode([query])[0]
 
 
 
 
 
78
 
79
- # Calculate cosine similarity
80
- similarities = np.dot(embeddings, query_embedding) / (
81
- np.linalg.norm(embeddings, axis=1) * np.linalg.norm(query_embedding)
82
- )
83
 
84
  # Get top k indices
85
  top_indices = np.argsort(similarities)[-top_k:][::-1]
@@ -87,42 +140,49 @@ def find_relevant_chunks(query, top_k=3):
87
  return [chunks[i] for i in top_indices]
88
 
89
  def generate_response(question, context):
90
- """Generate response using the language model"""
91
  global model, tokenizer
92
 
93
- prompt = f"""<|system|>
94
- You are a helpful assistant. Answer the question based on the provided context. Be concise and accurate.
95
- </s>
96
- <|user|>
97
- Context: {context}
98
 
99
  Question: {question}
100
- </s>
101
- <|assistant|>
102
- """
103
 
104
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
 
 
 
 
 
105
 
 
106
  with torch.no_grad():
107
  outputs = model.generate(
108
  **inputs,
109
- max_new_tokens=300,
110
  temperature=0.7,
111
  top_p=0.9,
112
  do_sample=True,
113
- pad_token_id=tokenizer.eos_token_id
 
 
114
  )
115
 
116
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
 
118
- # Extract only the assistant's response
119
- if "<|assistant|>" in response:
120
- response = response.split("<|assistant|>")[-1].strip()
 
 
 
121
 
122
  return response
123
 
124
  def chat(message, history):
125
- """Handle chat"""
126
  global chunks
127
 
128
  if not chunks:
@@ -132,48 +192,103 @@ def chat(message, history):
132
  return history
133
 
134
  try:
135
- # Find relevant context
136
- relevant_chunks = find_relevant_chunks(message)
137
- context = "\n\n".join(relevant_chunks)
138
 
139
  # Generate response
140
  response = generate_response(message, context)
141
 
 
 
 
 
142
  return history + [[message, response]]
143
 
144
  except Exception as e:
 
145
  return history + [[message, f"❌ Error: {str(e)}"]]
146
 
147
  def clear_all():
148
  """Clear everything"""
149
- global chunks, embeddings
150
  chunks = []
151
  embeddings = []
 
152
  return None, "Ready to process a new PDF"
153
 
154
- # Create UI
155
- with gr.Blocks(title="Chat with PDF") as demo:
156
- gr.Markdown("# πŸ“„ Chat with PDF - Simple Version")
 
157
 
158
  with gr.Row():
159
  with gr.Column(scale=1):
160
- pdf_input = gr.File(label="πŸ“Ž Upload PDF", file_types=[".pdf"])
161
- process_btn = gr.Button("πŸ”„ Process PDF", variant="primary")
162
- status = gr.Textbox(label="Status", lines=3)
163
- clear_all_btn = gr.Button("πŸ—‘οΈ Clear All")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  with gr.Column(scale=2):
166
- chatbot = gr.Chatbot(label="πŸ’¬ Chat", height=400)
167
- msg = gr.Textbox(label="Question", placeholder="Ask about the PDF...")
 
 
 
 
 
 
 
 
168
  with gr.Row():
169
- send_btn = gr.Button("Send", variant="primary")
170
  clear_btn = gr.Button("Clear Chat")
171
 
172
  # Events
173
- process_btn.click(process_pdf, [pdf_input], [status, chatbot])
 
 
 
 
174
 
175
- msg.submit(chat, [msg, chatbot], [chatbot]).then(lambda: "", None, [msg])
176
- send_btn.click(chat, [msg, chatbot], [chatbot]).then(lambda: "", None, [msg])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  clear_btn.click(lambda: None, None, [chatbot])
179
  clear_all_btn.click(clear_all, None, [chatbot, status])
@@ -182,4 +297,5 @@ with gr.Blocks(title="Chat with PDF") as demo:
182
  initialize_models()
183
 
184
  if __name__ == "__main__":
185
- demo.launch()
 
 
1
+ import os
2
+ os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '0'
3
+
4
  import gradio as gr
5
  from sentence_transformers import SentenceTransformer
6
  import numpy as np
7
  from pypdf import PdfReader
8
  import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ import re
11
 
12
  # Global variables
13
  chunks = []
 
15
  model = None
16
  tokenizer = None
17
  embed_model = None
18
+ text_cache = ""
19
 
20
  def initialize_models():
21
+ """Initialize models on startup with optimizations"""
22
  global model, tokenizer, embed_model
23
 
24
  print("Loading models...")
25
 
26
+ # Use smaller, faster embedding model
27
+ embed_model = SentenceTransformer(
28
+ 'sentence-transformers/paraphrase-MiniLM-L3-v2', # Faster, smaller model
29
+ device='cpu'
30
+ )
31
+
32
+ # Use smaller, faster language model
33
+ model_name = "microsoft/phi-1_5" # Much faster than TinyLlama, better quality
34
+ # Alternative: "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ model_name,
38
+ trust_remote_code=True
39
+ )
40
 
 
 
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_name,
43
  torch_dtype=torch.float32,
44
+ low_cpu_mem_usage=True,
45
+ trust_remote_code=True
46
  )
47
 
48
+ # Set padding token
49
+ if tokenizer.pad_token is None:
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+
52
  print("Models loaded successfully!")
53
 
54
+ def smart_chunk_text(text, chunk_size=500, overlap=100):
55
+ """Smarter chunking that respects sentence boundaries"""
56
+ # Split into sentences
57
+ sentences = re.split(r'[.!?]+', text)
58
+ chunks = []
59
+ current_chunk = ""
60
+
61
+ for sentence in sentences:
62
+ sentence = sentence.strip()
63
+ if not sentence:
64
+ continue
65
+
66
+ # If adding this sentence exceeds chunk size, save current chunk
67
+ if len(current_chunk) + len(sentence) > chunk_size and current_chunk:
68
+ chunks.append(current_chunk)
69
+ # Start new chunk with overlap
70
+ words = current_chunk.split()
71
+ current_chunk = " ".join(words[-20:]) + " " + sentence
72
+ else:
73
+ current_chunk += " " + sentence
74
+
75
+ # Add the last chunk
76
+ if current_chunk:
77
+ chunks.append(current_chunk.strip())
78
+
79
+ return chunks
80
+
81
  def process_pdf(pdf_file):
82
+ """Process PDF and create embeddings - OPTIMIZED"""
83
+ global chunks, embeddings, embed_model, text_cache
84
 
85
  if pdf_file is None:
86
  return "❌ Please upload a PDF file!", None
 
95
  if not text.strip():
96
  return "❌ Could not extract text from PDF!", None
97
 
98
+ text_cache = text # Cache for faster reprocessing
 
 
 
99
 
100
+ # Smart chunking (smaller chunks = faster embedding)
101
+ chunks = smart_chunk_text(text, chunk_size=500, overlap=100)
 
 
102
 
103
+ # Batch encode for speed
104
+ print(f"Creating embeddings for {len(chunks)} chunks...")
105
+ embeddings = embed_model.encode(
106
+ chunks,
107
+ batch_size=32, # Process multiple chunks at once
108
+ show_progress_bar=False,
109
+ convert_to_numpy=True
110
+ )
111
 
112
  return f"βœ… PDF processed! Created {len(chunks)} chunks. You can now ask questions!", None
113
 
114
  except Exception as e:
115
+ print(f"Error processing PDF: {str(e)}")
116
  return f"❌ Error: {str(e)}", None
117
 
118
+ def find_relevant_chunks(query, top_k=2): # Reduced from 3 to 2 for speed
119
+ """Find most relevant chunks - OPTIMIZED"""
120
  global chunks, embeddings, embed_model
121
 
122
+ if not chunks or len(embeddings) == 0:
123
  return []
124
 
125
+ # Encode query
126
+ query_embedding = embed_model.encode(
127
+ [query],
128
+ convert_to_numpy=True,
129
+ show_progress_bar=False
130
+ )[0]
131
 
132
+ # Fast cosine similarity using numpy
133
+ embeddings_norm = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
134
+ query_norm = query_embedding / np.linalg.norm(query_embedding)
135
+ similarities = np.dot(embeddings_norm, query_norm)
136
 
137
  # Get top k indices
138
  top_indices = np.argsort(similarities)[-top_k:][::-1]
 
140
  return [chunks[i] for i in top_indices]
141
 
142
  def generate_response(question, context):
143
+ """Generate response - OPTIMIZED"""
144
  global model, tokenizer
145
 
146
+ # Shorter, more efficient prompt
147
+ prompt = f"""Context: {context[:800]}
 
 
 
148
 
149
  Question: {question}
150
+
151
+ Answer:"""
 
152
 
153
+ inputs = tokenizer(
154
+ prompt,
155
+ return_tensors="pt",
156
+ truncation=True,
157
+ max_length=1024 # Reduced from 2048
158
+ )
159
 
160
+ # Faster generation settings
161
  with torch.no_grad():
162
  outputs = model.generate(
163
  **inputs,
164
+ max_new_tokens=150, # Reduced from 300
165
  temperature=0.7,
166
  top_p=0.9,
167
  do_sample=True,
168
+ pad_token_id=tokenizer.eos_token_id,
169
+ num_beams=1, # Greedy search for speed
170
+ early_stopping=True
171
  )
172
 
173
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
174
 
175
+ # Extract answer
176
+ if "Answer:" in response:
177
+ response = response.split("Answer:")[-1].strip()
178
+
179
+ # Clean up response
180
+ response = response.split("\n")[0].strip() # Take first line
181
 
182
  return response
183
 
184
  def chat(message, history):
185
+ """Handle chat - OPTIMIZED"""
186
  global chunks
187
 
188
  if not chunks:
 
192
  return history
193
 
194
  try:
195
+ # Find relevant context (reduced chunks)
196
+ relevant_chunks = find_relevant_chunks(message, top_k=2)
197
+ context = " ".join(relevant_chunks)
198
 
199
  # Generate response
200
  response = generate_response(message, context)
201
 
202
+ # Ensure response is not empty
203
+ if not response or len(response) < 10:
204
+ response = "I found relevant information but couldn't generate a clear answer. Please try rephrasing your question."
205
+
206
  return history + [[message, response]]
207
 
208
  except Exception as e:
209
+ print(f"Error in chat: {str(e)}")
210
  return history + [[message, f"❌ Error: {str(e)}"]]
211
 
212
  def clear_all():
213
  """Clear everything"""
214
+ global chunks, embeddings, text_cache
215
  chunks = []
216
  embeddings = []
217
+ text_cache = ""
218
  return None, "Ready to process a new PDF"
219
 
220
+ # Create UI with better styling
221
+ with gr.Blocks(title="Chat with PDF - Fast", theme=gr.themes.Soft()) as demo:
222
+ gr.Markdown("# ⚑ Chat with PDF - Optimized Fast Version")
223
+ gr.Markdown("*Using lightweight models for faster responses*")
224
 
225
  with gr.Row():
226
  with gr.Column(scale=1):
227
+ pdf_input = gr.File(
228
+ label="πŸ“Ž Upload PDF",
229
+ file_types=[".pdf"]
230
+ )
231
+ process_btn = gr.Button(
232
+ "πŸ”„ Process PDF",
233
+ variant="primary",
234
+ size="lg"
235
+ )
236
+ status = gr.Textbox(
237
+ label="Status",
238
+ lines=2,
239
+ interactive=False
240
+ )
241
+
242
+ gr.Markdown("### Tips:")
243
+ gr.Markdown("""
244
+ - Processing is much faster now!
245
+ - Ask specific questions
246
+ - Keep questions concise
247
+ """)
248
+
249
+ clear_all_btn = gr.Button("πŸ—‘οΈ Clear All", variant="stop")
250
 
251
  with gr.Column(scale=2):
252
+ chatbot = gr.Chatbot(
253
+ label="πŸ’¬ Chat",
254
+ height=450,
255
+ bubble_full_width=False
256
+ )
257
+ msg = gr.Textbox(
258
+ label="Question",
259
+ placeholder="Ask a question about the PDF...",
260
+ lines=2
261
+ )
262
  with gr.Row():
263
+ send_btn = gr.Button("πŸ“€ Send", variant="primary")
264
  clear_btn = gr.Button("Clear Chat")
265
 
266
  # Events
267
+ process_btn.click(
268
+ process_pdf,
269
+ inputs=[pdf_input],
270
+ outputs=[status, chatbot]
271
+ )
272
 
273
+ msg.submit(
274
+ chat,
275
+ inputs=[msg, chatbot],
276
+ outputs=[chatbot]
277
+ ).then(
278
+ lambda: "",
279
+ None,
280
+ [msg]
281
+ )
282
+
283
+ send_btn.click(
284
+ chat,
285
+ inputs=[msg, chatbot],
286
+ outputs=[chatbot]
287
+ ).then(
288
+ lambda: "",
289
+ None,
290
+ [msg]
291
+ )
292
 
293
  clear_btn.click(lambda: None, None, [chatbot])
294
  clear_all_btn.click(clear_all, None, [chatbot, status])
 
297
  initialize_models()
298
 
299
  if __name__ == "__main__":
300
+ demo.queue() # Enable queuing for better performance
301
+ demo.launch(share=False)