kshitijk commited on
Commit
8a078c3
2 Parent(s): da65768 7d5df89

Switch to openAI apis

Browse files
.gitattributes CHANGED
@@ -33,4 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- *json filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
app.py CHANGED
@@ -1,13 +1,19 @@
1
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
2
 
3
  from llama_index.core import(SimpleDirectoryReader,
4
  VectorStoreIndex, StorageContext,
5
  Settings,set_global_tokenizer)
6
- from llama_index.llms.llama_cpp import LlamaCPP
7
- from llama_index.llms.llama_cpp.llama_utils import (
8
- messages_to_prompt,
9
- completion_to_prompt,
10
- )
11
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
  from transformers import AutoTokenizer, BitsAndBytesConfig
13
  from llama_index.llms.huggingface import HuggingFaceLLM
@@ -17,17 +23,15 @@ import sys
17
  import streamlit as st
18
  import os
19
  from llama_index.core import load_index_from_storage
20
- default_bnb_config = BitsAndBytesConfig(
21
- load_in_4bit=True,
22
- bnb_4bit_quant_type='nf4',
23
- bnb_4bit_use_double_quant=True,
24
- bnb_4bit_compute_dtype=torch.bfloat16
25
- )
 
26
  logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
27
  logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
28
- set_global_tokenizer(
29
- AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf").encode
30
- )
31
 
32
 
33
  def getDocs(doc_path="./data/"):
@@ -35,52 +39,29 @@ def getDocs(doc_path="./data/"):
35
  return documents
36
 
37
 
38
- def getVectorIndex(docs):
39
  Settings.chunk_size = 512
40
  index_set = {}
41
-
42
- if(os.path.isdir(f"./storage/book_data")):
43
- storage_context = StorageContext.from_defaults(persist_dir=f"./storage/book_data")
 
 
44
  cur_index = load_index_from_storage(
45
- storage_context,embed_model = getEmbedModel()
46
  )
47
  else:
 
 
48
  storage_context = StorageContext.from_defaults()
49
- cur_index = VectorStoreIndex.from_documents(docs, embed_model = getEmbedModel(), storage_context=storage_context)
50
- storage_context.persist(persist_dir=f"./storage/book_data")
51
  return cur_index
52
 
53
-
54
- def getLLM():
55
-
56
- model_path = "NousResearch/Llama-2-13b-chat-hf"
57
- # model_path = "NousResearch/Llama-2-7b-chat-hf"
58
-
59
- llm = HuggingFaceLLM(
60
- context_window=3900,
61
- max_new_tokens=256,
62
- # generate_kwargs={"temperature": 0.25, "do_sample": False},
63
- tokenizer_name=model_path,
64
- model_name=model_path,
65
- device_map=0,
66
- tokenizer_kwargs={"max_length": 2048},
67
- # uncomment this if using CUDA to reduce memory usage
68
- model_kwargs={"torch_dtype": torch.float16,
69
- "quantization_config": default_bnb_config,
70
- }
71
- )
72
- return llm
73
-
74
-
75
  def getQueryEngine(index):
76
- query_engine = index.as_chat_engine(llm=getLLM())
77
  return query_engine
78
 
79
- def getEmbedModel():
80
- embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
81
- return embed_model
82
-
83
-
84
 
85
 
86
 
@@ -101,14 +82,17 @@ if "messages" not in st.session_state.keys(): # Initialize the chat messages his
101
 
102
  @st.cache_resource(show_spinner=False)
103
  def load_data():
104
- index = getVectorIndex(getDocs())
105
  return index
106
- query_engine = getQueryEngine(index)
107
-
108
  index = load_data()
 
 
 
109
 
110
  if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine
111
- st.session_state.chat_engine = index.as_chat_engine(llm=getLLM(),chat_mode="condense_question", verbose=True)
112
 
113
  if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
114
  st.session_state.messages.append({"role": "user", "content": prompt})
 
1
  from pathlib import Path
2
+ import os
3
+ import openai
4
+ openai.api_key = os.getenv("OAI_KEY")
5
+ from llama_index.llms.openai import OpenAI
6
+ from llama_index.embeddings.openai import OpenAIEmbedding
7
+ import nest_asyncio
8
+
9
+ nest_asyncio.apply()
10
+
11
+
12
 
13
  from llama_index.core import(SimpleDirectoryReader,
14
  VectorStoreIndex, StorageContext,
15
  Settings,set_global_tokenizer)
