ChatWith.PDF / app.py
theerasin's picture
Update app.py
5a30f3b verified
# app.py
from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel
from typing import List
import fitz # PyMuPDF
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import HuggingFacePipeline
from langchain_core.documents import Document as LangchainDocument
# --- Init FastAPI ---
app = FastAPI()
# --- Summarizer ---
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
# --- Question Answering ---
qa_pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")
# --- Embedding model ---
embedding_model = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-small-en-v1.5")
# --- Text Splitter ---
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
# --- Pydantic schemas ---
class Summary(BaseModel):
summary: str
class KeyPoint(BaseModel):
point: str
class DocumentAnalysis(BaseModel):
summary: Summary
key_points: List[KeyPoint]
class QARequest(BaseModel):
question: str
context: str
class QAResponse(BaseModel):
answer: str
# --- PDF Text Extractor ---
def extract_text_from_pdf(pdf_file: UploadFile) -> str:
text = ""
with fitz.open(stream=pdf_file.file.read(), filetype="pdf") as doc:
for page in doc:
text += page.get_text()
return text
# --- Analyze Text (summarization) ---
def analyze_text_structured(text: str) -> DocumentAnalysis:
chunks = text_splitter.split_text(text)
summaries = []
for chunk in chunks:
result = summarizer(chunk, max_length=200, min_length=50, do_sample=False)
if result:
summaries.append(result[0]["summary_text"])
full_summary = " ".join(summaries)
key_points = [KeyPoint(point=line.strip()) for line in full_summary.split(". ") if line.strip()]
return DocumentAnalysis(summary=Summary(summary=full_summary), key_points=key_points)
# --- Question Answering ---
def answer_question(question: str, context: str) -> str:
result = qa_pipe(question=question, context=context)
return result["answer"]
# --- PDF Upload + Analysis Route ---
@app.post("/analyze-pdf", response_model=DocumentAnalysis)
async def analyze_pdf(file: UploadFile = File(...)):
text = extract_text_from_pdf(file)
analysis = analyze_text_structured(text)
return analysis
# --- Question Answering Route ---
@app.post("/qa", response_model=QAResponse)
async def ask_question(qa_request: QARequest):
answer = answer_question(qa_request.question, qa_request.context)
return QAResponse(answer=answer)
# --- Embedding Search (FAISS) Demo ---
@app.post("/search-chunks")
async def search_chunks(file: UploadFile = File(...), query: str = ""):
text = extract_text_from_pdf(file)
chunks = text_splitter.split_text(text)
documents = [LangchainDocument(page_content=chunk) for chunk in chunks]
# Create FAISS vector store
db = FAISS.from_documents(documents, embedding_model)
# Similarity search
results = db.similarity_search(query, k=3)
return {"results": [doc.page_content for doc in results]}