sofzcc commited on
Commit
72f5bc1
·
verified ·
1 Parent(s): 27759ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -8,7 +8,7 @@ import faiss
8
  import numpy as np
9
  import gradio as gr
10
  from sentence_transformers import SentenceTransformer
11
- from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
12
  from PyPDF2 import PdfReader
13
  import docx
14
 
@@ -40,7 +40,7 @@ def get_default_config():
40
  "models": {
41
  # Embedding model for FAISS
42
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
43
- # Abstractive generation model (can upgrade to flan-t5-base if resources allow)
44
  "qa": "google/flan-t5-small",
45
  },
46
  "chunking": {
@@ -181,7 +181,8 @@ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
181
  class RAGIndex:
182
  def __init__(self):
183
  self.embedder = None
184
- self.qa_pipeline = None # now a generative pipeline
 
185
  self.chunks: List[str] = []
186
  self.chunk_sources: List[str] = []
187
  self.index = None
@@ -203,13 +204,9 @@ class RAGIndex:
203
  print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
204
  self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
205
 
206
- print(f"Loading QA (generation) model: {QA_MODEL_NAME}")
207
- # Abstractive generation pipeline (Flan-T5)
208
- self.qa_pipeline = pipeline(
209
- "text2text-generation",
210
- model=QA_MODEL_NAME,
211
- tokenizer=QA_MODEL_NAME,
212
- )
213
  except Exception as e:
214
  print(f"Error loading models: {e}")
215
  raise
@@ -329,6 +326,31 @@ class RAGIndex:
329
  print(f"Retrieval error: {e}")
330
  return []
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  def answer(self, question: str) -> str:
333
  """Answer a question using RAG + abstractive generation"""
334
  if not self.initialized:
@@ -363,7 +385,7 @@ class RAGIndex:
363
 
364
  combined_text = "\n\n".join(combined_context)
365
 
366
- # Limit context length to keep it manageable for the model
367
  max_context_chars = 4000
368
  if len(combined_text) > max_context_chars:
369
  combined_text = combined_text[:max_context_chars]
@@ -379,13 +401,7 @@ class RAGIndex:
379
  )
380
 
381
  try:
382
- result = self.qa_pipeline(
383
- prompt,
384
- max_new_tokens=256,
385
- do_sample=False,
386
- )
387
- # text2text-generation returns list of dicts with 'generated_text'
388
- answer_text = result[0]["generated_text"].strip()
389
  except Exception as e:
390
  print(f"Generation error: {e}")
391
  return (
 
8
  import numpy as np
9
  import gradio as gr
10
  from sentence_transformers import SentenceTransformer
11
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
  from PyPDF2 import PdfReader
13
  import docx
14
 
 
40
  "models": {
41
  # Embedding model for FAISS
42
  "embedding": "sentence-transformers/all-MiniLM-L6-v2",
43
+ # Abstractive generation model
44
  "qa": "google/flan-t5-small",
45
  },
46
  "chunking": {
 
181
  class RAGIndex:
182
  def __init__(self):
183
  self.embedder = None
184
+ self.qa_tokenizer = None
185
+ self.qa_model = None
186
  self.chunks: List[str] = []
187
  self.chunk_sources: List[str] = []
188
  self.index = None
 
204
  print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
205
  self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
206
 
207
+ print(f"Loading QA (seq2seq) model: {QA_MODEL_NAME}")
208
+ self.qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
209
+ self.qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
 
 
 
 
210
  except Exception as e:
211
  print(f"Error loading models: {e}")
212
  raise
 
326
  print(f"Retrieval error: {e}")
327
  return []
328
 
329
+ def _generate_from_context(self, prompt: str) -> str:
330
+ """Run Flan-T5 on the given prompt and return the decoded answer."""
331
+ if self.qa_model is None or self.qa_tokenizer is None:
332
+ return "Model not loaded."
333
+
334
+ inputs = self.qa_tokenizer(
335
+ prompt,
336
+ return_tensors="pt",
337
+ truncation=True,
338
+ max_length=768,
339
+ )
340
+
341
+ output_ids = self.qa_model.generate(
342
+ **inputs,
343
+ max_new_tokens=256,
344
+ do_sample=False,
345
+ )
346
+
347
+ answer = self.qa_tokenizer.decode(
348
+ output_ids[0],
349
+ skip_special_tokens=True,
350
+ ).strip()
351
+
352
+ return answer
353
+
354
  def answer(self, question: str) -> str:
355
  """Answer a question using RAG + abstractive generation"""
356
  if not self.initialized:
 
385
 
386
  combined_text = "\n\n".join(combined_context)
387
 
388
+ # Limit context length to keep it manageable
389
  max_context_chars = 4000
390
  if len(combined_text) > max_context_chars:
391
  combined_text = combined_text[:max_context_chars]
 
401
  )
402
 
403
  try:
404
+ answer_text = self._generate_from_context(prompt)
 
 
 
 
 
 
405
  except Exception as e:
406
  print(f"Generation error: {e}")
407
  return (