16
+
 
 
 
 
17
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
18
  from transformers import AutoTokenizer, BitsAndBytesConfig
19
  from llama_index.llms.huggingface import HuggingFaceLLM
 
23
  import streamlit as st
24
  import os
25
  from llama_index.core import load_index_from_storage
26
+
27
+
28
+ Settings.llm = OpenAI(model="gpt-3.5-turbo-instruct", temperature=0.2)
29
+ Settings.embed_model = OpenAIEmbedding(
30
+ model="text-embedding-3-large", embed_batch_size=100
31
+ )
32
+
33
  logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
34
  logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
 
 
 
35
 
36
 
37
  def getDocs(doc_path="./data/"):
 
39
  return documents
40
 
41
 
42
+ def getVectorIndex():
43
  Settings.chunk_size = 512
44
  index_set = {}
45
+ if os.path.isdir(f"./storage/open_ai_embedding_data_large"):
46
+ print("Index already exists")
47
+ storage_context = StorageContext.from_defaults(
48
+ persist_dir=f"./storage/open_ai_embedding_data_large"
49
+ )
50
  cur_index = load_index_from_storage(
51
+ storage_context,
52
  )
53
  else:
54
+ print("Index does not exist, creating new index")
55
+ docs = getDocs()
56
  storage_context = StorageContext.from_defaults()
57
+ cur_index = VectorStoreIndex.from_documents(docs, storage_context=storage_context)
58
+ storage_context.persist(persist_dir=f"./storage/open_ai_embedding_data_large")
59
  return cur_index
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def getQueryEngine(index):
62
+ query_engine = index.as_chat_engine()
63
  return query_engine
64
 
 
 
 
 
 
65
 
66
 
67
 
 
82
 
83
  @st.cache_resource(show_spinner=False)
84
  def load_data():
85
+ index = getVectorIndex()
86
  return index
87
+ import time
88
+ s_time = time.time()
89
  index = load_data()
90
+ e_time = time.time()
91
+
92
+ print(f"It took {e_time - s_time} to load index")
93
 
94
  if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine
95
+ st.session_state.chat_engine = index.as_chat_engine(chat_mode="condense_plus_context", verbose=True)
96
 
97
  if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
98
  st.session_state.messages.append({"role": "user", "content": prompt})
