File size: 4,337 Bytes
7938cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a007ed6
7938cd4
 
22545a0
7938cd4
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from fastapi import FastAPI
import os
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import os
from langchain_community.document_loaders import PyPDFLoader
import os
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings.sentence_transformer import (
    SentenceTransformerEmbeddings,
)
from langchain_chroma import Chroma
from sentence_transformers import SentenceTransformer
from langchain_core.messages import AIMessage, HumanMessage
from fastapi import FastAPI, Request, UploadFile, File

os.environ['HF_HOME'] = '/hug/cache/'
os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'

app = FastAPI()

def predict(message, db):

    llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
    template = """You are a general purpose chatbot. Be friendly and kind. Help people answer their questions. Use the context below to answer the questions
    {context}
    Question: {question}
    Helpful Answer:"""
    QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,)
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        return_messages=True
    )
        
    retriever = db.as_retriever(k=3)

    contextualize_q_system_prompt = """Given a chat history and the latest user question \
    which might reference context in the chat history, formulate a standalone question \
    which can be understood without the chat history. Do NOT answer the question, \
    just reformulate it if needed and otherwise return it as is."""
    contextualize_q_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder(variable_name="chat_history"),
            ("human", "{question}"),
        ]
    )
    contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
    def contextualized_question(input: dict):
      if input.get("chat_history"):
          return contextualize_q_chain
      else:
          return input["question"]

    rag_chain = (
        RunnablePassthrough.assign(
            context=contextualized_question | retriever
        )
        | QA_CHAIN_PROMPT
        | llm
      )
    history = []
    ai_msg = rag_chain.invoke({"question": message, "chat_history": history})
    print(ai_msg)
    bot_response = ai_msg.content.strip()

    # Ensure history is correctly formatted as a list of tuples (user_message, bot_response)
    history.append((HumanMessage(content=message), AIMessage(content=bot_response)))

    docs = db.similarity_search(message,k=3)
    extra = "\n" + "*"*100 + "\n"
    additional_info = []
    for d in docs:
        citations = d.metadata["source"] + " pg." + str(d.metadata["page"])
        additional_info = d.page_content
        extra += citations + "\n" + additional_info + "\n" + "*"*100 + "\n"
    # Return the bot's response and the updated history
    return bot_response + extra

def upload_file(file_path):

    loaders = []
    print(file_path)
    loaders.append(PyPDFLoader(file_path))

    documents = []
    for loader in loaders:
        documents.extend(loader.load())

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=16)
    docs = text_splitter.split_documents(documents)

    model = "thenlper/gte-large"
    embedding_function = SentenceTransformerEmbeddings(model_name=model)
    print(f"Model's maximum sequence length: {SentenceTransformer(model).max_seq_length}")
    collection_name = "Autism"
    persist_directory = "./chroma"
    print(len(docs))
    db = Chroma.from_documents(docs, embedding_function)
    print("Done Processing, you can query")

    return db

@app.get("/")
async def root():
 return {"Entvin":"Version 1.0 'First Draft'"}

@app.post("/UploadFile/")
def predicts(question: str, file: UploadFile = File(...)):
    contents =  file.file.read()
    with open(file.filename, 'wb') as f:
        f.write(contents)
    
    db = upload_file(file.filename)
    result = predict(question, db)
    return {"answer":result}