File size: 1,663 Bytes
a3f9c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import streamlit as st
from app.chat import initialize_session_state, display_chat_history
from app.data_loader import get_data, load_docs
from app.document_processor import process_documents
from app.prompts import llama_prompt
from langchain_community.llms import Replicate
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_community.document_transformers import (
    LongContextReorder
)


def create_conversational_chain(vector_store):
    llm = Replicate(
        model="meta/meta-llama-3-8b-instruct",
        model_kwargs={"temperature": 0.5, "top_p": 1, "max_new_tokens":10000}
    )

    memory = ConversationBufferMemory(
        memory_key="chat_history", return_messages=True, output_key='answer')
    chain = ConversationalRetrievalChain.from_llm(llm, retriever=vector_store.as_retriever(search_kwargs={"k": 6}), combine_docs_chain_kwargs={"prompt": llama_prompt}, return_source_documents=True, memory=memory)

    return chain

def reorder_embedding(docs):
    reordering = LongContextReorder()
    reordered_docs = reordering.transform_documents(docs)
    return reordered_docs

def main():
    initialize_session_state()
    get_data()

    if len(st.session_state['history']) == 0:
        docs = load_docs()
        reordered = reorder_embedding(docs)
        vector_store = process_documents(reordered)
        st.session_state['vector_store'] = vector_store

    if st.session_state['vector_store'] is not None:
        chain = create_conversational_chain(st.session_state['vector_store'])
        display_chat_history(chain)

if __name__ == "__main__":
    main()