FileAPI / main.py
tensorgirl's picture
Update main.py
22545a0 verified
from fastapi import FastAPI
import os
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import os
from langchain_community.document_loaders import PyPDFLoader
import os
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain_chroma import Chroma
from sentence_transformers import SentenceTransformer
from langchain_core.messages import AIMessage, HumanMessage
from fastapi import FastAPI, Request, UploadFile, File
os.environ['HF_HOME'] = '/hug/cache/'
os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'
app = FastAPI()
def predict(message, db):
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
template = """You are a general purpose chatbot. Be friendly and kind. Help people answer their questions. Use the context below to answer the questions
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
retriever = db.as_retriever(k=3)
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
def contextualized_question(input: dict):
if input.get("chat_history"):
return contextualize_q_chain
else:
return input["question"]
rag_chain = (
RunnablePassthrough.assign(
context=contextualized_question | retriever
)
| QA_CHAIN_PROMPT
| llm
)
history = []
ai_msg = rag_chain.invoke({"question": message, "chat_history": history})
print(ai_msg)
bot_response = ai_msg.content.strip()
# Ensure history is correctly formatted as a list of tuples (user_message, bot_response)
history.append((HumanMessage(content=message), AIMessage(content=bot_response)))
docs = db.similarity_search(message,k=3)
extra = "\n" + "*"*100 + "\n"
additional_info = []
for d in docs:
citations = d.metadata["source"] + " pg." + str(d.metadata["page"])
additional_info = d.page_content
extra += citations + "\n" + additional_info + "\n" + "*"*100 + "\n"
# Return the bot's response and the updated history
return bot_response + extra
def upload_file(file_path):
loaders = []
print(file_path)
loaders.append(PyPDFLoader(file_path))
documents = []
for loader in loaders:
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=16)
docs = text_splitter.split_documents(documents)
model = "thenlper/gte-large"
embedding_function = SentenceTransformerEmbeddings(model_name=model)
print(f"Model's maximum sequence length: {SentenceTransformer(model).max_seq_length}")
collection_name = "Autism"
persist_directory = "./chroma"
print(len(docs))
db = Chroma.from_documents(docs, embedding_function)
print("Done Processing, you can query")
return db
@app.get("/")
async def root():
return {"Entvin":"Version 1.0 'First Draft'"}
@app.post("/UploadFile/")
def predicts(question: str, file: UploadFile = File(...)):
contents = file.file.read()
with open(file.filename, 'wb') as f:
f.write(contents)
db = upload_file(file.filename)
result = predict(question, db)
return {"answer":result}