Spaces:
Sleeping
Sleeping
File size: 4,808 Bytes
6e94f25 8eaaf5c 6e94f25 555a055 6e94f25 555a055 6e94f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import logging
import time
import asyncio
from fastapi import status
from langchain_groq import ChatGroq
from langchain.schema import Document
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from app.core.template import prompt_template_description
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2")
# Async PDF loader
async def pdf_loader(url: str):
pages = []
loader = PyPDFLoader(url)
async for page in loader.alazy_load():
pages.append(page)
return pages
# Main function to create/load vectorstore
async def load_and_create_vector_store(url: str):
"""
Loads a PDF document from a URL and either reuses or builds a FAISS vectorstore.
Returns a retriever object.
"""
vectorstore_path = "/tmp/database/faiss_index"
if os.path.exists(f"{vectorstore_path}/index.faiss"):
logging.info("Vector store already exists, loading it.")
vectorstore = FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True)
else:
logging.info("Vector store not found. Creating new one from document.")
pages = await pdf_loader(url)
if not pages:
raise ValueError("No pages loaded from the document.")
full_text = "\n\n".join([page.page_content for page in pages])
documents = [Document(page_content=full_text, metadata={"source": url})]
# Use CharacterTextSplitter with optimized parameters for better chunk quality
text_splitter = CharacterTextSplitter(
separator="\n\n",
chunk_size=2500,
chunk_overlap=300,
length_function=len,
)
split_docs = text_splitter.split_documents(documents)
logging.info(f"Document split into {len(split_docs)} chunks")
vectorstore = FAISS.from_documents(split_docs, embeddings)
vectorstore.save_local(vectorstore_path)
return vectorstore.as_retriever(
search_kwargs={"k": 2, "score_threshold": 0.5}
)
async def llm_setup(config, url):
"""
Setup the LLM for question answering.
This function initializes the LLM with the necessary configurations
for processing questions and generating answers based on the context.
Args:
config: Configuration dictionary with LLM settings
url: URL of the document to process
Returns:
object: The configured LLM instance.
"""
llm = ChatGroq(
model=f"{config.get('MODEL_NAME')}",
temperature=f"{config.get('TEMPERATURE', 0)}",
max_tokens=f"{config.get('MAX_TOKENS', 300)}", # Increased token limit for JSON responses
max_retries=f"{config.get('MAX_RETRIES', 3)}",
api_key=f"{os.getenv('GROQ_KEY')}",
)
logging.info(f"LLM initialized with model: {config.get('MODEL_NAME')}, api_key: {os.getenv('GROQ_KEY')}")
# Choose template based on whether we need structured JSON output
prompt_template = prompt_template_description()
retriever = await load_and_create_vector_store(url=url)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt_template}
)
return qa_chain
async def llm_response_generator(config, url, questions):
"""
Generate answers from the LLM within 30 seconds.
Args:
config: Configuration dictionary with LLM settings
url: URL of the document to process
questions: List of questions to answer
use_json: Whether to force JSON output format
Returns:
Tuple of (response dict, status code)
"""
try:
start = time.time()
qa_chain = await llm_setup(config, url)
answers = []
for question in questions:
elapsed = time.time() - start
if elapsed > 28: # leave margin for safety
logging.warning("Time limit reached, skipping remaining questions.")
break
try:
answer = await qa_chain.arun(question)
answers.append(answer)
except Exception as e:
logging.error(f"Error answering: {question} | {e}")
answers.append("Error processing question.")
return {"answers": answers}, status.HTTP_200_OK
except Exception as e:
logging.error(f"Error in llm_response_generator: {e}")
return {"answers": []}, status.HTTP_500_INTERNAL_SERVER_ERROR
|