File size: 4,433 Bytes
3107845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f0dc8
 
3107845
 
 
 
 
e8f0dc8
 
3107845
 
 
 
e8f0dc8
3107845
 
 
 
e8f0dc8
3107845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f0dc8
 
 
3107845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from pathlib import Path

from llama_index.core import(SimpleDirectoryReader,
                            VectorStoreIndex, StorageContext,
                            Settings,set_global_tokenizer)
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.llms.llama_cpp.llama_utils import (
    messages_to_prompt,
    completion_to_prompt,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from transformers import AutoTokenizer, BitsAndBytesConfig
from llama_index.llms.huggingface import HuggingFaceLLM
import torch
import logging
import sys
import streamlit as st

default_bnb_config = BitsAndBytesConfig(
                                                load_in_4bit=True,
                                                bnb_4bit_quant_type='nf4',
                                                bnb_4bit_use_double_quant=True,
                                                bnb_4bit_compute_dtype=torch.bfloat16
                                            )
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
set_global_tokenizer(
    AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf").encode
)


def getDocs(doc_path="./data/"):
    documents = SimpleDirectoryReader(doc_path).load_data()
    return documents


def getVectorIndex(docs):
    Settings.chunk_size = 512
    index_set = {}

    storage_context = StorageContext.from_defaults()
    cur_index = VectorStoreIndex.from_documents(docs, embed_model = getEmbedModel())
    storage_context.persist(persist_dir=f"./storage/book_data")
    return cur_index


def getLLM():

    model_path = "NousResearch/Llama-2-13b-chat-hf"
    # model_path = "meta-llama/Llama-2-13b-chat-hf"

    llm = HuggingFaceLLM(
    context_window=3900,
    max_new_tokens=256,
    # generate_kwargs={"temperature": 0.25, "do_sample": False},
    tokenizer_name=model_path,
    model_name=model_path,
    device_map=0,
    tokenizer_kwargs={"max_length": 2048},
    # uncomment this if using CUDA to reduce memory usage
    model_kwargs={"torch_dtype": torch.float16,
    # "quantization_config": default_bnb_config,
    }
    )
    return llm


def getQueryEngine(index):
    query_engine = index.as_chat_engine(llm=getLLM())
    return query_engine

def getEmbedModel():
    embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
    return embed_model











st.set_page_config(page_title="Project BookWorm: Your own Librarian!", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
st.title("Project BookWorm: Your own Librarian!")
st.info("Use this app to get recommendations for books that your kids will love!", icon="📃")
         
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
    st.session_state.messages = [
        {"role": "assistant", "content": "Ask me a question about children's books or movies!"}
    ]

@st.cache_resource(show_spinner=False)
def load_data():
    index = getVectorIndex(getDocs())
    return index
    query_engine = getQueryEngine(index)

index = load_data()

if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine
        st.session_state.chat_engine = index.as_chat_engine(llm=getLLM(),chat_mode="condense_question", verbose=True)

if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})

for message in st.session_state.messages: # Display the prior chat messages
    with st.chat_message(message["role"]):
        st.write(message["content"])

# If last message is not from assistant, generate a new response
if st.session_state.messages[-1]["role"] != "assistant":
    with st.chat_message("assistant"):
        with st.spinner("Thinking..."):
            response = st.session_state.chat_engine.chat(prompt)
            st.write(response.response)
            message = {"role": "assistant", "content": response.response}
            st.session_state.messages.append(message) # Add response to message history

















# if __name__ == "__main__":

#     index = getVectorIndex(getDocs())
#     query_engine = getQueryEngine(index)
#     while(True):
#         your_request = input("Your comment: ")
#         response = query_engine.chat(your_request)
#         print(response)