ffeew commited on
Commit
f68c440
1 Parent(s): d9a5eeb
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gguf filter=lfs diff=lfs merge=lfs -text
37
+ chromadb/**/* filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils import st_load_retriever, st_load_llm, StreamHandler
3
+ from langchain.chains import RetrievalQAWithSourcesChain
4
+
5
+ st.title("AIxplorer - A Smarter Google Scholar 🌐📚")
6
+ st.write(
7
+ "AIxplorer aims to revolutionize academic research by combining the capabilities of traditional search engines like Google Scholar with an advanced retrieval augmented generation (RAG) system. Built on Python and Langchain, this application provides highly relevant and context-aware academic papers, journals, and articles, elevating the standard of academic research."
8
+ )
9
+
10
+
11
+ st.divider()
12
+ st.subheader("Settings")
13
+ col1, col2, col3 = st.columns(3)
14
+
15
+ with col1:
16
+ use_google = st.checkbox(
17
+ "Use Google Search",
18
+ value=True,
19
+ help="Use Google Search to retrieve papers. If unchecked, will use the vector database.",
20
+ )
21
+ st.divider()
22
+
23
+ llm = st_load_llm()
24
+ retriever = st_load_retriever(llm, "vectordb" if not use_google else "google search")
25
+
26
+ qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm, retriever=retriever)
27
+
28
+ user_input = st.text_area(
29
+ "Enter your query here",
30
+ help="Query should be on computer science as the RAG system is tuned to that domain.",
31
+ )
32
+
33
+
34
+ if st.button("Generate"):
35
+ st.divider()
36
+ st.subheader("Answer:")
37
+ with st.spinner("Generating..."):
38
+ container = st.empty()
39
+ stream_handler = StreamHandler(container)
40
+ response = qa_chain({"question": user_input}, callbacks=[stream_handler])
chromadb/ccdaf353-2f96-4472-a625-909323352d4d/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2fccfc026bd403c900c6d79e1a155e0d5c63c8c755d6f3ca371a34d8cfd03c7
3
+ size 946940000
chromadb/ccdaf353-2f96-4472-a625-909323352d4d/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74b87ee33bb338a0109cb825f43f7141e3feed200cbbc373fc7668e7795ae669
3
+ size 100
chromadb/ccdaf353-2f96-4472-a625-909323352d4d/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:405a464b08d974ed544ca62f07d73f3cd22c605c746bd196ea7914068a24271a
3
+ size 35773057
chromadb/ccdaf353-2f96-4472-a625-909323352d4d/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b957a02ed53e8b1db4e5bd8a91a0a2c0fbf08c8988ffa5aac911b2081a81b11e
3
+ size 2260000
chromadb/ccdaf353-2f96-4472-a625-909323352d4d/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b483c64357659a9b0a518dac20566ea4a10212e59b666033df8115e144f405d
3
+ size 4833528
chromadb/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac92ff6eeb03659dbceca7402c8107a3f44084bc753b28e6bac39c817c0712f6
3
+ size 5545873408
config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vector_db_path = "./chromadb" # path to the vector database
2
+
3
+ embeddings_model = "BAAI/bge-small-en" # embeddings model to use to generate vectors
4
+
5
+ llm_path = "./mistral-7b-openorca.Q5_K_M.gguf" # path to the LLM model
6
+
7
+ device = "cpu" # device to use for the LLM model, "cuda" or "cpu
8
+
9
+ n_gpu_layers = 0 # Change this value based on your model and your GPU VRAM pool. Change to 0 if you are using a CPU.
10
+
11
+ n_batch = 256 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
12
+
13
+ context_length = 8000 # length of the context to use for the LLM model
14
+
15
+ temperature = 0.0 # temperature to use for the LLM model
16
+
17
+ top_p = 1.0 # top_p to use for the LLM model
18
+
19
+ max_tokens = 2000 # maximum number of tokens to generate from the LLM model
mistral-7b-openorca.Q5_K_M.gguf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12a7c2d08be7c718a28c73115c321d91918a3fdef27de1da9f38b4079056773e
3
+ size 5131421440
requirements.txt ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.1
2
+ aiosignal==1.3.1
3
+ altair==5.2.0
4
+ annotated-types==0.6.0
5
+ anyio==3.7.1
6
+ asgiref==3.7.2
7
+ attrs==23.1.0
8
+ backoff==2.2.1
9
+ bcrypt==4.1.2
10
+ beautifulsoup4==4.12.2
11
+ blinker==1.7.0
12
+ cachetools==5.3.2
13
+ certifi==2023.11.17
14
+ charset-normalizer==3.3.2
15
+ chroma-hnswlib==0.7.3
16
+ chromadb==0.4.20
17
+ click==8.1.7
18
+ coloredlogs==15.0.1
19
+ dataclasses-json==0.6.3
20
+ Deprecated==1.2.14
21
+ diskcache==5.6.3
22
+ fastapi==0.105.0
23
+ filelock==3.13.1
24
+ flatbuffers==23.5.26
25
+ frozenlist==1.4.1
26
+ fsspec==2023.12.2
27
+ gitdb==4.0.11
28
+ GitPython==3.1.40
29
+ google-api-core==2.15.0
30
+ google-api-python-client==2.111.0
31
+ google-auth==2.25.2
32
+ google-auth-httplib2==0.2.0
33
+ googleapis-common-protos==1.62.0
34
+ greenlet==3.0.2
35
+ grpcio==1.60.0
36
+ h11==0.14.0
37
+ html2text==2020.1.16
38
+ httplib2==0.22.0
39
+ httptools==0.6.1
40
+ huggingface-hub==0.20.1
41
+ humanfriendly==10.0
42
+ idna==3.6
43
+ importlib-metadata==6.11.0
44
+ importlib-resources==6.1.1
45
+ Jinja2==3.1.2
46
+ joblib==1.3.2
47
+ jsonpatch==1.33
48
+ jsonpointer==2.4
49
+ jsonschema==4.20.0
50
+ jsonschema-specifications==2023.11.2
51
+ kubernetes==28.1.0
52
+ langchain==0.0.352
53
+ langchain-community==0.0.5
54
+ langchain-core==0.1.2
55
+ langsmith==0.0.72
56
+ llama_cpp_python==0.2.24
57
+ markdown-it-py==3.0.0
58
+ MarkupSafe==2.1.3
59
+ marshmallow==3.20.1
60
+ mdurl==0.1.2
61
+ mmh3==4.0.1
62
+ monotonic==1.6
63
+ mpmath==1.3.0
64
+ multidict==6.0.4
65
+ mypy-extensions==1.0.0
66
+ networkx==3.2.1
67
+ nltk==3.8.1
68
+ numpy==1.26.2
69
+ nvidia-cublas-cu12==12.1.3.1
70
+ nvidia-cuda-cupti-cu12==12.1.105
71
+ nvidia-cuda-nvrtc-cu12==12.1.105
72
+ nvidia-cuda-runtime-cu12==12.1.105
73
+ nvidia-cudnn-cu12==8.9.2.26
74
+ nvidia-cufft-cu12==11.0.2.54
75
+ nvidia-curand-cu12==10.3.2.106
76
+ nvidia-cusolver-cu12==11.4.5.107
77
+ nvidia-cusparse-cu12==12.1.0.106
78
+ nvidia-nccl-cu12==2.18.1
79
+ nvidia-nvjitlink-cu12==12.3.101
80
+ nvidia-nvtx-cu12==12.1.105
81
+ oauthlib==3.2.2
82
+ onnxruntime==1.16.3
83
+ opentelemetry-api==1.22.0
84
+ opentelemetry-exporter-otlp-proto-common==1.22.0
85
+ opentelemetry-exporter-otlp-proto-grpc==1.22.0
86
+ opentelemetry-instrumentation==0.43b0
87
+ opentelemetry-instrumentation-asgi==0.43b0
88
+ opentelemetry-instrumentation-fastapi==0.43b0
89
+ opentelemetry-proto==1.22.0
90
+ opentelemetry-sdk==1.22.0
91
+ opentelemetry-semantic-conventions==0.43b0
92
+ opentelemetry-util-http==0.43b0
93
+ overrides==7.4.0
94
+ packaging==23.2
95
+ pandas==2.1.4
96
+ Pillow==10.1.0
97
+ posthog==3.1.0
98
+ protobuf==4.25.1
99
+ pulsar-client==3.3.0
100
+ pyarrow==14.0.2
101
+ pyasn1==0.5.1
102
+ pyasn1-modules==0.3.0
103
+ pydantic==2.5.2
104
+ pydantic_core==2.14.5
105
+ pydeck==0.8.1b0
106
+ Pygments==2.17.2
107
+ pyparsing==3.1.1
108
+ PyPika==0.48.9
109
+ python-dateutil==2.8.2
110
+ python-dotenv==1.0.0
111
+ pytz==2023.3.post1
112
+ PyYAML==6.0.1
113
+ referencing==0.32.0
114
+ regex==2023.10.3
115
+ requests==2.31.0
116
+ requests-oauthlib==1.3.1
117
+ rich==13.7.0
118
+ rpds-py==0.15.2
119
+ rsa==4.9
120
+ safetensors==0.4.1
121
+ scikit-learn==1.3.2
122
+ scipy==1.11.4
123
+ sentence-transformers==2.2.2
124
+ sentencepiece==0.1.99
125
+ six==1.16.0
126
+ smmap==5.0.1
127
+ sniffio==1.3.0
128
+ soupsieve==2.5
129
+ SQLAlchemy==2.0.23
130
+ starlette==0.27.0
131
+ streamlit==1.29.0
132
+ sympy==1.12
133
+ tenacity==8.2.3
134
+ threadpoolctl==3.2.0
135
+ tiktoken==0.5.2
136
+ tokenizers==0.15.0
137
+ toml==0.10.2
138
+ toolz==0.12.0
139
+ torch==2.1.2
140
+ torchvision==0.16.2
141
+ tornado==6.4
142
+ tqdm==4.66.1
143
+ transformers==4.36.2
144
+ triton==2.1.0
145
+ typer==0.9.0
146
+ typing-inspect==0.9.0
147
+ typing_extensions==4.9.0
148
+ tzdata==2023.3
149
+ tzlocal==5.2
150
+ uritemplate==4.1.1
151
+ urllib3==1.26.18
152
+ uvicorn==0.24.0.post1
153
+ uvloop==0.19.0
154
+ validators==0.22.0
155
+ watchdog==3.0.0
156
+ watchfiles==0.21.0
157
+ websocket-client==1.7.0
158
+ websockets==12.0
159
+ wrapt==1.16.0
160
+ yarl==1.9.4
161
+ zipp==3.17.0
utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.llms import LlamaCpp
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.vectorstores import Chroma
5
+ from langchain.retrievers.web_research import WebResearchRetriever
6
+ from langchain.utilities import GoogleSearchAPIWrapper
7
+ from dotenv import load_dotenv
8
+ import config
9
+ from langchain.callbacks.base import BaseCallbackHandler
10
+
11
+
12
+ class StreamHandler(BaseCallbackHandler):
13
+ def __init__(self, container, initial_text=""):
14
+ self.container = container
15
+ self.text = initial_text
16
+
17
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
18
+ self.text += token
19
+ self.container.markdown(self.text)
20
+
21
+
22
+ @st.cache_resource
23
+ def st_load_retriever(_llm, mode):
24
+ model_kwargs = {"device": "cuda"}
25
+ embeddings_model = HuggingFaceEmbeddings(
26
+ model_name=config.embeddings_model,
27
+ model_kwargs=model_kwargs,
28
+ )
29
+
30
+ vector_store = Chroma(
31
+ "cs_paper_store",
32
+ embeddings_model,
33
+ persist_directory=config.vector_db_path,
34
+ )
35
+
36
+ if mode == "vectordb":
37
+ # load the vector store
38
+ return vector_store.as_retriever()
39
+
40
+ elif mode == "google search":
41
+ load_dotenv()
42
+ search = GoogleSearchAPIWrapper()
43
+ web_research_retriever = WebResearchRetriever.from_llm(
44
+ vectorstore=vector_store, llm=_llm, search=search
45
+ )
46
+ return web_research_retriever
47
+
48
+ else:
49
+ raise ValueError(f"Unknown retrieval mode: {mode}")
50
+
51
+
52
+ @st.cache_resource
53
+ def st_load_llm(
54
+ temperature=config.temperature,
55
+ max_tokens=config.max_tokens,
56
+ top_p=config.top_p,
57
+ llm_path=config.llm_path,
58
+ context_length=config.context_length,
59
+ n_gpu_layers=config.n_gpu_layers,
60
+ n_batch=config.n_batch,
61
+ ):
62
+ llm = LlamaCpp(
63
+ model_path=llm_path,
64
+ temperature=temperature,
65
+ max_tokens=max_tokens,
66
+ n_ctx=context_length,
67
+ n_gpu_layers=n_gpu_layers,
68
+ n_batch=n_batch,
69
+ top_p=top_p,
70
+ verbose=False,
71
+ )
72
+
73
+ return llm