kshitijkumbar commited on
Commit
3107845
1 Parent(s): f937bd0

Init commit

Browse files
Files changed (3) hide show
  1. app.py +152 -0
  2. data/books_summary.txt +0 -0
  3. requirements.txt +180 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from llama_index.core import(SimpleDirectoryReader,
4
+ VectorStoreIndex, StorageContext,
5
+ Settings,set_global_tokenizer)
6
+ from llama_index.llms.llama_cpp import LlamaCPP
7
+ from llama_index.llms.llama_cpp.llama_utils import (
8
+ messages_to_prompt,
9
+ completion_to_prompt,
10
+ )
11
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
+ from transformers import AutoTokenizer, BitsAndBytesConfig
13
+ from llama_index.llms.huggingface import HuggingFaceLLM
14
+ import torch
15
+ import logging
16
+ import sys
17
+ import streamlit as st
18
+
19
+ default_bnb_config = BitsAndBytesConfig(
20
+ load_in_4bit=True,
21
+ bnb_4bit_quant_type='nf4',
22
+ bnb_4bit_use_double_quant=True,
23
+ bnb_4bit_compute_dtype=torch.bfloat16
24
+ )
25
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
26
+ logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
27
+ set_global_tokenizer(
28
+ AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf").encode
29
+ )
30
+
31
+
32
+ def getDocs(doc_path="./data/"):
33
+ documents = SimpleDirectoryReader(doc_path).load_data()
34
+ return documents
35
+
36
+
37
+ def getVectorIndex(docs):
38
+ Settings.chunk_size = 512
39
+ index_set = {}
40
+
41
+ storage_context = StorageContext.from_defaults()
42
+ cur_index = VectorStoreIndex.from_documents(docs, embed_model = getEmbedModel())
43
+ storage_context.persist(persist_dir=f"./storage/book_data")
44
+ return cur_index
45
+
46
+
47
+ def getLLM():
48
+
49
+
50
+ llm = HuggingFaceLLM(
51
+ context_window=3900,
52
+ max_new_tokens=256,
53
+ # generate_kwargs={"temperature": 0.25, "do_sample": False},
54
+ tokenizer_name="meta-llama/Llama-2-13b-chat-hf",
55
+ model_name="meta-llama/Llama-2-13b-chat-hf",
56
+ device_map=0,
57
+ tokenizer_kwargs={"max_length": 2048},
58
+ # uncomment this if using CUDA to reduce memory usage
59
+ model_kwargs={"torch_dtype": torch.float16,
60
+ "quantization_config": default_bnb_config,
61
+ }
62
+ )
63
+ return llm
64
+
65
+ def getQueryEngine(index):
66
+ query_engine = index.as_chat_engine(llm=getLLM())
67
+ return query_engine
68
+
69
+ def getEmbedModel():
70
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
71
+ return embed_model
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+ st.set_page_config(page_title="Chat with the Streamlit docs, powered by LlamaIndex", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
84
+ st.title("Chat with the Streamlit docs, powered by LlamaIndex 💬🦙")
85
+ st.info("Check out the full tutorial to build this app in our [blog post](https://blog.streamlit.io/build-a-chatbot-with-custom-data-sources-powered-by-llamaindex/)", icon="📃")
86
+
87
+ if "messages" not in st.session_state.keys(): # Initialize the chat messages history
88
+ st.session_state.messages = [
89
+ {"role": "assistant", "content": "Ask me a question about children's books or movies!"}
90
+ ]
91
+
92
+ @st.cache_resource(show_spinner=False)
93
+ def load_data():
94
+ index = getVectorIndex(getDocs())
95
+ return index
96
+ query_engine = getQueryEngine(index)
97
+
98
+ index = load_data()
99
+
100
+ if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine
101
+ st.session_state.chat_engine = index.as_chat_engine(llm=getLLM(),chat_mode="condense_question", verbose=True)
102
+
103
+ if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
104
+ st.session_state.messages.append({"role": "user", "content": prompt})
105
+
106
+ for message in st.session_state.messages: # Display the prior chat messages
107
+ with st.chat_message(message["role"]):
108
+ st.write(message["content"])
109
+
110
+ # If last message is not from assistant, generate a new response
111
+ if st.session_state.messages[-1]["role"] != "assistant":
112
+ with st.chat_message("assistant"):
113
+ with st.spinner("Thinking..."):
114
+ response = st.session_state.chat_engine.chat(prompt)
115
+ st.write(response.response)
116
+ message = {"role": "assistant", "content": response.response}
117
+ st.session_state.messages.append(message) # Add response to message history
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+ # if __name__ == "__main__":
136
+
137
+ # index = getVectorIndex(getDocs())
138
+ # query_engine = getQueryEngine(index)
139
+ # while(True):
140
+ # your_request = input("Your comment: ")
141
+ # response = query_engine.chat(your_request)
142
+ # print(response)
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
data/books_summary.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.27.2
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ altair==5.2.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ asgiref==3.7.2
8
+ attrs==23.2.0
9
+ backoff==2.2.1
10
+ bcrypt==4.1.2
11
+ beautifulsoup4==4.12.3
12
+ bitsandbytes==0.42.0
13
+ blinker==1.7.0
14
+ bs4==0.0.2
15
+ build==1.1.1
16
+ cachetools==5.3.3
17
+ certifi==2024.2.2
18
+ charset-normalizer==3.3.2
19
+ chroma-hnswlib==0.7.3
20
+ chromadb==0.4.24
21
+ click==8.1.7
22
+ coloredlogs==15.0.1
23
+ dataclasses-json==0.6.4
24
+ Deprecated==1.2.14
25
+ dirtyjson==1.0.8
26
+ diskcache==5.6.3
27
+ distro==1.9.0
28
+ fastapi==0.110.0
29
+ filelock==3.13.1
30
+ flatbuffers==23.5.26
31
+ frozenlist==1.4.1
32
+ fsspec==2024.2.0
33
+ gitdb==4.0.11
34
+ GitPython==3.1.42
35
+ google-auth==2.28.1
36
+ googleapis-common-protos==1.62.0
37
+ greenlet==3.0.3
38
+ grpcio==1.62.0
39
+ h11==0.14.0
40
+ httpcore==1.0.4
41
+ httptools==0.6.1
42
+ httpx==0.27.0
43
+ huggingface-hub==0.20.3
44
+ humanfriendly==10.0
45
+ idna==3.6
46
+ importlib-metadata==6.11.0
47
+ importlib_resources==6.1.2
48
+ install==1.3.5
49
+ Jinja2==3.1.3
50
+ joblib==1.3.2
51
+ jsonschema==4.21.1
52
+ jsonschema-specifications==2023.12.1
53
+ kubernetes==29.0.0
54
+ llama-index==0.10.15
55
+ llama-index-agent-openai==0.1.5
56
+ llama-index-cli==0.1.7
57
+ llama-index-core==0.10.15
58
+ llama-index-embeddings-huggingface==0.1.4
59
+ llama-index-embeddings-openai==0.1.6
60
+ llama-index-indices-managed-llama-cloud==0.1.3
61
+ llama-index-legacy==0.9.48
62
+ llama-index-llms-huggingface==0.1.3
63
+ llama-index-llms-llama-cpp==0.1.3
64
+ llama-index-llms-openai==0.1.7
65
+ llama-index-multi-modal-llms-openai==0.1.4
66
+ llama-index-program-openai==0.1.4
67
+ llama-index-question-gen-openai==0.1.3
68
+ llama-index-readers-file==0.1.6
69
+ llama-index-readers-llama-parse==0.1.3
70
+ llama-index-vector-stores-chroma==0.1.5
71
+ llama-parse==0.3.5
72
+ llama_cpp_python==0.2.55
73
+ llamaindex-py-client==0.1.13
74
+ markdown-it-py==3.0.0
75
+ MarkupSafe==2.1.5
76
+ marshmallow==3.21.0
77
+ mdurl==0.1.2
78
+ mmh3==4.1.0
79
+ monotonic==1.6
80
+ mpmath==1.3.0
81
+ multidict==6.0.5
82
+ mypy-extensions==1.0.0
83
+ nest-asyncio==1.6.0
84
+ networkx==3.2.1
85
+ nltk==3.8.1
86
+ numpy==1.26.4
87
+ nvidia-cublas-cu12==12.1.3.1
88
+ nvidia-cuda-cupti-cu12==12.1.105
89
+ nvidia-cuda-nvrtc-cu12==12.1.105
90
+ nvidia-cuda-runtime-cu12==12.1.105
91
+ nvidia-cudnn-cu12==8.9.2.26
92
+ nvidia-cufft-cu12==11.0.2.54
93
+ nvidia-curand-cu12==10.3.2.106
94
+ nvidia-cusolver-cu12==11.4.5.107
95
+ nvidia-cusparse-cu12==12.1.0.106
96
+ nvidia-nccl-cu12==2.19.3
97
+ nvidia-nvjitlink-cu12==12.3.101
98
+ nvidia-nvtx-cu12==12.1.105
99
+ oauthlib==3.2.2
100
+ onnxruntime==1.17.1
101
+ openai==1.13.3
102
+ opentelemetry-api==1.23.0
103
+ opentelemetry-exporter-otlp-proto-common==1.23.0
104
+ opentelemetry-exporter-otlp-proto-grpc==1.23.0
105
+ opentelemetry-instrumentation==0.44b0
106
+ opentelemetry-instrumentation-asgi==0.44b0
107
+ opentelemetry-instrumentation-fastapi==0.44b0
108
+ opentelemetry-proto==1.23.0
109
+ opentelemetry-sdk==1.23.0
110
+ opentelemetry-semantic-conventions==0.44b0
111
+ opentelemetry-util-http==0.44b0
112
+ orjson==3.9.15
113
+ overrides==7.7.0
114
+ packaging==23.2
115
+ pandas==2.2.1
116
+ pillow==10.2.0
117
+ posthog==3.4.2
118
+ protobuf==4.25.3
119
+ psutil==5.9.8
120
+ pulsar-client==3.4.0
121
+ pyarrow==15.0.0
122
+ pyasn1==0.5.1
123
+ pyasn1-modules==0.3.0
124
+ pydantic==2.6.3
125
+ pydantic_core==2.16.3
126
+ pydeck==0.8.1b0
127
+ Pygments==2.17.2
128
+ PyMuPDF==1.23.26
129
+ PyMuPDFb==1.23.22
130
+ pypdf==4.1.0
131
+ PyPika==0.48.9
132
+ pyproject_hooks==1.0.0
133
+ python-dateutil==2.9.0.post0
134
+ python-dotenv==1.0.1
135
+ pytz==2024.1
136
+ PyYAML==6.0.1
137
+ referencing==0.33.0
138
+ regex==2023.12.25
139
+ requests==2.31.0
140
+ requests-oauthlib==1.3.1
141
+ rich==13.7.1
142
+ rpds-py==0.18.0
143
+ rsa==4.9
144
+ safetensors==0.4.2
145
+ scipy==1.12.0
146
+ setuptools==68.2.2
147
+ six==1.16.0
148
+ smmap==5.0.1
149
+ sniffio==1.3.1
150
+ soupsieve==2.5
151
+ SQLAlchemy==2.0.27
152
+ starlette==0.36.3
153
+ streamlit==1.31.1
154
+ sympy==1.12
155
+ tenacity==8.2.3
156
+ tiktoken==0.6.0
157
+ tokenizers==0.15.2
158
+ toml==0.10.2
159
+ toolz==0.12.1
160
+ torch==2.2.1
161
+ tornado==6.4
162
+ tqdm==4.66.2
163
+ transformers==4.38.2
164
+ typer==0.9.0
165
+ typing-inspect==0.9.0
166
+ typing_extensions==4.10.0
167
+ tzdata==2024.1
168
+ tzlocal==5.2
169
+ urllib3==2.2.1
170
+ uvicorn==0.27.1
171
+ uvloop==0.19.0
172
+ validators==0.22.0
173
+ watchdog==4.0.0
174
+ watchfiles==0.21.0
175
+ websocket-client==1.7.0
176
+ websockets==12.0
177
+ wheel==0.41.2
178
+ wrapt==1.16.0
179
+ yarl==1.9.4
180
+ zipp==3.17.0