from pathlib import Path from llama_index.core import(SimpleDirectoryReader, VectorStoreIndex, StorageContext, Settings,set_global_tokenizer) from llama_index.llms.llama_cpp import LlamaCPP from llama_index.llms.llama_cpp.llama_utils import ( messages_to_prompt, completion_to_prompt, ) from llama_index.embeddings.huggingface import HuggingFaceEmbedding from transformers import AutoTokenizer, BitsAndBytesConfig from llama_index.llms.huggingface import HuggingFaceLLM import torch import logging import sys import streamlit as st import os from llama_index.core import load_index_from_storage default_bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) set_global_tokenizer( AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf").encode ) def getDocs(doc_path="./data/"): documents = SimpleDirectoryReader(doc_path).load_data() return documents def getVectorIndex(docs): Settings.chunk_size = 512 index_set = {} if(os.path.isdir(f"./storage/book_data")): storage_context = StorageContext.from_defaults(persist_dir=f"./storage/book_data") cur_index = load_index_from_storage( storage_context,embed_model = getEmbedModel() ) else: storage_context = StorageContext.from_defaults() cur_index = VectorStoreIndex.from_documents(docs, embed_model = getEmbedModel(), storage_context=storage_context) storage_context.persist(persist_dir=f"./storage/book_data") return cur_index def getLLM(): model_path = "NousResearch/Llama-2-13b-chat-hf" # model_path = "NousResearch/Llama-2-7b-chat-hf" llm = HuggingFaceLLM( context_window=3900, max_new_tokens=256, # generate_kwargs={"temperature": 0.25, "do_sample": False}, tokenizer_name=model_path, model_name=model_path, device_map=0, tokenizer_kwargs={"max_length": 2048}, # uncomment this if using CUDA to reduce memory usage model_kwargs={"torch_dtype": torch.float16, "quantization_config": default_bnb_config, } ) return llm def getQueryEngine(index): query_engine = index.as_chat_engine(llm=getLLM()) return query_engine def getEmbedModel(): embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") return embed_model st.set_page_config(page_title="Project BookWorm: Your own Librarian!", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None) st.title("Project BookWorm: Your own Librarian!") st.info("Use this app to get recommendations for books that your kids will love!", icon="📃") if "messages" not in st.session_state.keys(): # Initialize the chat messages history st.session_state.messages = [ {"role": "assistant", "content": "Ask me a question about children's books or movies!"} ] @st.cache_resource(show_spinner=False) def load_data(): index = getVectorIndex(getDocs()) return index query_engine = getQueryEngine(index) index = load_data() if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine st.session_state.chat_engine = index.as_chat_engine(llm=getLLM(),chat_mode="condense_question", verbose=True) if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history st.session_state.messages.append({"role": "user", "content": prompt}) for message in st.session_state.messages: # Display the prior chat messages with st.chat_message(message["role"]): st.write(message["content"]) # If last message is not from assistant, generate a new response if st.session_state.messages[-1]["role"] != "assistant": with st.chat_message("assistant"): with st.spinner("Thinking..."): response = st.session_state.chat_engine.chat(prompt) st.write(response.response) message = {"role": "assistant", "content": response.response} st.session_state.messages.append(message) # Add response to message history # if __name__ == "__main__": # index = getVectorIndex(getDocs()) # query_engine = getQueryEngine(index) # while(True): # your_request = input("Your comment: ") # response = query_engine.chat(your_request) # print(response)