Spaces:
Sleeping
Sleeping
import os | |
from typing import Iterator, Iterable | |
from tempfile import TemporaryDirectory | |
from dotenv import load_dotenv | |
import gradio as gr | |
from langchain_core.document_loaders import BaseLoader | |
from langchain_core.documents import Document as LCDocument | |
from docling.document_converter import DocumentConverter | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
from langchain_milvus import Milvus | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
class DoclingPDFLoader(BaseLoader): | |
def __init__(self, file_paths: str | list[str]) -> None: | |
self._file_paths = file_paths if isinstance(file_paths, list) else [file_paths] | |
self._converter = DocumentConverter() | |
def lazy_load(self) -> Iterator[LCDocument]: | |
for source in self._file_paths: | |
dl_doc = self._converter.convert(source).document | |
text = dl_doc.export_to_markdown() | |
yield LCDocument(page_content=text) | |
def load(self) -> list[LCDocument]: | |
return list(self.lazy_load()) | |
# Load environment variables | |
load_dotenv() | |
# File paths configuration | |
FILE_PATHS = ["vol1.txt", "vol2.txt", "vol3.txt", "vol4.txt", "vol5.txt"] | |
# Load and split documents | |
loader = DoclingPDFLoader(file_paths=FILE_PATHS) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
) | |
splits = text_splitter.split_documents(FILE_PATHS) | |
# Set up embeddings | |
HF_EMBED_MODEL_ID = "BAAI/bge-small-en-v1.5" | |
embeddings = HuggingFaceEmbeddings(model_name=HF_EMBED_MODEL_ID) | |
# Configure Milvus | |
MILVUS_URI = os.environ.get( | |
"MILVUS_URI", f"{(tmp_dir := TemporaryDirectory()).name}/milvus_demo.db" | |
) | |
# Initialize vector store | |
vectorstore = Milvus.from_documents( | |
splits, | |
embeddings, | |
connection_args={"uri": MILVUS_URI}, | |
drop_old=True, | |
index_params={ | |
"index_type": "IVF_FLAT", | |
"metric_type": "L2", | |
"params": {"nlist": 100} | |
}, | |
) | |
# Set up LLM | |
HF_API_KEY = os.environ.get("HF_TOKEN") | |
HF_LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3" | |
llm = HuggingFaceEndpoint( | |
repo_id=HF_LLM_MODEL_ID, | |
huggingfacehub_api_token=HF_API_KEY, | |
) | |
# FastAPI setup | |
app = FastAPI() | |
class QueryRequest(BaseModel): | |
question: str | |
def format_docs(docs: Iterable[LCDocument]): | |
return "\n\n".join(doc.page_content for doc in docs) | |
# Create RAG chain | |
retriever = vectorstore.as_retriever() | |
prompt = PromptTemplate.from_template( | |
"Context information is below.\n---------------------\n{context}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {question}\nAnswer:\n" | |
) | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
async def query_documents(request: QueryRequest): | |
try: | |
response = rag_chain.invoke(request.question) | |
return {"answer": response} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Gradio interface | |
def ask_question(question): | |
return rag_chain.invoke(question) | |
if __name__ == "__main__": | |
# Launch both FastAPI and Gradio | |
iface = gr.Interface(fn=ask_question, inputs="text", outputs="text") | |
iface.launch(server_name="0.0.0.0", server_port=7860) | |
# FastAPI runs on a different port | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |