Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Iterator, Iterable | |
| from tempfile import TemporaryDirectory | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| from langchain_core.document_loaders import BaseLoader | |
| from langchain_core.documents import Document as LCDocument | |
| from docling.document_converter import DocumentConverter | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface.embeddings import HuggingFaceEmbeddings | |
| from langchain_milvus import Milvus | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| class DoclingPDFLoader(BaseLoader): | |
| def __init__(self, file_paths: str | list[str]) -> None: | |
| self._file_paths = file_paths if isinstance(file_paths, list) else [file_paths] | |
| self._converter = DocumentConverter() | |
| def lazy_load(self) -> Iterator[LCDocument]: | |
| for source in self._file_paths: | |
| dl_doc = self._converter.convert(source).document | |
| text = dl_doc.export_to_markdown() | |
| yield LCDocument(page_content=text) | |
| def load(self) -> list[LCDocument]: | |
| return list(self.lazy_load()) | |
| # Load environment variables | |
| load_dotenv() | |
| # File paths configuration | |
| FILE_PATHS = ["vol1.txt", "vol2.txt", "vol3.txt", "vol4.txt", "vol5.txt"] | |
| # Load and split documents | |
| loader = DoclingPDFLoader(file_paths=FILE_PATHS) | |
| docs = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| ) | |
| splits = text_splitter.split_documents(FILE_PATHS) | |
| # Set up embeddings | |
| HF_EMBED_MODEL_ID = "BAAI/bge-small-en-v1.5" | |
| embeddings = HuggingFaceEmbeddings(model_name=HF_EMBED_MODEL_ID) | |
| # Configure Milvus | |
| MILVUS_URI = os.environ.get( | |
| "MILVUS_URI", f"{(tmp_dir := TemporaryDirectory()).name}/milvus_demo.db" | |
| ) | |
| # Initialize vector store | |
| vectorstore = Milvus.from_documents( | |
| splits, | |
| embeddings, | |
| connection_args={"uri": MILVUS_URI}, | |
| drop_old=True, | |
| index_params={ | |
| "index_type": "IVF_FLAT", | |
| "metric_type": "L2", | |
| "params": {"nlist": 100} | |
| }, | |
| ) | |
| # Set up LLM | |
| HF_API_KEY = os.environ.get("HF_TOKEN") | |
| HF_LLM_MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3" | |
| llm = HuggingFaceEndpoint( | |
| repo_id=HF_LLM_MODEL_ID, | |
| huggingfacehub_api_token=HF_API_KEY, | |
| ) | |
| # FastAPI setup | |
| app = FastAPI() | |
| class QueryRequest(BaseModel): | |
| question: str | |
| def format_docs(docs: Iterable[LCDocument]): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| # Create RAG chain | |
| retriever = vectorstore.as_retriever() | |
| prompt = PromptTemplate.from_template( | |
| "Context information is below.\n---------------------\n{context}\n---------------------\nGiven the context information and not prior knowledge, answer the query.\nQuery: {question}\nAnswer:\n" | |
| ) | |
| rag_chain = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| async def query_documents(request: QueryRequest): | |
| try: | |
| response = rag_chain.invoke(request.question) | |
| return {"answer": response} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Gradio interface | |
| def ask_question(question): | |
| return rag_chain.invoke(question) | |
| if __name__ == "__main__": | |
| # Launch both FastAPI and Gradio | |
| iface = gr.Interface(fn=ask_question, inputs="text", outputs="text") | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |
| # FastAPI runs on a different port | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |