PoemTest / app_recent.py
thomasjacob04's picture
Rename app.py to app_recent.py
f8f632b verified
import os
from typing import Iterator, Iterable
from tempfile import TemporaryDirectory
from dotenv import load_dotenv
import gradio as gr
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document as LCDocument
from docling.document_converter import DocumentConverter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_milvus import Milvus
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
class DoclingPDFLoader(BaseLoader):
def __init__(self, file_paths: str | list[str]) -> None:
self._file_paths = file_paths if isinstance(file_paths, list) else [file_paths]
self._converter = DocumentConverter()
def lazy_load(self) -> Iterator[LCDocument]:
for source in self._file_paths:
dl_doc = self._converter.convert(source).document
text = dl_doc.export_to_markdown()
yield LCDocument(page_content=text)
def load(self) -> list[LCDocument]:
return list(self.lazy_load())
# Load environment variables
load_dotenv()
# File paths configuration
FILE_PATHS = ["vol1.txt", "vol2.txt", "vol3.txt", "vol4.txt", "vol5.txt"]
# Load and split documents
loader = DoclingPDFLoader(file_paths=FILE_PATHS)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
splits = text_splitter.split_documents(FILE_PATHS)
# Set up embeddings
HF_EMBED_MODEL_ID = "BAAI/bge-small-en-v1.5"
embeddings = HuggingFaceEmbeddings(model_name=HF_EMBED_MODEL_ID)
# Configure Milvus
MILVUS_URI = os.environ.get(
"MILVUS_URI", f"{(tmp_dir := TemporaryDirectory()).name}/milvus_demo.db"
)
# Initialize vector store
vectorstore = Milvus.from_documents(
splits,
embeddings,
connection_args={"uri": MILVUS_URI},
drop_old=True,
index_params={
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 100}
},
)
# Set up LLM
HF_API_KEY = os.environ.get("HF_TOKEN")
HF_LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
llm = HuggingFaceEndpoint(
repo_id=HF_LLM_MODEL_ID,
huggingfacehub_api_token=HF_API_KEY,
)
# FastAPI setup
app = FastAPI()
class QueryRequest(BaseModel):
question: str
def format_docs(docs: Iterable[LCDocument]):
return "\n\n".join(doc.page_content for doc in docs)
# Create RAG chain
retriever = vectorstore.as_retriever()
prompt = PromptTemplate.from_template(
"Context information is below.\n---------------------\n{context}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {question}\nAnswer:\n"
)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
@app.post("/query")
async def query_documents(request: QueryRequest):
try:
response = rag_chain.invoke(request.question)
return {"answer": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface
def ask_question(question):
return rag_chain.invoke(question)
if __name__ == "__main__":
# Launch both FastAPI and Gradio
iface = gr.Interface(fn=ask_question, inputs="text", outputs="text")
iface.launch(server_name="0.0.0.0", server_port=7860)
# FastAPI runs on a different port
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)