Spaces:
No application file
No application file
''' | |
Necessary Imports | |
''' | |
from fastapi import FastAPI, UploadFile, File, HTTPException,Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from postgres import PostgresChatMessageHistory | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain_postgres.vectorstores import PGVector | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
from typing import Dict | |
from langchain_openai import ChatOpenAI | |
from prompt import prompt,system_prompt | |
import psycopg | |
import uuid | |
import os | |
from custom_message import CustomMessage | |
from dotenv import load_dotenv | |
import os | |
from io import BytesIO | |
from pypdf import PdfReader | |
from langchain.docstore.document import Document | |
vector_store = None | |
# LOADING ENVIRONMENT VARIABLES | |
load_dotenv() | |
# INSTANTIATING THE APP | |
app = FastAPI() | |
llm = ChatOpenAI(model="gpt-4o", | |
temperature=0.2, | |
max_tokens=None, | |
timeout=None, | |
max_retries=1) | |
# ALLOWING CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# INITIALIZING THE EMBEDDING MODEL | |
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=300, | |
length_function=len, | |
) | |
def greeting(): | |
return {'response':'success','status code':200} | |
# PDF UPLOAD ROUTE | |
async def upload_pdf(file: UploadFile = File(...), collection_name: str = Form(...)): | |
""" | |
Upload and process a PDF file, storing its embeddings in the vector database. | |
""" | |
if not file.filename.endswith('.pdf'): | |
raise HTTPException(status_code=400, detail="Only PDF files are allowed") | |
try: | |
# Read PDF content directly into memory | |
pdf_content = await file.read() | |
pdf_file = BytesIO(pdf_content) | |
pdf_reader = PdfReader(pdf_file) | |
# Extract text from PDF | |
documents = [] | |
for page_num, page in enumerate(pdf_reader.pages): | |
text = page.extract_text() | |
# Create a Document object with metadata | |
doc = Document( | |
page_content=text, | |
metadata={"page": page_num + 1, "source": file.filename} | |
) | |
documents.append(doc) | |
# Split documents into chunks | |
texts = text_splitter.split_documents(documents) | |
try: | |
global vector_store | |
vector_store = PGVector.from_documents( | |
documents=texts, | |
embedding=embeddings, | |
connection=os.environ['CONNECTION_STRING'], | |
collection_name=collection_name, | |
use_jsonb=True, | |
) | |
except Exception as e: | |
raise("Error in establishing the connection with DB: {e}") | |
return {"message": "PDF processed successfully", "collection_name": file.filename.replace('.pdf', '')} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def upload_pdf(query: str = Form(...),collection_name:str = Form(...),username:str = Form(...),table_name:str = Form(...)): | |
try: | |
global vector_store | |
if vector_store == None : | |
vector_store = PGVector( | |
embeddings=embeddings, | |
connection=os.environ['CONNECTION_STRING'], | |
collection_name=collection_name, | |
use_jsonb=True, | |
) | |
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
question_answer_chain = create_stuff_documents_chain(llm, prompt) | |
rag_chain = create_retrieval_chain(retriever, question_answer_chain) | |
response = rag_chain.invoke({"input":query})['answer'] | |
sync_connection = psycopg.connect(os.environ['CONNECTION_STRING']) | |
session_id = str(uuid.uuid4()) | |
chat_history = PostgresChatMessageHistory( | |
table_name, | |
session_id, | |
username, | |
sync_connection=sync_connection | |
) | |
try: | |
custom_message = CustomMessage(content=f"SYSTEM_PROMPT:{system_prompt}\n\nHUMAN_MESSAGE:{query}\n\nAI_RESPONSE:{response}") | |
chat_history.add_message(custom_message) | |
except Exception as e: | |
print(e) | |
print("Ended") | |
return { | |
"relevant docs":response, | |
"session_id":session_id | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# if __name__ == "__main__": | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |