cack / app.py
ujalaarshad17's picture
Upload 8 files
a16eb78 verified
'''
Necessary Imports
'''
from fastapi import FastAPI, UploadFile, File, HTTPException,Form
from fastapi.middleware.cors import CORSMiddleware
from langchain.text_splitter import RecursiveCharacterTextSplitter
from postgres import PostgresChatMessageHistory
from langchain_community.document_loaders import PyPDFLoader
from langchain_postgres.vectorstores import PGVector
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from typing import Dict
from langchain_openai import ChatOpenAI
from prompt import prompt,system_prompt
import psycopg
import uuid
import os
from custom_message import CustomMessage
from dotenv import load_dotenv
import os
from io import BytesIO
from pypdf import PdfReader
from langchain.docstore.document import Document
vector_store = None
# LOADING ENVIRONMENT VARIABLES
load_dotenv()
# INSTANTIATING THE APP
app = FastAPI()
llm = ChatOpenAI(model="gpt-4o",
temperature=0.2,
max_tokens=None,
timeout=None,
max_retries=1)
# ALLOWING CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# INITIALIZING THE EMBEDDING MODEL
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=300,
length_function=len,
)
@app.get("/")
def greeting():
return {'response':'success','status code':200}
# PDF UPLOAD ROUTE
@app.post("/upload")
async def upload_pdf(file: UploadFile = File(...), collection_name: str = Form(...)):
"""
Upload and process a PDF file, storing its embeddings in the vector database.
"""
if not file.filename.endswith('.pdf'):
raise HTTPException(status_code=400, detail="Only PDF files are allowed")
try:
# Read PDF content directly into memory
pdf_content = await file.read()
pdf_file = BytesIO(pdf_content)
pdf_reader = PdfReader(pdf_file)
# Extract text from PDF
documents = []
for page_num, page in enumerate(pdf_reader.pages):
text = page.extract_text()
# Create a Document object with metadata
doc = Document(
page_content=text,
metadata={"page": page_num + 1, "source": file.filename}
)
documents.append(doc)
# Split documents into chunks
texts = text_splitter.split_documents(documents)
try:
global vector_store
vector_store = PGVector.from_documents(
documents=texts,
embedding=embeddings,
connection=os.environ['CONNECTION_STRING'],
collection_name=collection_name,
use_jsonb=True,
)
except Exception as e:
raise("Error in establishing the connection with DB: {e}")
return {"message": "PDF processed successfully", "collection_name": file.filename.replace('.pdf', '')}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/query")
async def upload_pdf(query: str = Form(...),collection_name:str = Form(...),username:str = Form(...),table_name:str = Form(...)):
try:
global vector_store
if vector_store == None :
vector_store = PGVector(
embeddings=embeddings,
connection=os.environ['CONNECTION_STRING'],
collection_name=collection_name,
use_jsonb=True,
)
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3})
question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
response = rag_chain.invoke({"input":query})['answer']
sync_connection = psycopg.connect(os.environ['CONNECTION_STRING'])
session_id = str(uuid.uuid4())
chat_history = PostgresChatMessageHistory(
table_name,
session_id,
username,
sync_connection=sync_connection
)
try:
custom_message = CustomMessage(content=f"SYSTEM_PROMPT:{system_prompt}\n\nHUMAN_MESSAGE:{query}\n\nAI_RESPONSE:{response}")
chat_history.add_message(custom_message)
except Exception as e:
print(e)
print("Ended")
return {
"relevant docs":response,
"session_id":session_id
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)