app.py.bkp ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from llama_index.core import(SimpleDirectoryReader,
4
+ VectorStoreIndex, StorageContext,
5
+ Settings,set_global_tokenizer)
6
+ from llama_index.llms.llama_cpp import LlamaCPP
7
+ from llama_index.llms.llama_cpp.llama_utils import (
8
+ messages_to_prompt,
9
+ completion_to_prompt,
10
+ )
11
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
+ from transformers import AutoTokenizer, BitsAndBytesConfig
13
+ from llama_index.llms.huggingface import HuggingFaceLLM
14
+ import torch
15
+ import logging
16
+ import sys
17
+ import streamlit as st
18
+ import os
19
+ from llama_index.core import load_index_from_storage
20
+ default_bnb_config = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_quant_type='nf4',
23
+ bnb_4bit_use_double_quant=True,
24
+ bnb_4bit_compute_dtype=torch.bfloat16
25
+ )
26
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
27
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
28
+ set_global_tokenizer(
29
+ AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf").encode
30
+ )
31
+
32
+
33
+ def getDocs(doc_path="./data/"):
34
+ documents = SimpleDirectoryReader(doc_path).load_data()
35
+ return documents
36
+
37
+
38
+ def getVectorIndex(docs):
39
+ Settings.chunk_size = 512
40
+ index_set = {}
41
+
42
+ if(os.path.isdir(f"./storage/book_data")):
43
+ storage_context = StorageContext.from_defaults(persist_dir=f"./storage/book_data")
44
+ cur_index = load_index_from_storage(
45
+ storage_context,embed_model = getEmbedModel()
46
+ )
47
+ else:
48
+ storage_context = StorageContext.from_defaults()
49
+ cur_index = VectorStoreIndex.from_documents(docs, embed_model = getEmbedModel(), storage_context=storage_context)
50
+ storage_context.persist(persist_dir=f"./storage/book_data")
51
+ return cur_index
52
+
53
+
54
+ def getLLM():
55
+
56
+ model_path = "NousResearch/Llama-2-13b-chat-hf"
57
+ # model_path = "NousResearch/Llama-2-7b-chat-hf"
58
+
59
+ llm = HuggingFaceLLM(
60
+ context_window=3900,
61
+ max_new_tokens=256,
62
+ # generate_kwargs={"temperature": 0.25, "do_sample": False},
63
+ tokenizer_name=model_path,
64
+ model_name=model_path,
65
+ device_map=0,
66
+ tokenizer_kwargs={"max_length": 2048},
67
+ # uncomment this if using CUDA to reduce memory usage
68
+ model_kwargs={"torch_dtype": torch.float16,
69
+ "quantization_config": default_bnb_config,
70
+ }
71
+ )
72
+ return llm
73
+
74
+
75
+ def getQueryEngine(index):
76
+ query_engine = index.as_chat_engine(llm=getLLM())
77
+ return query_engine
78
+
79
+ def getEmbedModel():
80
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
81
+ return embed_model
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+ st.set_page_config(page_title="Project BookWorm: Your own Librarian!", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
94
+ st.title("Project BookWorm: Your own Librarian!")
95
+ st.info("Use this app to get recommendations for books that your kids will love!", icon="📃")
96
+
97
+ if "messages" not in st.session_state.keys(): # Initialize the chat messages history
98
+ st.session_state.messages = [
99
+ {"role": "assistant", "content": "Ask me a question about children's books or movies!"}
100
+ ]
101
+
102
+ @st.cache_resource(show_spinner=False)
103
+ def load_data():
104
+ index = getVectorIndex(getDocs())
105
+ return index
106
+ query_engine = getQueryEngine(index)
107
+
108
+ index = load_data()
109
+
110
+ if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine
111
+ st.session_state.chat_engine = index.as_chat_engine(llm=getLLM(),chat_mode="condense_question", verbose=True)
112
+
113
+ if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
114
+ st.session_state.messages.append({"role": "user", "content": prompt})
115
+
116
+ for message in st.session_state.messages: # Display the prior chat messages
117
+ with st.chat_message(message["role"]):
118
+ st.write(message["content"])
119
+
120
+ # If last message is not from assistant, generate a new response
121
+ if st.session_state.messages[-1]["role"] != "assistant":
122
+ with st.chat_message("assistant"):
123
+ with st.spinner("Thinking..."):
124
+ response = st.session_state.chat_engine.chat(prompt)
125
+ st.write(response.response)
126
+ message = {"role": "assistant", "content": response.response}
127
+ st.session_state.messages.append(message) # Add response to message history
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+ # if __name__ == "__main__":
146
+
147
+ # index = getVectorIndex(getDocs())
148
+ # query_engine = getQueryEngine(index)
149
+ # while(True):
150
+ # your_request = input("Your comment: ")
151
+ # response = query_engine.chat(your_request)
152
+ # print(response)
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
storage/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *json filter=lfs diff=lfs merge=lfs -text
storage/open_ai_embedding_data/default__vector_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87956943cb8e0633d9df6a98d31a12c9528901114a79b39c179734999cee7163
3
+ size 244449202
storage/open_ai_embedding_data/docstore.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a7d9222ba808d2bf326098e84b7b959ba0104c923ce7f87c782ecfd93404325
3
+ size 29962555
storage/open_ai_embedding_data/graph_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e0a77744010862225c69da83c585f4f8a42fd551b044ce530dbb1eb6e16742c
3
+ size 18
storage/open_ai_embedding_data/image__vector_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d17ed74c1649a438e518a8dc56a7772913dfe1ea7a7605bce069c63872431455
3
+ size 72
storage/open_ai_embedding_data/index_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd52cfd6aba4fb5c0774b7e3d38ddcda21e0cf5344a86c8eaf7c8690bb451bcd
3
+ size 589927
storage/open_ai_embedding_data_large/default__vector_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ad49abac9bff5c529bb2446b985e7ae14a74328d6d2293f6d421326b3851538
3
+ size 487945734
storage/open_ai_embedding_data_large/docstore.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3983b645d0aefc3d92158d573cb9bc4d3f79077a77066e140d8e725dd7e085b5
3
+ size 29962555
storage/open_ai_embedding_data_large/graph_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e0a77744010862225c69da83c585f4f8a42fd551b044ce530dbb1eb6e16742c
3
+ size 18
storage/open_ai_embedding_data_large/image__vector_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d17ed74c1649a438e518a8dc56a7772913dfe1ea7a7605bce069c63872431455
3
+ size 72
storage/open_ai_embedding_data_large/index_store.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6a7b6dad9b8f418dfd26132e54203b8dca1374dc8e8c3199d5e9d001816f3cf
3
+ size 589927