shahabkahn's picture
Update app.py
3c4ac4f verified
import os
import re
import logging
import asyncio
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import CTransformers
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from dotenv import load_dotenv
import subprocess
# Load environment variables
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"})
try:
db = FAISS.load_local("vectorstore/db_faiss", embeddings, allow_dangerous_deserialization=True)
logger.info("Vector database loaded successfully!")
except Exception as e:
logger.error(f"Failed to load vector database: {e}")
raise e
try:
llm = CTransformers(
model="TheBloke/Llama-2-7B-Chat-GGML",
model_type="llama",
max_new_tokens=128,
temperature=0.5,
)
logger.info("LLM model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load LLM model: {e}")
raise e
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
qa_prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={"k": 2}),
return_source_documents=True,
chain_type_kwargs={"prompt": qa_prompt},
)
class QuestionRequest(BaseModel):
question: str
class AnswerResponse(BaseModel):
answer: str
def clean_answer(answer):
cleaned_answer = re.sub(r'[^\w\s.,-]', '', answer)
cleaned_answer = re.sub(r'\b(\w+)( \1\b)+', r'\1', cleaned_answer)
cleaned_answer = cleaned_answer.strip()
cleaned_answer = re.sub(r'\s+', ' ', cleaned_answer)
cleaned_answer = re.sub(r'\\n', '\n', cleaned_answer)
cleaned_answer = re.sub(r'^\s*-\s+(.*)$', r'* \1', cleaned_answer, flags=re.MULTILINE)
cleaned_answer = re.sub(r'^\s*\d+\.\s+(.*)$', r'1. \1', cleaned_answer, flags=re.MULTILINE)
cleaned_answer = re.sub(r'^\s*(#+)\s+(.*)$', r'\1 \2', cleaned_answer, flags=re.MULTILINE)
return cleaned_answer
def format_sources(sources):
formatted_sources = []
for source in sources:
metadata = source.metadata
page = metadata.get('page', 'Unknown page')
source_str = f"{metadata.get('source', 'Unknown source')}, page {page}"
formatted_sources.append(source_str)
return "\n".join(formatted_sources)
@app.post("/query", response_model=AnswerResponse)
async def query(question_request: QuestionRequest):
try:
question = question_request.question
if not question:
raise HTTPException(status_code=400, detail="Question is required")
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, qa_chain, {"query": question})
answer = result.get("result")
sources = result.get("source_documents")
if sources:
formatted_sources = format_sources(sources)
answer += "\nSources:\n" + formatted_sources
else:
answer += "\nNo sources found"
cleaned_answer = clean_answer(answer)
return {"answer": cleaned_answer}
except Exception as e:
logger.error(f"Error processing query: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@app.get("/")
async def root():
return RedirectResponse(url="/docs")
@app.on_event("startup")
async def startup_event():
subprocess.Popen(["streamlit", "run", "frontend.py", "--server.port", "8501"])
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=7860)