File size: 4,995 Bytes
2e3966c
 
 
e15fa4b
 
2e3966c
e8445ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ce9e10
2e3966c
 
 
 
 
 
 
 
 
e8445ba
2e3966c
 
 
 
 
8ce9e10
 
 
2e3966c
8ce9e10
e8445ba
 
 
 
 
 
 
 
2e3966c
 
e8445ba
 
 
 
 
 
 
 
 
2e3966c
 
 
 
 
 
 
 
 
 
8ce9e10
2e3966c
 
e8445ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e3966c
 
 
 
 
e8445ba
 
2e3966c
e8445ba
 
 
 
 
 
 
 
8ce9e10
2e3966c
 
e8445ba
2e3966c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import torch
from langchain_core.prompts import PromptTemplate
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS

from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain

from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import MessagesPlaceholder

from langchain.chains import create_history_aware_retriever

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory



# Check for GPU availability and set the appropriate device for computation.
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

# Global variables
conversation_retrieval_chain = None
chat_history = []
llm_hub = None
embeddings = None
tokenizer = None

# Function to initialize the language model and its embeddings
def init_llm():
    global llm_hub, embeddings

    # Hugging Face API token
    # Setup environment variable HUGGINGFACEHUB_API_TOKEN

    model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

    llm_hub = HuggingFaceEndpoint(
        repo_id=model_id,
        task="text-generation",
        max_new_tokens=200,
        do_sample=False,
        repetition_penalty=1.03,
        return_full_text=False,
        temperature=0.1,
    )

    from langchain.embeddings import HuggingFaceEmbeddings
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


# Function to process a PDF document
def process_document(document_path):
    global conversation_retrieval_chain

    # Load the document
    loader = PyPDFLoader(document_path)
    documents = loader.load()
    
    # Split the document into chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
    texts = text_splitter.split_documents(documents)
    
    # Create an embeddings database using FAISS from the split text chunks.
    db = FAISS.from_documents(documents=texts, embedding=embeddings)

    system_prompt = """
    <|start_header_id|>user<|end_header_id|>
    You are an assistant for answering questions using provided context.
    You are given the extracted parts of a long document, previous chat_history and a question. Provide a conversational answer.
    If you don't know the answer, just say "I do not know." Don't make up an answer.
    Question: {input}
    Context: {context}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            ("human", "{input}"),
        ]
    )
    
    retriever=db.as_retriever(search_type="similarity", search_kwargs={'k': 3, 'lambda_mult': 0.25})
    question_answer_chain = create_stuff_documents_chain(llm_hub, prompt)    
    # conversation_retrieval_chain = create_retrieval_chain(retriever, question_answer_chain)

    contextualize_q_system_prompt = (
        "Given a chat history and the latest user question "
        "which might reference context in the chat history, "
        "formulate a standalone question which can be understood "
        "without the chat history. Do NOT answer the question, "
        "just reformulate it if needed and otherwise return it as is."
    )
    contextualize_q_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", contextualize_q_system_prompt),
            MessagesPlaceholder("chat_history"),
            ("human", "{input}"),
        ]
    )
    history_aware_retriever = create_history_aware_retriever(
        llm_hub, retriever, contextualize_q_prompt
    )

    rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

    conversation_retrieval_chain = RunnableWithMessageHistory(
        rag_chain,
        get_session_history,
        input_messages_key="input",
        history_messages_key="chat_history",
        output_messages_key="answer",
    )


# Function to process a user prompt
def process_prompt(prompt):
    # global conversation_retrieval_chain
    global chat_history   
    
    # Query the model with history    
    output = conversation_retrieval_chain.invoke(
        {"input": prompt},
        config={
            "configurable": {"session_id": "abc123"}
        },  # constructs a key "abc123" in `store`.
    )
    answer = output["answer"]
    print(output)
    
    # Return the model's response
    return answer

# Initialize the language model
init_llm()