File size: 4,519 Bytes
3107845 07b9194 3107845 6afa145 3107845 fa23bcc 3107845 6afa145 3107845 e8f0dc8 3107845 a857dd6 3107845 e8f0dc8 3107845 e8f0dc8 3107845 07b9194 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from pathlib import Path
from llama_index.core import(SimpleDirectoryReader,
VectorStoreIndex, StorageContext,
Settings,set_global_tokenizer)
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.llms.llama_cpp.llama_utils import (
messages_to_prompt,
completion_to_prompt,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from transformers import AutoTokenizer, BitsAndBytesConfig
from llama_index.llms.huggingface import HuggingFaceLLM
import torch
import logging
import sys
import streamlit as st
from llama_index.core import load_index_from_storage
default_bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
set_global_tokenizer(
AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf").encode
)
def getDocs(doc_path="./data/"):
documents = SimpleDirectoryReader(doc_path).load_data()
return documents
def getVectorIndex(docs):
Settings.chunk_size = 512
index_set = {}
storage_context = StorageContext.from_defaults()
cur_index = VectorStoreIndex.from_documents(docs, embed_model = getEmbedModel(), storage_context=storage_context)
storage_context.persist(persist_dir=f"./storage/book_data")
return cur_index
def getLLM():
model_path = "NousResearch/Llama-2-13b-chat-hf"
# model_path = "NousResearch/Llama-2-7b-chat-hf"
llm = HuggingFaceLLM(
context_window=3900,
max_new_tokens=256,
# generate_kwargs={"temperature": 0.25, "do_sample": False},
tokenizer_name=model_path,
model_name=model_path,
device_map=0,
tokenizer_kwargs={"max_length": 2048},
# uncomment this if using CUDA to reduce memory usage
model_kwargs={"torch_dtype": torch.float16,
"quantization_config": default_bnb_config,
}
)
return llm
def getQueryEngine(index):
query_engine = index.as_chat_engine(llm=getLLM())
return query_engine
def getEmbedModel():
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
return embed_model
st.set_page_config(page_title="Project BookWorm: Your own Librarian!", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
st.title("Project BookWorm: Your own Librarian!")
st.info("Use this app to get recommendations for books that your kids will love!", icon="📃")
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
st.session_state.messages = [
{"role": "assistant", "content": "Ask me a question about children's books or movies!"}
]
@st.cache_resource(show_spinner=False)
def load_data():
index = getVectorIndex(getDocs())
return index
query_engine = getQueryEngine(index)
index = load_data()
if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine
st.session_state.chat_engine = index.as_chat_engine(llm=getLLM(),chat_mode="condense_question", verbose=True)
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# If last message is not from assistant, generate a new response
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = st.session_state.chat_engine.chat(prompt)
st.write(response.response)
message = {"role": "assistant", "content": response.response}
st.session_state.messages.append(message) # Add response to message history
# if __name__ == "__main__":
# index = getVectorIndex(getDocs())
# query_engine = getQueryEngine(index)
# while(True):
# your_request = input("Your comment: ")
# response = query_engine.chat(your_request)
# print(response)
|