pragneshbarik commited on
Commit
831e906
·
1 Parent(s): c80f6c0

inital commit

Browse files
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Win executables
10
+ *.exe
11
+ rag-env/
12
+ mixtral-playground/
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ software-properties-common \
9
+ git \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ COPY components /app/
13
+ COPY middlewares /app/
14
+ COPY app.py /app/
15
+ COPY requirements.txt /app/
16
+ COPY config.yaml /app/
17
+
18
+ RUN pip3 install -r requirements.txt
19
+
20
+ EXPOSE 8501
21
+
22
+ ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
Manage ADDED
File without changes
README.md CHANGED
@@ -1,13 +1,62 @@
1
  ---
2
- title: Mixtral Search
3
- emoji: 🏢
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: streamlit
7
- sdk_version: 1.30.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Ebook Gen
3
+ emoji: 🐢
4
+ colorFrom: pink
5
+ colorTo: gray
6
  sdk: streamlit
7
+ sdk_version: 1.29.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
+ # Mixtral Playground
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+
17
+ ## Python Setup
18
+
19
+ Follow these steps to set up the environment:
20
+
21
+ 1. Clone this repository.
22
+ 2. Create and activate a new virtual environment using `conda` or `venv`. If you need guidance, refer to this [tutorial](https://www.perplexity.ai/search/how-to-make-GnFc09yGTvSyka0ZWqhSQg?s=c).
23
+ 3. Create a `.env` file to store API credentials. You'll need these four credentials:
24
+
25
+ ```
26
+ HF_TOKEN = ...
27
+ GOOGLE_SEARCH_ENGINE_ID = ...
28
+ GOOGLE_SEARCH_API_KEY = ...
29
+ BING_SEARCH_API_KEY = ...
30
+ ```
31
+
32
+ 4. Install the necessary requirements:
33
+ ```
34
+ pip install -r requirements.txt --user
35
+ ```
36
+ 5. Start the Streamlit server using either command:
37
+ ```
38
+ streamlit run app.py
39
+ ```
40
+ or
41
+ ```
42
+ python -m streamlit run app.py
43
+ ```
44
+
45
+
46
+ ## Docker Setup
47
+
48
+ If you prefer using Docker, follow these steps:
49
+
50
+ 1. Clone the repository.
51
+ 2. Create a `.env` file to store API credentials, similar to the Python setup.
52
+
53
+ 3. Build docker image using
54
+
55
+ ```
56
+ docker build -t mixtral-playground .
57
+ ```
58
+ 4. Run the image using
59
+
60
+ ```
61
+ docker run --env-file .env -p 8501:8501 mixtral-playground
62
+ ```
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import streamlit as st
3
+ from components.sidebar import sidebar
4
+ from components.chat_box import chat_box
5
+ from components.chat_loop import chat_loop
6
+ from components.init_state import init_state
7
+ from components.prompt_engineering_dashboard import prompt_engineering_dashboard
8
+
9
+
10
+
11
+ with open("config.yaml", "r") as file:
12
+ config = yaml.safe_load(file)
13
+
14
+ st.set_page_config(
15
+ page_title="Mixtral Search Engine",
16
+ page_icon="📚",
17
+ )
18
+
19
+
20
+ init_state(st.session_state, config)
21
+
22
+ st.write("# Mixtral Search Engine")
23
+
24
+ # Prompt Engineering Dashboard is working but not for production, works great for testing.
25
+ # prompt_engineering_dashboard(st.session_state, config)
26
+
27
+
28
+ sidebar(st.session_state, config)
29
+
30
+ chat_box(st.session_state, config)
31
+
32
+ chat_loop(st.session_state, config)
components/__init__.py ADDED
File without changes
components/chat_box.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def chat_box(session_state, config):
5
+ for message in session_state.messages:
6
+ with st.chat_message(message["role"]):
7
+ st.markdown(message["content"])
components/chat_loop.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from components.generate_chat_stream import generate_chat_stream
3
+ from components.stream_handler import stream_handler
4
+ from components.show_source import show_source
5
+
6
+
7
+ def chat_loop(session_state, config):
8
+ if prompt := st.chat_input("Search the web..."):
9
+ st.chat_message("user").markdown(prompt)
10
+ session_state.messages.append({"role": "user", "content": prompt})
11
+
12
+ chat_stream, links = generate_chat_stream(session_state, prompt, config)
13
+
14
+ with st.chat_message("assistant"):
15
+ placeholder = st.empty()
16
+ full_response = stream_handler(
17
+ session_state, chat_stream, prompt, placeholder
18
+ )
19
+ if session_state.rag_enabled:
20
+ show_source(links)
21
+
22
+ session_state.history.append([prompt, full_response])
23
+ session_state.messages.append({"role": "assistant", "content": full_response})
components/generate_chat_stream.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from middlewares.utils import gen_augmented_prompt_via_websearch
3
+ from middlewares.chat_client import chat
4
+
5
+
6
+ def generate_chat_stream(session_state, prompt, config):
7
+ # 1. augments prompt according to the template
8
+ # 2. returns chat_stream and source links
9
+ # 3. chat_stream and source links are used by stream_handler and show_source
10
+ chat_bot_dict = config["CHAT_BOTS"]
11
+ links = []
12
+ if session_state.rag_enabled:
13
+ with st.spinner("Fetching relevent documents from Web...."):
14
+ prompt, links = gen_augmented_prompt_via_websearch(
15
+ prompt=prompt,
16
+ pre_context=session_state.pre_context,
17
+ post_context=session_state.post_context,
18
+ pre_prompt=session_state.pre_prompt,
19
+ post_prompt=session_state.post_prompt,
20
+ search_vendor=session_state.search_vendor,
21
+ top_k=session_state.top_k,
22
+ n_crawl=session_state.n_crawl,
23
+ pass_prev=session_state.pass_prev,
24
+ prev_output=session_state.history[-1][1],
25
+ )
26
+
27
+ with st.spinner("Generating response..."):
28
+ chat_stream = chat(
29
+ prompt,
30
+ session_state.history,
31
+ chat_client=chat_bot_dict[session_state.chat_bot],
32
+ temperature=session_state.temp,
33
+ max_new_tokens=session_state.max_tokens,
34
+ )
35
+
36
+ return chat_stream, links
components/init_state.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_state(session_state, config):
2
+ initial_prompt_engineering_dict = config["PROMPT_ENGINEERING_DICT"]
3
+ if "messages" not in session_state:
4
+ session_state.messages = []
5
+
6
+ if "tokens_used" not in session_state:
7
+ session_state.tokens_used = 0
8
+
9
+ if "tps" not in session_state:
10
+ session_state.tps = 0
11
+
12
+ if "temp" not in session_state:
13
+ session_state.temp = 0.8
14
+
15
+ if "history" not in session_state:
16
+ session_state.history = [
17
+ [
18
+ initial_prompt_engineering_dict["SYSTEM_INSTRUCTION"],
19
+ initial_prompt_engineering_dict["SYSTEM_RESPONSE"],
20
+ ]
21
+ ]
22
+
23
+ if "n_crawl" not in session_state:
24
+ session_state.n_crawl = 5
25
+
26
+ if "repetion_penalty" not in session_state:
27
+ session_state.repetion_penalty = 1
28
+
29
+ if "rag_enabled" not in session_state:
30
+ session_state.rag_enabled = True
31
+
32
+ if "chat_bot" not in session_state:
33
+ session_state.chat_bot = "Mixtral 8x7B v0.1"
34
+
35
+ if "search_vendor" not in session_state:
36
+ session_state.search_vendor = "Bing"
37
+
38
+ if "system_instruction" not in session_state:
39
+ session_state.system_instruction = initial_prompt_engineering_dict[
40
+ "SYSTEM_INSTRUCTION"
41
+ ]
42
+
43
+ if "system_response" not in session_state:
44
+ session_state.system_instruction = initial_prompt_engineering_dict[
45
+ "SYSTEM_RESPONSE"
46
+ ]
47
+
48
+ if "pre_context" not in session_state:
49
+ session_state.pre_context = initial_prompt_engineering_dict["PRE_CONTEXT"]
50
+
51
+ if "post_context" not in session_state:
52
+ session_state.post_context = initial_prompt_engineering_dict["POST_CONTEXT"]
53
+
54
+ if "pre_prompt" not in session_state:
55
+ session_state.pre_prompt = initial_prompt_engineering_dict["PRE_PROMPT"]
56
+
57
+ if "post_prompt" not in session_state:
58
+ session_state.post_prompt = initial_prompt_engineering_dict["POST_PROMPT"]
59
+
60
+ if "pass_prev" not in session_state:
61
+ session_state.pass_prev = False
62
+
63
+ if "chunk_size" not in session_state:
64
+ session_state.chunk_size = 512
components/prompt_engineering_dashboard.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def prompt_engineering_dashboard(session_state, config):
5
+ inital_prompt_engineering_dict = config["PROMPT_ENGINEERING_DICT"]
6
+
7
+ def engineer_prompt():
8
+ session_state.history[0] = [
9
+ session_state.system_instruction,
10
+ session_state.system_response,
11
+ ]
12
+
13
+ with st.expander("Prompt Engineering Dashboard"):
14
+ st.info(
15
+ "**The input to the model follows this below template**",
16
+ )
17
+ st.code(
18
+ """
19
+ [SYSTEM INSTRUCTION]
20
+ [SYSTEM RESPONSE]
21
+
22
+ [... LIST OF PREV INPUTS]
23
+
24
+ [PRE CONTEXT]
25
+ [CONTEXT RETRIEVED FROM THE WEB]
26
+ [POST CONTEXT]
27
+
28
+ [PRE PROMPT]
29
+ [PROMPT]
30
+ [POST PROMPT]
31
+ [PREV GENERATED INPUT] # Only if Pass previous prompt set True
32
+
33
+ """
34
+ )
35
+ session_state.system_instruction = st.text_area(
36
+ label="SYSTEM INSTRUCTION",
37
+ value=inital_prompt_engineering_dict["SYSTEM_INSTRUCTION"],
38
+ )
39
+ session_state.system_response = st.text_area(
40
+ "SYSTEM RESPONSE", value=inital_prompt_engineering_dict["SYSTEM_RESPONSE"]
41
+ )
42
+
43
+ col1, col2 = st.columns(2)
44
+ with col1:
45
+ session_state.pre_context = st.text_input(
46
+ "PRE CONTEXT",
47
+ value=inital_prompt_engineering_dict["PRE_CONTEXT"],
48
+ disabled=not session_state.rag_enabled,
49
+ )
50
+ session_state.post_context = st.text_input(
51
+ "POST CONTEXT",
52
+ value=inital_prompt_engineering_dict["POST_CONTEXT"],
53
+ disabled=not session_state.rag_enabled,
54
+ )
55
+
56
+ with col2:
57
+ session_state.pre_prompt = st.text_input(
58
+ "PRE PROMPT", value=inital_prompt_engineering_dict["PRE_PROMPT"]
59
+ )
60
+ session_state.post_prompt = st.text_input(
61
+ "POST PROMPT", value=inital_prompt_engineering_dict["POST_PROMPT"]
62
+ )
63
+
64
+ col3, col4 = st.columns(2)
65
+ with col3:
66
+ session_state.pass_prev = st.toggle("Pass previous Output")
67
+ with col4:
68
+ st.button("Engineer Prompts", on_click=engineer_prompt)
components/show_source.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def show_source(links):
5
+ # Expander component to show source
6
+ with st.expander("Show source"):
7
+ for i, link in enumerate(links):
8
+ st.info(f"{link}")
components/sidebar.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from components.sidebar_components.model_analytics import model_analytics
3
+ from components.sidebar_components.retrieval_settings import retrieval_settings
4
+ from components.sidebar_components.model_settings import model_settings
5
+
6
+
7
+ def sidebar(session_state, config):
8
+ with st.sidebar:
9
+ retrieval_settings(session_state, config)
10
+ model_analytics(session_state, config)
11
+ model_settings(session_state, config)
components/sidebar_components/__init__.py ADDED
File without changes
components/sidebar_components/model_analytics.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ def model_analytics(session_state, config):
3
+ COST_PER_1000_TOKENS_USD = config["COST_PER_1000_TOKENS_USD"]
4
+
5
+ st.markdown("# Model Analytics")
6
+
7
+ st.write("Total tokens used :", session_state["tokens_used"])
8
+ st.write("Speed :", session_state["tps"], " tokens/sec")
9
+ st.write(
10
+ "Total cost incurred :",
11
+ round(
12
+ COST_PER_1000_TOKENS_USD * session_state["tokens_used"] / 1000,
13
+ 3,
14
+ ),
15
+ "USD",
16
+ )
17
+
18
+ st.markdown("---")
components/sidebar_components/model_settings.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def model_settings(session_state,config):
4
+ CHAT_BOTS = config["CHAT_BOTS"]
5
+
6
+ st.markdown("# Model Settings")
7
+
8
+ session_state.chat_bot = st.sidebar.radio(
9
+ "Select one:", [key for key, _ in CHAT_BOTS.items()]
10
+ )
11
+ session_state.temp = st.slider(
12
+ label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9
13
+ )
14
+
15
+ session_state.max_tokens = st.slider(
16
+ label="New tokens to generate",
17
+ min_value=64,
18
+ max_value=2048,
19
+ step=32,
20
+ value=512,
21
+ )
22
+
23
+ session_state.repetion_penalty = st.slider(
24
+ label="Repetion Penalty", min_value=0.0, max_value=1.0, step=0.1, value=1.0
25
+ )
components/sidebar_components/retrieval_settings.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def retrieval_settings(session_state, config):
4
+ st.markdown("# Web Retrieval")
5
+ session_state.rag_enabled = st.toggle("Activate Web Retrieval", value=True)
6
+ session_state.search_vendor = st.radio(
7
+ "Select Search Vendor",
8
+ ["Bing", "Google"],
9
+ disabled=not session_state.rag_enabled,
10
+ )
11
+ session_state.n_crawl = st.slider(
12
+ label="Links to Crawl",
13
+ key=1,
14
+ min_value=1,
15
+ max_value=10,
16
+ value=4,
17
+ disabled=not session_state.rag_enabled,
18
+ )
19
+ session_state.top_k = st.slider(
20
+ label="Chunks to Retrieve via Reranker",
21
+ key=2,
22
+ min_value=1,
23
+ max_value=20,
24
+ value=5,
25
+ disabled=not session_state.rag_enabled,
26
+ )
27
+
28
+ session_state.chunk_size = st.slider(
29
+ label="Chunk Size",
30
+ value=512,
31
+ min_value=128,
32
+ max_value=1024,
33
+ step=8,
34
+ disabled=not session_state.rag_enabled,
35
+ )
36
+
37
+ st.markdown("---")
components/stream_handler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+
4
+ COST_PER_1000_TOKENS_USD = 0.139 / 80
5
+
6
+
7
+ def stream_handler(session_state, chat_stream, prompt, placeholder):
8
+ # 1. Uses the chat_stream and streams message on placeholder
9
+ # 2. returns full_response for token calculation
10
+ start_time = time.time()
11
+ full_response = ""
12
+
13
+ for chunk in chat_stream:
14
+ if chunk.token.text != "</s>":
15
+ full_response += chunk.token.text
16
+ placeholder.markdown(full_response + "▌")
17
+ placeholder.markdown(full_response)
18
+
19
+ end_time = time.time()
20
+ elapsed_time = end_time - start_time
21
+ total_tokens_processed = len(full_response.split())
22
+ tokens_per_second = total_tokens_processed // elapsed_time
23
+ len_response = (len(prompt.split()) + len(full_response.split())) * 1.25
24
+ col1, col2, col3 = st.columns(3)
25
+
26
+ with col1:
27
+ st.write(f"**{tokens_per_second} tokens/second**")
28
+
29
+ with col2:
30
+ st.write(f"**{int(len_response)} tokens generated**")
31
+
32
+ with col3:
33
+ st.write(
34
+ f"**$ {round(len_response * COST_PER_1000_TOKENS_USD / 1000, 5)} cost incurred**"
35
+ )
36
+
37
+ session_state["tps"] = tokens_per_second
38
+ session_state["tokens_used"] = len_response + session_state["tokens_used"]
39
+
40
+ return full_response
config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT_ENGINEERING_DICT:
2
+ SYSTEM_INSTRUCTION: |
3
+ YOU ARE A SEARCH ENGINE HAVING FULL ACCESS TO WEB PAGES, YOU GIVE EXTREMELY DETAILED AND ACCURATE INFORMATION ACCORDING TO USER PROMPTS.
4
+ SYSTEM_RESPONSE: |
5
+ Certainly! I'm here to help. What information are you looking for?
6
+ Please provide me with a specific topic or question, and I'll do my
7
+ best to provide you with detailed and accurate information.
8
+
9
+ PRE_CONTEXT: NOW YOU ARE SEARCHING THE WEB, AND HERE ARE THE CHUNKS RETRIEVED FROM THE WEB.
10
+ POST_CONTEXT: ""
11
+ PRE_PROMPT: NOW ACCORDING TO THE CONTEXT RETRIEVED FROM THE GENERATE THE CONTENT FOR THE FOLLOWING SUBJECT
12
+ POST_PROMPT: PRIORITIZE DATA, FACTS AND STATISTICS OVER PERSONAL EXPERIENCES AND OPINIONS, FOCUS MORE ON STATISTICS AND DATA.
13
+
14
+ CHAT_BOTS:
15
+ Mixtral 8x7B v0.1: mistralai/Mixtral-8x7B-Instruct-v0.1
16
+ Mistral 7B v0.1: mistralai/Mistral-7B-Instruct-v0.1
17
+ Mistral 7B v0.2: mistralai/Mistral-7B-Instruct-v0.2
18
+
19
+ COST_PER_1000_TOKENS_USD: 0.001737375
middlewares/chat_client.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ API_TOKEN = os.getenv("HF_TOKEN")
8
+
9
+
10
+ def format_prompt(message, history):
11
+ prompt = "<s>"
12
+ for user_prompt, bot_response in history:
13
+ prompt += f"[INST] {user_prompt} [/INST]"
14
+ prompt += f" {bot_response}</s> "
15
+ prompt += f"[INST] {message} [/INST]"
16
+ return prompt
17
+
18
+
19
+ def chat(
20
+ prompt,
21
+ history,
22
+ chat_client="mistralai/Mistral-7B-Instruct-v0.1",
23
+ temperature=0.9,
24
+ max_new_tokens=256,
25
+ top_p=0.95,
26
+ repetition_penalty=1.0,
27
+ truncate = False
28
+ ):
29
+
30
+ client = InferenceClient(chat_client, token=API_TOKEN)
31
+ temperature = float(temperature)
32
+ if temperature < 1e-2:
33
+ temperature = 1e-2
34
+ top_p = float(top_p)
35
+
36
+ generate_kwargs = dict(
37
+ temperature=temperature,
38
+ max_new_tokens=max_new_tokens,
39
+ top_p=top_p,
40
+ repetition_penalty=repetition_penalty,
41
+ do_sample=True,
42
+ seed=42,
43
+ )
44
+
45
+ formatted_prompt = format_prompt(prompt, history)
46
+
47
+ stream = client.text_generation(
48
+ formatted_prompt,
49
+ **generate_kwargs,
50
+ stream=True,
51
+ details=True,
52
+ return_full_text=False,
53
+ )
54
+
55
+ return stream
middlewares/search_client.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ import re
4
+ import concurrent.futures
5
+
6
+
7
+ class SearchClient:
8
+ def __init__(self, vendor, engine_id=None, api_key=None):
9
+ self.vendor = vendor
10
+ if vendor == "google":
11
+ self.endpoint = f"https://www.googleapis.com/customsearch/v1?key={api_key}&cx={engine_id}"
12
+ elif vendor == "bing":
13
+ self.endpoint = "https://api.bing.microsoft.com/v7.0/search"
14
+ self.headers = {
15
+ "Ocp-Apim-Subscription-Key": api_key,
16
+ }
17
+
18
+ @staticmethod
19
+ def _extract_text_from_link(link):
20
+ page = requests.get(link)
21
+ if page.status_code == 200:
22
+ soup = BeautifulSoup(page.content, "html.parser")
23
+ text = soup.get_text()
24
+ cleaned_text = re.sub(r"\s+", " ", text)
25
+ return cleaned_text
26
+ return None
27
+
28
+ def _fetch_text_from_links(self, links):
29
+ results = []
30
+ with concurrent.futures.ThreadPoolExecutor() as executor:
31
+ future_to_link = {
32
+ executor.submit(self._extract_text_from_link, link): link
33
+ for link in links
34
+ }
35
+ for future in concurrent.futures.as_completed(future_to_link):
36
+ link = future_to_link[future]
37
+ try:
38
+ cleaned_text = future.result()
39
+ if cleaned_text:
40
+ results.append({"text": cleaned_text, "link": link})
41
+ except Exception as e:
42
+ print(f"Error fetching data from {link}: {e}")
43
+ return results
44
+
45
+ def _google_search(self, query, n_crawl):
46
+ response = requests.get(self.endpoint, params={"q": query})
47
+ search_results = response.json()
48
+
49
+ results = []
50
+ count = 0
51
+ for item in search_results.get("items", []):
52
+ if count >= n_crawl:
53
+ break
54
+
55
+ link = item["link"]
56
+ results.append(link)
57
+ count += 1
58
+
59
+ text_results = self._fetch_text_from_links(results)
60
+ return text_results
61
+
62
+ def _bing_search(self, query, n_crawl):
63
+ params = {
64
+ "q": query,
65
+ "count": n_crawl, # You might need to adjust this based on Bing API requirements
66
+ "mkt": "en-US",
67
+ }
68
+ response = requests.get(self.endpoint, headers=self.headers, params=params)
69
+ search_results = response.json()
70
+
71
+ results = []
72
+ for item in search_results.get("webPages", {}).get("value", []):
73
+ link = item["url"]
74
+ results.append(link)
75
+
76
+ text_results = self._fetch_text_from_links(results)
77
+ return text_results
78
+
79
+ def search(self, query, n_crawl):
80
+ if self.vendor == "google":
81
+ return self._google_search(query, n_crawl)
82
+ elif self.vendor == "bing":
83
+ return self._bing_search(query, n_crawl)
84
+ else:
85
+ return "Invalid vendor"
middlewares/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+ import math
3
+ import numpy as np
4
+ from middlewares.search_client import SearchClient
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+
9
+ load_dotenv()
10
+
11
+
12
+ GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
13
+ GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
14
+ BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")
15
+
16
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
17
+
18
+ googleSearchClient = SearchClient(
19
+ "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
20
+ )
21
+ bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)
22
+
23
+
24
+ def rerank(query, top_k, search_results, chunk_size=512):
25
+ chunks = []
26
+ for result in search_results:
27
+ text = result["text"]
28
+ words = text.split()
29
+ num_chunks = math.ceil(len(words) / chunk_size)
30
+ for i in range(num_chunks):
31
+ start = i * chunk_size
32
+ end = (i + 1) * chunk_size
33
+ chunk = " ".join(words[start:end])
34
+ chunks.append((result["link"], chunk))
35
+
36
+ # Create sentence combinations with the query
37
+ sentence_combinations = [[query, chunk[1]] for chunk in chunks]
38
+
39
+ # Compute similarity scores for these combinations
40
+ similarity_scores = reranker.predict(sentence_combinations)
41
+
42
+ # Sort scores indexes in decreasing order
43
+ sim_scores_argsort = reversed(np.argsort(similarity_scores))
44
+
45
+ # Rearrange search_results based on the reranked scores
46
+ reranked_results = []
47
+ for idx in sim_scores_argsort:
48
+ link = chunks[idx][0]
49
+ chunk = chunks[idx][1]
50
+ reranked_results.append({"link": link, "text": chunk})
51
+
52
+ # Return the top K ranks
53
+ return reranked_results[:top_k]
54
+
55
+
56
+ def gen_augmented_prompt_via_websearch(
57
+ prompt,
58
+ search_vendor,
59
+ n_crawl,
60
+ top_k,
61
+ pre_context="",
62
+ post_context="",
63
+ pre_prompt="",
64
+ post_prompt="",
65
+ pass_prev=False,
66
+ prev_output="",
67
+ chunk_size=512,
68
+ ):
69
+
70
+
71
+ search_results = []
72
+ reranked_results = []
73
+ if search_vendor == "Google":
74
+ search_results = googleSearchClient.search(prompt, n_crawl)
75
+ elif search_vendor == "Bing":
76
+ search_results = bingSearchClient.search(prompt, n_crawl)
77
+
78
+ if len(search_results) > 0:
79
+ reranked_results = rerank(prompt, top_k, search_results, chunk_size)
80
+
81
+ links = []
82
+ context = ""
83
+ for res in reranked_results:
84
+ context += res["text"] + "\n\n"
85
+ link = res["link"]
86
+ links.append(link)
87
+
88
+ # remove duplicate links
89
+ links = list(set(links))
90
+
91
+ prev_output = prev_output if pass_prev else ""
92
+
93
+ augmented_prompt = f"""
94
+ {pre_context}
95
+
96
+ {context}
97
+
98
+ {post_context}
99
+
100
+ {pre_prompt}
101
+
102
+ {prompt}
103
+
104
+ {post_prompt}
105
+
106
+ {prev_output}
107
+
108
+ """
109
+ return augmented_prompt, links
requirements.txt ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.1.2
2
+ asttokens==2.2.1
3
+ attrs==23.1.0
4
+ backcall==0.2.0
5
+ beautifulsoup4==4.12.2
6
+ blinker==1.6.3
7
+ cachetools==5.3.1
8
+ certifi==2023.7.22
9
+ charset-normalizer==3.3.0
10
+ click==8.1.7
11
+ colorama==0.4.6
12
+ comm==0.1.3
13
+ debugpy==1.6.7
14
+ decorator==5.1.1
15
+ dnspython==2.4.2
16
+ executing==1.2.0
17
+ filelock==3.12.4
18
+ fsspec==2023.9.2
19
+ gitdb==4.0.10
20
+ GitPython==3.1.37
21
+ huggingface-hub==0.18.0
22
+ idna==3.4
23
+ importlib-metadata==6.8.0
24
+ ipykernel==6.23.3
25
+ ipython==8.14.0
26
+ jedi==0.18.2
27
+ Jinja2==3.1.2
28
+ joblib==1.3.2
29
+ jsonschema==4.19.1
30
+ jsonschema-specifications==2023.7.1
31
+ jupyter_client==8.3.0
32
+ jupyter_core==5.3.1
33
+ loguru==0.7.2
34
+ markdown-it-py==3.0.0
35
+ MarkupSafe==2.1.3
36
+ matplotlib-inline==0.1.6
37
+ mdurl==0.1.2
38
+ mpmath==1.3.0
39
+ nest-asyncio==1.5.6
40
+ networkx==3.2.1
41
+ nltk==3.8.1
42
+ numpy==1.26.0
43
+ packaging==23.1
44
+ pandas==2.1.1
45
+ parso==0.8.3
46
+ pickleshare==0.7.5
47
+ Pillow==10.0.1
48
+ platformdirs==3.8.0
49
+ prompt-toolkit==3.0.38
50
+ protobuf==4.24.4
51
+ psutil==5.9.5
52
+ pure-eval==0.2.2
53
+ pyarrow==13.0.0
54
+ pydeck==0.8.1b0
55
+ Pygments==2.15.1
56
+ python-dateutil==2.8.2
57
+ python-dotenv==1.0.0
58
+ pytz==2023.3.post1
59
+ PyYAML==6.0.1
60
+ pyzmq==25.1.0
61
+ referencing==0.30.2
62
+ regex==2023.10.3
63
+ requests==2.31.0
64
+ rich==13.6.0
65
+ rpds-py==0.10.4
66
+ safetensors==0.4.1
67
+ scikit-learn==1.3.2
68
+ scipy==1.11.4
69
+ sentence-transformers==2.2.2
70
+ sentencepiece==0.1.99
71
+ six==1.16.0
72
+ smmap==5.0.1
73
+ soupsieve==2.5
74
+ stack-data==0.6.2
75
+ streamlit==1.27.2
76
+ sympy==1.12
77
+ tenacity==8.2.3
78
+ threadpoolctl==3.2.0
79
+ tokenizers==0.15.0
80
+ toml==0.10.2
81
+ toolz==0.12.0
82
+ torch==2.1.2
83
+ torchvision==0.16.2
84
+ tornado==6.3.2
85
+ tqdm==4.66.1
86
+ traitlets==5.9.0
87
+ transformers==4.35.2
88
+ typing_extensions==4.8.0
89
+ tzdata==2023.3
90
+ tzlocal==5.1
91
+ urllib3==2.0.6
92
+ validators==0.22.0
93
+ watchdog==3.0.0
94
+ wcwidth==0.2.6
95
+ win32-setctime==1.1.0
96
+ zipp==3.17.0