Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
import os | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import PromptTemplate | |
from langchain.memory import ConversationBufferMemory | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
import os | |
from langchain_community.document_loaders import PyPDFLoader | |
import os | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings, | |
) | |
from langchain_chroma import Chroma | |
from sentence_transformers import SentenceTransformer | |
from langchain_core.messages import AIMessage, HumanMessage | |
from fastapi import FastAPI, Request, UploadFile, File | |
os.environ['HF_HOME'] = '/hug/cache/' | |
os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/' | |
app = FastAPI() | |
def predict(message, db): | |
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) | |
template = """You are a general purpose chatbot. Be friendly and kind. Help people answer their questions. Use the context below to answer the questions | |
{context} | |
Question: {question} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,) | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True | |
) | |
retriever = db.as_retriever(k=3) | |
contextualize_q_system_prompt = """Given a chat history and the latest user question \ | |
which might reference context in the chat history, formulate a standalone question \ | |
which can be understood without the chat history. Do NOT answer the question, \ | |
just reformulate it if needed and otherwise return it as is.""" | |
contextualize_q_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", contextualize_q_system_prompt), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("human", "{question}"), | |
] | |
) | |
contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser() | |
def contextualized_question(input: dict): | |
if input.get("chat_history"): | |
return contextualize_q_chain | |
else: | |
return input["question"] | |
rag_chain = ( | |
RunnablePassthrough.assign( | |
context=contextualized_question | retriever | |
) | |
| QA_CHAIN_PROMPT | |
| llm | |
) | |
history = [] | |
ai_msg = rag_chain.invoke({"question": message, "chat_history": history}) | |
print(ai_msg) | |
bot_response = ai_msg.content.strip() | |
# Ensure history is correctly formatted as a list of tuples (user_message, bot_response) | |
history.append((HumanMessage(content=message), AIMessage(content=bot_response))) | |
docs = db.similarity_search(message,k=3) | |
extra = "\n" + "*"*100 + "\n" | |
additional_info = [] | |
for d in docs: | |
citations = d.metadata["source"] + " pg." + str(d.metadata["page"]) | |
additional_info = d.page_content | |
extra += citations + "\n" + additional_info + "\n" + "*"*100 + "\n" | |
# Return the bot's response and the updated history | |
return bot_response + extra | |
def upload_file(file_path): | |
loaders = [] | |
print(file_path) | |
loaders.append(PyPDFLoader(file_path)) | |
documents = [] | |
for loader in loaders: | |
documents.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=16) | |
docs = text_splitter.split_documents(documents) | |
model = "thenlper/gte-large" | |
embedding_function = SentenceTransformerEmbeddings(model_name=model) | |
print(f"Model's maximum sequence length: {SentenceTransformer(model).max_seq_length}") | |
collection_name = "Autism" | |
persist_directory = "./chroma" | |
print(len(docs)) | |
db = Chroma.from_documents(docs, embedding_function) | |
print("Done Processing, you can query") | |
return db | |
async def root(): | |
return {"Entvin":"Version 1.0 'First Draft'"} | |
def predicts(question: str, file: UploadFile = File(...)): | |
contents = file.file.read() | |
with open(file.filename, 'wb') as f: | |
f.write(contents) | |
db = upload_file(file.filename) | |
result = predict(question, db) | |
return {"answer":result} | |