Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,7 @@ import faiss
|
|
| 8 |
import numpy as np
|
| 9 |
import gradio as gr
|
| 10 |
from sentence_transformers import SentenceTransformer
|
| 11 |
-
from transformers import AutoTokenizer,
|
| 12 |
from PyPDF2 import PdfReader
|
| 13 |
import docx
|
| 14 |
|
|
@@ -40,7 +40,7 @@ def get_default_config():
|
|
| 40 |
"models": {
|
| 41 |
# Embedding model for FAISS
|
| 42 |
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
|
| 43 |
-
# Abstractive generation model
|
| 44 |
"qa": "google/flan-t5-small",
|
| 45 |
},
|
| 46 |
"chunking": {
|
|
@@ -181,7 +181,8 @@ def load_kb_documents(kb_dir: str) -> List[Tuple[str, str]]:
|
|
| 181 |
class RAGIndex:
|
| 182 |
def __init__(self):
|
| 183 |
self.embedder = None
|
| 184 |
-
self.
|
|
|
|
| 185 |
self.chunks: List[str] = []
|
| 186 |
self.chunk_sources: List[str] = []
|
| 187 |
self.index = None
|
|
@@ -203,13 +204,9 @@ class RAGIndex:
|
|
| 203 |
print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
|
| 204 |
self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| 205 |
|
| 206 |
-
print(f"Loading QA (
|
| 207 |
-
|
| 208 |
-
self.
|
| 209 |
-
"text2text-generation",
|
| 210 |
-
model=QA_MODEL_NAME,
|
| 211 |
-
tokenizer=QA_MODEL_NAME,
|
| 212 |
-
)
|
| 213 |
except Exception as e:
|
| 214 |
print(f"Error loading models: {e}")
|
| 215 |
raise
|
|
@@ -329,6 +326,31 @@ class RAGIndex:
|
|
| 329 |
print(f"Retrieval error: {e}")
|
| 330 |
return []
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
def answer(self, question: str) -> str:
|
| 333 |
"""Answer a question using RAG + abstractive generation"""
|
| 334 |
if not self.initialized:
|
|
@@ -363,7 +385,7 @@ class RAGIndex:
|
|
| 363 |
|
| 364 |
combined_text = "\n\n".join(combined_context)
|
| 365 |
|
| 366 |
-
# Limit context length to keep it manageable
|
| 367 |
max_context_chars = 4000
|
| 368 |
if len(combined_text) > max_context_chars:
|
| 369 |
combined_text = combined_text[:max_context_chars]
|
|
@@ -379,13 +401,7 @@ class RAGIndex:
|
|
| 379 |
)
|
| 380 |
|
| 381 |
try:
|
| 382 |
-
|
| 383 |
-
prompt,
|
| 384 |
-
max_new_tokens=256,
|
| 385 |
-
do_sample=False,
|
| 386 |
-
)
|
| 387 |
-
# text2text-generation returns list of dicts with 'generated_text'
|
| 388 |
-
answer_text = result[0]["generated_text"].strip()
|
| 389 |
except Exception as e:
|
| 390 |
print(f"Generation error: {e}")
|
| 391 |
return (
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import gradio as gr
|
| 10 |
from sentence_transformers import SentenceTransformer
|
| 11 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 12 |
from PyPDF2 import PdfReader
|
| 13 |
import docx
|
| 14 |
|
|
|
|
| 40 |
"models": {
|
| 41 |
# Embedding model for FAISS
|
| 42 |
"embedding": "sentence-transformers/all-MiniLM-L6-v2",
|
| 43 |
+
# Abstractive generation model
|
| 44 |
"qa": "google/flan-t5-small",
|
| 45 |
},
|
| 46 |
"chunking": {
|
|
|
|
| 181 |
class RAGIndex:
|
| 182 |
def __init__(self):
|
| 183 |
self.embedder = None
|
| 184 |
+
self.qa_tokenizer = None
|
| 185 |
+
self.qa_model = None
|
| 186 |
self.chunks: List[str] = []
|
| 187 |
self.chunk_sources: List[str] = []
|
| 188 |
self.index = None
|
|
|
|
| 204 |
print(f"Loading embedding model: {EMBEDDING_MODEL_NAME}")
|
| 205 |
self.embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| 206 |
|
| 207 |
+
print(f"Loading QA (seq2seq) model: {QA_MODEL_NAME}")
|
| 208 |
+
self.qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
|
| 209 |
+
self.qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
except Exception as e:
|
| 211 |
print(f"Error loading models: {e}")
|
| 212 |
raise
|
|
|
|
| 326 |
print(f"Retrieval error: {e}")
|
| 327 |
return []
|
| 328 |
|
| 329 |
+
def _generate_from_context(self, prompt: str) -> str:
|
| 330 |
+
"""Run Flan-T5 on the given prompt and return the decoded answer."""
|
| 331 |
+
if self.qa_model is None or self.qa_tokenizer is None:
|
| 332 |
+
return "Model not loaded."
|
| 333 |
+
|
| 334 |
+
inputs = self.qa_tokenizer(
|
| 335 |
+
prompt,
|
| 336 |
+
return_tensors="pt",
|
| 337 |
+
truncation=True,
|
| 338 |
+
max_length=768,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
output_ids = self.qa_model.generate(
|
| 342 |
+
**inputs,
|
| 343 |
+
max_new_tokens=256,
|
| 344 |
+
do_sample=False,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
answer = self.qa_tokenizer.decode(
|
| 348 |
+
output_ids[0],
|
| 349 |
+
skip_special_tokens=True,
|
| 350 |
+
).strip()
|
| 351 |
+
|
| 352 |
+
return answer
|
| 353 |
+
|
| 354 |
def answer(self, question: str) -> str:
|
| 355 |
"""Answer a question using RAG + abstractive generation"""
|
| 356 |
if not self.initialized:
|
|
|
|
| 385 |
|
| 386 |
combined_text = "\n\n".join(combined_context)
|
| 387 |
|
| 388 |
+
# Limit context length to keep it manageable
|
| 389 |
max_context_chars = 4000
|
| 390 |
if len(combined_text) > max_context_chars:
|
| 391 |
combined_text = combined_text[:max_context_chars]
|
|
|
|
| 401 |
)
|
| 402 |
|
| 403 |
try:
|
| 404 |
+
answer_text = self._generate_from_context(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
except Exception as e:
|
| 406 |
print(f"Generation error: {e}")
|
| 407 |
return (
|