testing / model.py
samim2024's picture
Update model.py
2cfdb3c verified
import os
import tempfile
import PyPDF2
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFaceEndpoint
# Use /tmp for cache
CACHE_DIR = tempfile.gettempdir()
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR
DATA_PATH = "/app/data"
VECTORSTORE_PATH = "/app/vectorstore"
EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-MiniLM-L6-v2"
def load_embedding_model():
"""Load sentence transformer embeddings."""
return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
def load_documents(pdf_path):
"""Extract text from PDF and split into documents."""
try:
with open(pdf_path, "rb") as f:
pdf = PyPDF2.PdfReader(f)
text = "".join(page.extract_text() or "" for page in pdf.pages)
if not text.strip():
raise ValueError("No text extracted from PDF")
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
docs = splitter.create_documents([text])
return docs
except Exception as e:
raise ValueError(f"Failed to process PDF: {str(e)}")
def load_vectorstore(pdf_path):
"""Load or create FAISS vector store from PDF."""
vectorstore_file = os.path.join(VECTORSTORE_PATH, "faiss_index")
embedding_model = load_embedding_model()
if os.path.exists(vectorstore_file):
try:
return FAISS.load_local(vectorstore_file, embedding_model, allow_dangerous_deserialization=True)
except:
pass # Rebuild if loading fails
docs = load_documents(pdf_path)
vectorstore = FAISS.from_documents(docs, embedding_model)
vectorstore.save_local(vectorstore_file)
return vectorstore
def ask_question(query, pdf_path):
"""Run RAG query and return answer with contexts."""
api_key = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not api_key:
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
vectorstore = load_vectorstore(pdf_path)
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
huggingfacehub_api_token=api_key,
temperature=0.5,
max_new_tokens=256
)
qa = RetrievalQA.from_chain_type(
llm=llm,
retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),
return_source_documents=True
)
result = qa({"query": query})
return {
"answer": result["result"],
"contexts": [doc.page_content for doc in result["source_documents"]]
}