Spaces:
Runtime error
Runtime error
Upload 24 files
Browse files- .gitattributes +1 -0
- Dockerfile +18 -0
- Home.py +46 -0
- data/1_IIEE_1_json_data_19_02_2024_22-17-49.json +0 -0
- pages/1_π_Busqueda_Aumentada.py +377 -0
- pages/2_π£_Busqueda_Conversacional.py +576 -0
- pages/__init__.py +0 -0
- requirements.txt +18 -0
- static/.DS_Store +0 -0
- static/images/cervezas-mahou.jpeg +0 -0
- static/images/fabrica-mahou-1200x675.jpeg +0 -0
- static/images/openai_logo.png +0 -0
- static/images/openai_logo_circle.png +0 -0
- static/images/openai_purple_logo_hres.jpeg +0 -0
- static/images/screen_recording_busqueda_final_2.gif +3 -0
- utils/.DS_Store +0 -0
- utils/__init__.py +0 -0
- utils/app_features_spa.py +177 -0
- utils/openai_interface_spa.py +95 -0
- utils/preprocessing.py +123 -0
- utils/prompt_templates_spa.py +26 -0
- utils/reranker_spa.py +89 -0
- utils/retrieval_evaluation_spa.py +332 -0
- utils/system_prompts.py +72 -0
- utils/weaviate_interface_v3_spa.py +436 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
static/images/screen_recording_busqueda_final_2.gif filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use an official Python runtime as a parent image
|
2 |
+
FROM python:3.11-slim
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Install any needed packages specified in requirements.txt
|
8 |
+
COPY requirements.txt /app/
|
9 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
10 |
+
|
11 |
+
# Copy the rest of your application's code
|
12 |
+
COPY . /app
|
13 |
+
|
14 |
+
# Make port 8501 available to the world outside this container
|
15 |
+
EXPOSE 7860
|
16 |
+
|
17 |
+
# Run app.py when the container launches, use environment variables
|
18 |
+
CMD ["streamlit", "run", "Home.py", "--server.address=0.0.0.0", "--server.port=8501"]
|
Home.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import base64
|
3 |
+
|
4 |
+
## PAGE CONFIGURATION
|
5 |
+
st.set_page_config(page_title="BΓΊsqueda Aumentada MSM para Impuestos Especiales",
|
6 |
+
page_icon="π",
|
7 |
+
layout="centered",
|
8 |
+
initial_sidebar_state="auto",
|
9 |
+
menu_items=None)
|
10 |
+
|
11 |
+
st.image('./static/images/cervezas-mahou.jpeg', width=700,)
|
12 |
+
|
13 |
+
# Mensaje de bienvenida
|
14 |
+
st.markdown(
|
15 |
+
"""
|
16 |
+
# Β‘Bienvenido a BΓΊsqueda Aumentada MSM para Impuestos Especiales! ππ
|
17 |
+
|
18 |
+
Esta aplicaciΓ³n es una herramienta diseΓ±ada especΓficamente para la exploraciΓ³n y anΓ‘lisis de datos en el Γ‘mbito de Impuestos Especiales utilizando el poder de la Inteligencia Artificial.
|
19 |
+
|
20 |
+
**π Selecciona una opciΓ³n en la barra lateral** para comenzar a explorar las diferentes funcionalidades que ofrece la aplicaciΓ³n.
|
21 |
+
""")
|
22 |
+
file_ = open('./static/images/screen_recording_busqueda_final_2.gif', "rb")
|
23 |
+
contents = file_.read()
|
24 |
+
data_url = base64.b64encode(contents).decode("utf-8")
|
25 |
+
file_.close()
|
26 |
+
|
27 |
+
st.subheader("Uso de la AplicaciΓ³n: π Busqueda Aumentada")
|
28 |
+
st.caption("Observa en acciΓ³n cΓ³mo la busqueda aumentada con una potente IA simplifica la bΓΊsqueda de informaciΓ³n, todo con una interfaz de usuario facΓl de usar.")
|
29 |
+
st.markdown(
|
30 |
+
f'<div style="text-align: center;"><img src="data:image/gif;base64,{data_url}" alt="demo gif" style="max-width: 100%; height: auto;"></div>',
|
31 |
+
unsafe_allow_html=True,
|
32 |
+
)
|
33 |
+
|
34 |
+
st.markdown("""
|
35 |
+
|
36 |
+
### ΒΏQuieres aprender mΓ‘s?
|
37 |
+
- Visita nuestra [pΓ‘gina web](https://tupagina.com)
|
38 |
+
- SumΓ©rgete en nuestra [documentaciΓ³n](https://tudocumentacion.com)
|
39 |
+
- Participa y pregunta en nuestros [foros comunitarios](https://tucomunidad.com)
|
40 |
+
|
41 |
+
### Explora demos mΓ‘s complejos
|
42 |
+
- Descubre cΓ³mo aplicamos la IA para [analizar datasets especializados](https://tulinkdedataset.com)
|
43 |
+
- Explora [bases de datos de acceso pΓΊblico](https://tulinkdedatasetpublico.com) y ve la IA en acciΓ³n
|
44 |
+
""",
|
45 |
+
unsafe_allow_html=True
|
46 |
+
)
|
data/1_IIEE_1_json_data_19_02_2024_22-17-49.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pages/1_π_Busqueda_Aumentada.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tiktoken import get_encoding, encoding_for_model
|
2 |
+
from utils.weaviate_interface_v3_spa import WeaviateClient, WhereFilter
|
3 |
+
from templates.prompt_templates_spa import question_answering_prompt_series_spa
|
4 |
+
from utils.openai_interface_spa import GPT_Turbo
|
5 |
+
from openai import BadRequestError
|
6 |
+
from utils.app_features_spa import (convert_seconds, generate_prompt_series, search_result,
|
7 |
+
validate_token_threshold, load_content_cache, load_data, expand_content)
|
8 |
+
from utils.reranker_spa import ReRanker
|
9 |
+
from loguru import logger
|
10 |
+
import streamlit as st
|
11 |
+
import os
|
12 |
+
|
13 |
+
# load environment variables
|
14 |
+
from dotenv import load_dotenv
|
15 |
+
load_dotenv('.env', override=True)
|
16 |
+
|
17 |
+
## PAGE CONFIGURATION
|
18 |
+
st.set_page_config(page_title="Busqueda Aumentada",
|
19 |
+
page_icon="π",
|
20 |
+
layout="wide",
|
21 |
+
initial_sidebar_state="auto",
|
22 |
+
menu_items=None)
|
23 |
+
|
24 |
+
## DATA + CACHE
|
25 |
+
data_path = 'data/1_IIEE_1_json_data_19_02_2024_22-17-49.json'
|
26 |
+
cache_path = ''
|
27 |
+
data = load_data(data_path)
|
28 |
+
cache = None # Initialize cache as None
|
29 |
+
|
30 |
+
# Check if the cache file exists before attempting to load it
|
31 |
+
if os.path.exists(cache_path):
|
32 |
+
cache = load_content_cache(cache_path)
|
33 |
+
else:
|
34 |
+
logger.warning(f"Cache file {cache_path} not found. Proceeding without cache.")
|
35 |
+
|
36 |
+
#creates list of guests for sidebar
|
37 |
+
guest_list = sorted(list(set([d['document_title'] for d in data])))
|
38 |
+
|
39 |
+
with st.sidebar:
|
40 |
+
st.subheader("Selecciona tu Base de datos ποΈ")
|
41 |
+
client_type = st.radio(
|
42 |
+
"Selecciona el modo de acceso:",
|
43 |
+
('Cloud', 'Local'),
|
44 |
+
help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarΓ‘s tu bΓΊsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu mΓ‘quina.'
|
45 |
+
)
|
46 |
+
if client_type == 'Cloud':
|
47 |
+
api_key = st.secrets['WEAVIATE_CLOUD_API_KEY']
|
48 |
+
url = st.secrets['WEAVIATE_CLOUD_ENDPOINT']
|
49 |
+
|
50 |
+
weaviate_client = WeaviateClient(
|
51 |
+
endpoint=url,
|
52 |
+
api_key=api_key,
|
53 |
+
# model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
|
54 |
+
model_name_or_path="intfloat/multilingual-e5-small",
|
55 |
+
# openai_api_key=os.environ['OPENAI_API_KEY']
|
56 |
+
)
|
57 |
+
available_classes=sorted(weaviate_client.show_classes())
|
58 |
+
logger.info(available_classes)
|
59 |
+
logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
|
60 |
+
elif client_type == 'Local':
|
61 |
+
url = st.secrets['WEAVIATE_LOCAL_ENDPOINT']
|
62 |
+
weaviate_client = WeaviateClient(
|
63 |
+
endpoint=url,
|
64 |
+
# api_key=api_key,
|
65 |
+
# model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
|
66 |
+
model_name_or_path="intfloat/multilingual-e5-small",
|
67 |
+
# openai_api_key=os.environ['OPENAI_API_KEY']
|
68 |
+
)
|
69 |
+
available_classes=sorted(weaviate_client.show_classes())
|
70 |
+
logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
|
71 |
+
|
72 |
+
def main():
|
73 |
+
|
74 |
+
# Define the available user selected options
|
75 |
+
available_models = ['gpt-3.5-turbo', 'gpt-4-1106-preview']
|
76 |
+
# Define system prompts
|
77 |
+
|
78 |
+
# Initialize selected options in session state
|
79 |
+
if "openai_data_model" not in st.session_state:
|
80 |
+
st.session_state["openai_data_model"] = available_models[0]
|
81 |
+
|
82 |
+
if 'class_name' not in st.session_state:
|
83 |
+
st.session_state['class_name'] = None
|
84 |
+
|
85 |
+
with st.sidebar:
|
86 |
+
st.session_state['class_name'] = st.selectbox(
|
87 |
+
label='Repositorio:',
|
88 |
+
options=available_classes,
|
89 |
+
index=None,
|
90 |
+
placeholder='Repositorio',
|
91 |
+
help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarΓ‘s tu bΓΊsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu mΓ‘quina.'
|
92 |
+
)
|
93 |
+
# Check if the collection name has been selected
|
94 |
+
class_name = st.session_state['class_name']
|
95 |
+
if class_name:
|
96 |
+
st.success(f"Repositorio seleccionado β
: {st.session_state['class_name']}")
|
97 |
+
|
98 |
+
else:
|
99 |
+
st.warning("ποΈ No olvides seleccionar el repositorio π a consultar ποΈ.")
|
100 |
+
st.stop() # Stop execution of the script
|
101 |
+
|
102 |
+
model_choice = st.selectbox(
|
103 |
+
label="Elige un modelo de OpenAI",
|
104 |
+
options=available_models,
|
105 |
+
index= available_models.index(st.session_state["openai_data_model"]),
|
106 |
+
help='Escoge entre diferentes modelos de OpenAI para generar respuestas a tus consultas. Cada modelo tiene distintas capacidades y limitaciones.'
|
107 |
+
)
|
108 |
+
st.sidebar.make_llm_call = st.checkbox(
|
109 |
+
label="Activar GPT",
|
110 |
+
help='Marca esta casilla para activar la generaciΓ³n de texto con GPT. Esto te permitirΓ‘ obtener respuestas automΓ‘ticas a tus consultas.'
|
111 |
+
)
|
112 |
+
|
113 |
+
with st.expander("Filtros de Busqueda"):
|
114 |
+
guest_input = st.selectbox(
|
115 |
+
label='SelecciΓ³n de documentos',
|
116 |
+
options=guest_list,
|
117 |
+
index=None,
|
118 |
+
placeholder='Documento',
|
119 |
+
help='Elige un documento especΓfico del repositorio para afinar tu bΓΊsqueda a datos relevantes.'
|
120 |
+
)
|
121 |
+
|
122 |
+
with st.expander("Parametros de Busqueda"):
|
123 |
+
retriever_choice = st.selectbox(
|
124 |
+
label="Selecciona un mΓ©todo",
|
125 |
+
options=["Hybrid", "Vector", "Keyword"],
|
126 |
+
help='Determina el mΓ©todo de recuperaciΓ³n de informaciΓ³n: "Hybrid" combina bΓΊsqueda por palabras clave y por similitud semΓ‘ntica, "Vector" usa embeddings de texto para encontrar coincidencias semΓ‘nticas, y "Keyword" realiza una bΓΊsqueda tradicional por palabras clave.'
|
127 |
+
)
|
128 |
+
|
129 |
+
reranker_enabled = st.checkbox(
|
130 |
+
label="Activar Reranker",
|
131 |
+
value=True,
|
132 |
+
help='Activa esta opciΓ³n para ordenar los resultados de la bΓΊsqueda segΓΊn su relevancia, utilizando un modelo de reordenamiento adicional.'
|
133 |
+
)
|
134 |
+
|
135 |
+
alpha_input = st.slider(
|
136 |
+
label='Alpha para motor hibrido',
|
137 |
+
min_value=0.00,
|
138 |
+
max_value=1.00,
|
139 |
+
value=0.40,
|
140 |
+
step=0.05,
|
141 |
+
help='Ajusta el parΓ‘metro alfa para equilibrar los resultados entre los mΓ©todos de bΓΊsqueda por vector y por palabra clave en el motor hΓbrido.'
|
142 |
+
)
|
143 |
+
|
144 |
+
retrieval_limit = st.slider(
|
145 |
+
label='Resultados a Reranker',
|
146 |
+
min_value=10,
|
147 |
+
max_value=300,
|
148 |
+
value=100,
|
149 |
+
step=10,
|
150 |
+
help='Establece el nΓΊmero de resultados que se recuperarΓ‘n antes de aplicar el reordenamiento.'
|
151 |
+
)
|
152 |
+
|
153 |
+
top_k_limit = st.slider(
|
154 |
+
label='Top K Limit',
|
155 |
+
min_value=1,
|
156 |
+
max_value=5,
|
157 |
+
value=3,
|
158 |
+
step=1,
|
159 |
+
help='Define el nΓΊmero mΓ‘ximo de resultados a mostrar despuΓ©s de aplicar el reordenamiento.'
|
160 |
+
)
|
161 |
+
|
162 |
+
temperature_input = st.slider(
|
163 |
+
label='Temperatura',
|
164 |
+
min_value=0.0,
|
165 |
+
max_value=1.0,
|
166 |
+
value=0.10,
|
167 |
+
step=0.10,
|
168 |
+
help='Ajusta la temperatura para la generaciΓ³n de texto con GPT, lo que influirΓ‘ en la creatividad de las respuestas.'
|
169 |
+
)
|
170 |
+
|
171 |
+
logger.info(weaviate_client.display_properties)
|
172 |
+
|
173 |
+
def perform_search(client, retriever_choice, query, class_name, search_limit, guest_filter, display_properties, alpha_input):
|
174 |
+
if retriever_choice == "Keyword":
|
175 |
+
return weaviate_client.keyword_search(
|
176 |
+
request=query,
|
177 |
+
class_name=class_name,
|
178 |
+
limit=search_limit,
|
179 |
+
where_filter=guest_filter,
|
180 |
+
display_properties=display_properties
|
181 |
+
), "Resultados de la Busqueda - Motor: Keyword: "
|
182 |
+
elif retriever_choice == "Vector":
|
183 |
+
return weaviate_client.vector_search(
|
184 |
+
request=query,
|
185 |
+
class_name=class_name,
|
186 |
+
limit=search_limit,
|
187 |
+
where_filter=guest_filter,
|
188 |
+
display_properties=display_properties
|
189 |
+
), "Resultados de la Busqueda - Motor: Vector"
|
190 |
+
elif retriever_choice == "Hybrid":
|
191 |
+
return weaviate_client.hybrid_search(
|
192 |
+
request=query,
|
193 |
+
class_name=class_name,
|
194 |
+
alpha=alpha_input,
|
195 |
+
limit=search_limit,
|
196 |
+
properties=["content"],
|
197 |
+
where_filter=guest_filter,
|
198 |
+
display_properties=display_properties
|
199 |
+
), "Resultados de la Busqueda - Motor: Hybrid"
|
200 |
+
|
201 |
+
|
202 |
+
## RERANKER
|
203 |
+
reranker = ReRanker(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2')
|
204 |
+
|
205 |
+
## LLM
|
206 |
+
model_name = model_choice
|
207 |
+
llm = GPT_Turbo(model=model_name, api_key=st.secrets['OPENAI_API_KEY'])
|
208 |
+
encoding = encoding_for_model(model_name)
|
209 |
+
|
210 |
+
|
211 |
+
########################
|
212 |
+
## SETUP MAIN DISPLAY ##
|
213 |
+
########################
|
214 |
+
st.image('./static/images/cervezas-mahou.jpeg', width=300)
|
215 |
+
st.subheader(f"β¨ππ **BΓΊsqueda Aumentada** ππβ¨ Impuestos Especiales ")
|
216 |
+
st.caption("Descubre insights ocultos y responde a tus preguntas especializadas utilizando el poder de la IA")
|
217 |
+
st.write('\n')
|
218 |
+
|
219 |
+
query = st.text_input('Escribe tu pregunta aquΓ: ')
|
220 |
+
st.write('\n\n\n\n\n')
|
221 |
+
|
222 |
+
############
|
223 |
+
## SEARCH ##
|
224 |
+
############
|
225 |
+
if query:
|
226 |
+
# make hybrid call to weaviate
|
227 |
+
guest_filter = WhereFilter(
|
228 |
+
path=['document_title'],
|
229 |
+
operator='Equal',
|
230 |
+
valueText=guest_input).todict() if guest_input else None
|
231 |
+
|
232 |
+
|
233 |
+
# Determine the appropriate limit based on reranking
|
234 |
+
search_limit = retrieval_limit if reranker_enabled else top_k_limit
|
235 |
+
|
236 |
+
# Perform the search
|
237 |
+
query_response, subheader_msg = perform_search(
|
238 |
+
client=weaviate_client,
|
239 |
+
retriever_choice=retriever_choice,
|
240 |
+
query=query,
|
241 |
+
class_name=class_name,
|
242 |
+
search_limit=search_limit,
|
243 |
+
guest_filter=guest_filter,
|
244 |
+
display_properties=weaviate_client.display_properties,
|
245 |
+
alpha_input=alpha_input if retriever_choice == "Hybrid" else None
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
# Rerank the results if enabled
|
250 |
+
if reranker_enabled:
|
251 |
+
search_results = reranker.rerank(
|
252 |
+
results=query_response,
|
253 |
+
query=query,
|
254 |
+
apply_sigmoid=True,
|
255 |
+
top_k=top_k_limit
|
256 |
+
)
|
257 |
+
subheader_msg += " Reranked"
|
258 |
+
else:
|
259 |
+
# Use the results directly if reranking is not enabled
|
260 |
+
search_results = query_response
|
261 |
+
|
262 |
+
logger.info(search_results)
|
263 |
+
expanded_response = expand_content(search_results, cache, content_key='doc_id', create_new_list=True)
|
264 |
+
|
265 |
+
# validate token count is below threshold
|
266 |
+
token_threshold = 8000 if model_name == 'gpt-3.5-turbo-16k' else 3500
|
267 |
+
valid_response = validate_token_threshold(
|
268 |
+
ranked_results=expanded_response,
|
269 |
+
base_prompt=question_answering_prompt_series_spa,
|
270 |
+
query=query,
|
271 |
+
tokenizer=encoding,
|
272 |
+
token_threshold=token_threshold,
|
273 |
+
verbose=True
|
274 |
+
)
|
275 |
+
logger.info(valid_response)
|
276 |
+
#########
|
277 |
+
## LLM ##
|
278 |
+
#########
|
279 |
+
make_llm_call = st.sidebar.make_llm_call
|
280 |
+
# prep for streaming response
|
281 |
+
st.subheader("Respuesta GPT:")
|
282 |
+
with st.spinner('Generando Respuesta...'):
|
283 |
+
st.markdown("----")
|
284 |
+
# Creates container for LLM response
|
285 |
+
chat_container, response_box = [], st.empty()
|
286 |
+
|
287 |
+
# generate LLM prompt
|
288 |
+
prompt = generate_prompt_series(query=query, results=valid_response)
|
289 |
+
# logger.info(prompt)
|
290 |
+
if make_llm_call:
|
291 |
+
|
292 |
+
try:
|
293 |
+
for resp in llm.get_chat_completion(
|
294 |
+
prompt=prompt,
|
295 |
+
temperature=temperature_input,
|
296 |
+
max_tokens=350, # expand for more verbose answers
|
297 |
+
show_response=True,
|
298 |
+
stream=True):
|
299 |
+
|
300 |
+
# inserts chat stream from LLM
|
301 |
+
with response_box:
|
302 |
+
content = resp.choices[0].delta.content
|
303 |
+
if content:
|
304 |
+
chat_container.append(content)
|
305 |
+
result = "".join(chat_container).strip()
|
306 |
+
st.write(f'{result}')
|
307 |
+
except BadRequestError:
|
308 |
+
logger.info('Making request with smaller context...')
|
309 |
+
valid_response = validate_token_threshold(
|
310 |
+
ranked_results=search_results,
|
311 |
+
base_prompt=question_answering_prompt_series_spa,
|
312 |
+
query=query,
|
313 |
+
tokenizer=encoding,
|
314 |
+
token_threshold=token_threshold,
|
315 |
+
verbose=True
|
316 |
+
)
|
317 |
+
|
318 |
+
# generate LLM prompt
|
319 |
+
prompt = generate_prompt_series(query=query, results=valid_response)
|
320 |
+
for resp in llm.get_chat_completion(
|
321 |
+
prompt=prompt,
|
322 |
+
temperature=temperature_input,
|
323 |
+
max_tokens=350, # expand for more verbose answers
|
324 |
+
show_response=True,
|
325 |
+
stream=True):
|
326 |
+
|
327 |
+
try:
|
328 |
+
# inserts chat stream from LLM
|
329 |
+
with response_box:
|
330 |
+
content = resp.choices[0].delta.content
|
331 |
+
if content:
|
332 |
+
chat_container.append(content)
|
333 |
+
result = "".join(chat_container).strip()
|
334 |
+
st.write(f'{result}')
|
335 |
+
except Exception as e:
|
336 |
+
print(e)
|
337 |
+
|
338 |
+
####################
|
339 |
+
## Search Results ##
|
340 |
+
####################
|
341 |
+
st.subheader(subheader_msg)
|
342 |
+
for i, hit in enumerate(search_results):
|
343 |
+
col1, col2 = st.columns([7, 3], gap='large')
|
344 |
+
page_url = hit['page_url']
|
345 |
+
page_label = hit['page_label']
|
346 |
+
document_title = hit['document_title']
|
347 |
+
# Assuming 'page_summary' is available and you want to display it
|
348 |
+
page_summary = hit.get('page_summary', 'Summary not available')
|
349 |
+
|
350 |
+
with col1:
|
351 |
+
st.markdown(f'''
|
352 |
+
<span style="color: #3498db; font-size: 19px; font-weight: bold;">{document_title}</span><br>
|
353 |
+
{page_summary}
|
354 |
+
[**PaΜgina:** {page_label}]({page_url})
|
355 |
+
''', unsafe_allow_html=True)
|
356 |
+
|
357 |
+
with st.expander("π Clic aquΓ para ver contexto:"):
|
358 |
+
try:
|
359 |
+
content = hit['content']
|
360 |
+
st.write(content)
|
361 |
+
except Exception as e:
|
362 |
+
st.write(f"Error displaying content: {e}")
|
363 |
+
|
364 |
+
# with col2:
|
365 |
+
# # If you have an image or want to display a placeholder image
|
366 |
+
# image = "URL_TO_A_PLACEHOLDER_IMAGE" # Replace with a relevant image URL if applicable
|
367 |
+
# st.image(image, caption=document_title, width=200, use_column_width=False)
|
368 |
+
# st.markdown(f'''
|
369 |
+
# <p style="text-align: right;">
|
370 |
+
# <b>Document Title:</b> {document_title}<br>
|
371 |
+
# <b>File Name:</b> {file_name}<br>
|
372 |
+
# </p>''', unsafe_allow_html=True)
|
373 |
+
|
374 |
+
|
375 |
+
|
376 |
+
if __name__ == '__main__':
|
377 |
+
main()
|
pages/2_π£_Busqueda_Conversacional.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tiktoken import get_encoding, encoding_for_model
|
2 |
+
from utils.weaviate_interface_v3_spa import WeaviateClient, WhereFilter
|
3 |
+
from templates.prompt_templates_spa import question_answering_prompt_series_spa
|
4 |
+
from utils.openai_interface_spa import GPT_Turbo
|
5 |
+
from openai import BadRequestError
|
6 |
+
from utils.app_features_spa import (convert_seconds, generate_prompt_series, search_result,
|
7 |
+
validate_token_threshold, load_content_cache, load_data, expand_content)
|
8 |
+
from utils.reranker_spa import ReRanker
|
9 |
+
from openai import OpenAI
|
10 |
+
|
11 |
+
from loguru import logger
|
12 |
+
import streamlit as st
|
13 |
+
import os
|
14 |
+
import templates.system_prompts as system_prompts
|
15 |
+
import base64
|
16 |
+
import json
|
17 |
+
|
18 |
+
# load environment variables
|
19 |
+
from dotenv import load_dotenv
|
20 |
+
load_dotenv('.env', override=True)
|
21 |
+
|
22 |
+
## PAGE CONFIGURATION
|
23 |
+
st.set_page_config(page_title="Busqueda Conversacional",
|
24 |
+
page_icon="π£",
|
25 |
+
layout="wide",
|
26 |
+
initial_sidebar_state="auto",
|
27 |
+
menu_items=None)
|
28 |
+
|
29 |
+
def encode_image(uploaded_file):
|
30 |
+
return base64.b64encode(uploaded_file.getvalue()).decode('utf-8')
|
31 |
+
|
32 |
+
## DATA + CACHE
|
33 |
+
data_path = 'data/1_IIEE_1_json_data_19_02_2024_22-17-49.json'
|
34 |
+
cache_path = ''
|
35 |
+
data = load_data(data_path)
|
36 |
+
cache = None # Initialize cache as None
|
37 |
+
|
38 |
+
# Check if the cache file exists before attempting to load it
|
39 |
+
if os.path.exists(cache_path):
|
40 |
+
cache = load_content_cache(cache_path)
|
41 |
+
else:
|
42 |
+
logger.warning(f"Cache file {cache_path} not found. Proceeding without cache.")
|
43 |
+
|
44 |
+
#creates list of guests for sidebar
|
45 |
+
guest_list = sorted(list(set([d['document_title'] for d in data])))
|
46 |
+
|
47 |
+
with st.sidebar:
|
48 |
+
st.subheader("Selecciona tu Base de datos ποΈ")
|
49 |
+
client_type = st.radio(
|
50 |
+
"Selecciona el modo de acceso:",
|
51 |
+
('Cloud', 'Local'),
|
52 |
+
help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarΓ‘s tu bΓΊsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu mΓ‘quina.'
|
53 |
+
)
|
54 |
+
if client_type == 'Cloud':
|
55 |
+
api_key = st.secrets['WEAVIATE_CLOUD_API_KEY']
|
56 |
+
url = st.secrets['WEAVIATE_CLOUD_ENDPOINT']
|
57 |
+
|
58 |
+
weaviate_client = WeaviateClient(
|
59 |
+
endpoint=url,
|
60 |
+
api_key=api_key,
|
61 |
+
# model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
|
62 |
+
model_name_or_path="intfloat/multilingual-e5-small",
|
63 |
+
# openai_api_key=os.environ['OPENAI_API_KEY']
|
64 |
+
)
|
65 |
+
available_classes=sorted(weaviate_client.show_classes())
|
66 |
+
logger.info(available_classes)
|
67 |
+
logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
|
68 |
+
elif client_type == 'Local':
|
69 |
+
url = st.secrets['WEAVIATE_LOCAL_ENDPOINT']
|
70 |
+
weaviate_client = WeaviateClient(
|
71 |
+
endpoint=url,
|
72 |
+
# api_key=api_key,
|
73 |
+
# model_name_or_path='./models/finetuned-all-MiniLM-L6-v2-300',
|
74 |
+
model_name_or_path="intfloat/multilingual-e5-small",
|
75 |
+
# openai_api_key=os.environ['OPENAI_API_KEY']
|
76 |
+
)
|
77 |
+
available_classes=sorted(weaviate_client.show_classes())
|
78 |
+
logger.info(f"Endpoint: {client_type} | Classes: {available_classes}")
|
79 |
+
|
80 |
+
client = OpenAI(api_key=st.secrets["OPENAI_API_KEY"])
|
81 |
+
|
82 |
+
def main():
|
83 |
+
|
84 |
+
# Define the available user selected options
|
85 |
+
available_models = ['gpt-3.5-turbo', 'gpt-4-1106-preview']
|
86 |
+
# Define system prompts
|
87 |
+
system_prompt_list = ["π€ChatGPT","π§πΎββοΈProfessor Synapse", "π©πΌβπΌMarketing Jane"]
|
88 |
+
|
89 |
+
|
90 |
+
# Initialize selected options in session state
|
91 |
+
if "openai_data_model" not in st.session_state:
|
92 |
+
st.session_state["openai_data_model"] = available_models[0]
|
93 |
+
if "system_prompt_data_list" not in st.session_state and "system_prompt_data_model" not in st.session_state:
|
94 |
+
# This should be the emoji string the user selected
|
95 |
+
st.session_state["system_prompt_data_list"] = system_prompt_list[0]
|
96 |
+
# Now we get the corresponding prompt variable using the selected emoji string
|
97 |
+
st.session_state["system_prompt_data_model"] = system_prompts.prompt_mapping[system_prompt_list[0]]
|
98 |
+
|
99 |
+
# logger.debug(f"Assistant: {st.session_state['system_prompt_sync_list']}")
|
100 |
+
# logger.debug(f"System Prompt: {st.session_state['system_prompt_sync_model']}")
|
101 |
+
|
102 |
+
if 'class_name' not in st.session_state:
|
103 |
+
st.session_state['class_name'] = None
|
104 |
+
|
105 |
+
with st.sidebar:
|
106 |
+
st.session_state['class_name'] = st.selectbox(
|
107 |
+
label='Repositorio:',
|
108 |
+
options=available_classes,
|
109 |
+
index=None,
|
110 |
+
placeholder='Repositorio',
|
111 |
+
help='Elige un repositorio para determinar el conjunto de datos sobre el cual realizarΓ‘s tu bΓΊsqueda. "Cloud" te permite acceder a datos alojados en nuestros servidores seguros, mientras que "Local" es para trabajar con datos alojados localmente en tu mΓ‘quina.'
|
112 |
+
)
|
113 |
+
|
114 |
+
# Check if the collection name has been selected
|
115 |
+
class_name = st.session_state['class_name']
|
116 |
+
if class_name:
|
117 |
+
st.success(f"Repositorio seleccionado β
: {st.session_state['class_name']}")
|
118 |
+
|
119 |
+
else:
|
120 |
+
st.warning("ποΈ No olvides seleccionar el repositorio π a consultar ποΈ.")
|
121 |
+
st.stop() # Stop execution of the script
|
122 |
+
|
123 |
+
model_choice = st.selectbox(
|
124 |
+
label="Elige un modelo de OpenAI",
|
125 |
+
options=available_models,
|
126 |
+
index= available_models.index(st.session_state["openai_data_model"]),
|
127 |
+
help='Escoge entre diferentes modelos de OpenAI para generar respuestas a tus consultas. Cada modelo tiene distintas capacidades y limitaciones.'
|
128 |
+
)
|
129 |
+
|
130 |
+
system_prompt = st.selectbox(
|
131 |
+
label="Elige un asistente",
|
132 |
+
options=system_prompt_list,
|
133 |
+
index=system_prompt_list.index(st.session_state["system_prompt_data_list"]),
|
134 |
+
)
|
135 |
+
|
136 |
+
with st.expander("Filtros de Busqueda"):
|
137 |
+
guest_input = st.selectbox(
|
138 |
+
label='SelecciΓ³n de Documento',
|
139 |
+
options=guest_list,
|
140 |
+
index=None,
|
141 |
+
placeholder='Documentos',
|
142 |
+
help='Elige un documento especΓfico del repositorio para afinar tu bΓΊsqueda a datos relevantes.'
|
143 |
+
)
|
144 |
+
with st.expander("Parametros de Busqueda"):
|
145 |
+
retriever_choice = st.selectbox(
|
146 |
+
label="Selecciona un mΓ©todo",
|
147 |
+
options=["Hybrid", "Vector", "Keyword"],
|
148 |
+
help='Determina el mΓ©todo de recuperaciΓ³n de informaciΓ³n: "Hybrid" combina bΓΊsqueda por palabras clave y por similitud semΓ‘ntica, "Vector" usa embeddings de texto para encontrar coincidencias semΓ‘nticas, y "Keyword" realiza una bΓΊsqueda tradicional por palabras clave.'
|
149 |
+
)
|
150 |
+
|
151 |
+
reranker_enabled = st.checkbox(
|
152 |
+
label="Activar Reranker",
|
153 |
+
value=True,
|
154 |
+
help='Activa esta opciΓ³n para ordenar los resultados de la bΓΊsqueda segΓΊn su relevancia, utilizando un modelo de reordenamiento adicional.'
|
155 |
+
)
|
156 |
+
|
157 |
+
alpha_input = st.slider(
|
158 |
+
label='Alpha para motor hibrido',
|
159 |
+
min_value=0.00,
|
160 |
+
max_value=1.00,
|
161 |
+
value=0.40,
|
162 |
+
step=0.05,
|
163 |
+
help='Ajusta el parΓ‘metro alfa para equilibrar los resultados entre los mΓ©todos de bΓΊsqueda por vector y por palabra clave en el motor hΓbrido.'
|
164 |
+
)
|
165 |
+
|
166 |
+
retrieval_limit = st.slider(
|
167 |
+
label='Resultados a Reranker',
|
168 |
+
min_value=10,
|
169 |
+
max_value=300,
|
170 |
+
value=100,
|
171 |
+
step=10,
|
172 |
+
help='Establece el nΓΊmero de resultados que se recuperarΓ‘n antes de aplicar el reordenamiento.'
|
173 |
+
)
|
174 |
+
|
175 |
+
top_k_limit = st.slider(
|
176 |
+
label='Top K Limit',
|
177 |
+
min_value=1,
|
178 |
+
max_value=5,
|
179 |
+
value=3,
|
180 |
+
step=1,
|
181 |
+
help='Define el nΓΊmero mΓ‘ximo de resultados a mostrar despuΓ©s de aplicar el reordenamiento.'
|
182 |
+
)
|
183 |
+
|
184 |
+
temperature_input = st.slider(
|
185 |
+
label='Temperatura',
|
186 |
+
min_value=0.0,
|
187 |
+
max_value=1.0,
|
188 |
+
value=0.20,
|
189 |
+
step=0.10,
|
190 |
+
help='Ajusta la temperatura para la generaciΓ³n de texto con GPT, lo que influirΓ‘ en la creatividad de las respuestas.'
|
191 |
+
)
|
192 |
+
|
193 |
+
# Update the model choice in session state
|
194 |
+
if st.session_state["openai_data_model"]!=model_choice:
|
195 |
+
st.session_state["openai_data_model"] = model_choice
|
196 |
+
logger.info(f"Data model: {st.session_state['openai_data_model']}")
|
197 |
+
|
198 |
+
# Update the system prompt choice in session state
|
199 |
+
if st.session_state["system_prompt_data_list"] != system_prompt:
|
200 |
+
# This should be the emoji string the user selected
|
201 |
+
st.session_state["system_prompt_data_list"] = system_prompt
|
202 |
+
# Now we get the corresponding prompt variable using the selected emoji string
|
203 |
+
selected_prompt_variable = system_prompts.prompt_mapping[system_prompt]
|
204 |
+
st.session_state['system_prompt_data_model'] = selected_prompt_variable
|
205 |
+
# logger.info(f"System Prompt: {selected_prompt_variable}")
|
206 |
+
logger.info(f"Assistant: {st.session_state['system_prompt_data_list']}")
|
207 |
+
# logger.info(f"System Prompt: {st.session_state['system_prompt_sync_model']}")
|
208 |
+
|
209 |
+
logger.info(weaviate_client.display_properties)
|
210 |
+
|
211 |
+
def database_search(query):
|
212 |
+
# Determine the appropriate limit based on reranking
|
213 |
+
search_limit = retrieval_limit if reranker_enabled else top_k_limit
|
214 |
+
|
215 |
+
# make hybrid call to weaviate
|
216 |
+
guest_filter = WhereFilter(
|
217 |
+
path=['document_title'],
|
218 |
+
operator='Equal',
|
219 |
+
valueText=guest_input).todict() if guest_input else None
|
220 |
+
|
221 |
+
try:
|
222 |
+
# Perform the search based on retriever_choice
|
223 |
+
if retriever_choice == "Keyword":
|
224 |
+
query_results = weaviate_client.keyword_search(
|
225 |
+
request=query,
|
226 |
+
class_name=class_name,
|
227 |
+
limit=search_limit,
|
228 |
+
where_filter=guest_filter
|
229 |
+
)
|
230 |
+
elif retriever_choice == "Vector":
|
231 |
+
query_results = weaviate_client.vector_search(
|
232 |
+
request=query,
|
233 |
+
class_name=class_name,
|
234 |
+
limit=search_limit,
|
235 |
+
where_filter=guest_filter
|
236 |
+
)
|
237 |
+
elif retriever_choice == "Hybrid":
|
238 |
+
query_results = weaviate_client.hybrid_search(
|
239 |
+
request=query,
|
240 |
+
class_name=class_name,
|
241 |
+
alpha=alpha_input,
|
242 |
+
limit=search_limit,
|
243 |
+
properties=["content"],
|
244 |
+
where_filter=guest_filter
|
245 |
+
)
|
246 |
+
else:
|
247 |
+
return json.dumps({"error": "Invalid retriever choice"})
|
248 |
+
|
249 |
+
|
250 |
+
## RERANKER
|
251 |
+
reranker = ReRanker(model_name='cross-encoder/ms-marco-MiniLM-L-12-v2')
|
252 |
+
model_name = model_choice
|
253 |
+
encoding = encoding_for_model(model_name)
|
254 |
+
|
255 |
+
# Rerank the results if enabled
|
256 |
+
if reranker_enabled:
|
257 |
+
search_results = reranker.rerank(
|
258 |
+
results=query_results,
|
259 |
+
query=query,
|
260 |
+
apply_sigmoid=True,
|
261 |
+
top_k=top_k_limit
|
262 |
+
)
|
263 |
+
|
264 |
+
else:
|
265 |
+
# Use the results directly if reranking is not enabled
|
266 |
+
search_results = query_results
|
267 |
+
|
268 |
+
# logger.debug(search_results)
|
269 |
+
# Save search results to session state for later use
|
270 |
+
# st.session_state['search_results'] = search_results
|
271 |
+
add_to_search_history(query=query, search_results=search_results)
|
272 |
+
expanded_response = expand_content(search_results, cache, content_key='doc_id', create_new_list=True)
|
273 |
+
|
274 |
+
# validate token count is below threshold
|
275 |
+
token_threshold = 8000
|
276 |
+
valid_response = validate_token_threshold(
|
277 |
+
ranked_results=expanded_response,
|
278 |
+
base_prompt=question_answering_prompt_series_spa,
|
279 |
+
query=query,
|
280 |
+
tokenizer=encoding,
|
281 |
+
token_threshold=token_threshold,
|
282 |
+
verbose=True
|
283 |
+
)
|
284 |
+
|
285 |
+
# generate LLM prompt
|
286 |
+
prompt = generate_prompt_series(query=query, results=valid_response)
|
287 |
+
|
288 |
+
# If the strings in 'prompt' are double-escaped, decode them before dumping to JSON
|
289 |
+
# prompt_decoded = prompt.encode().decode('unicode_escape')
|
290 |
+
|
291 |
+
# Then, when you dump to JSON, it should no longer double-escape the characters
|
292 |
+
return json.dumps({
|
293 |
+
"query": query,
|
294 |
+
"Search Results": prompt,
|
295 |
+
}, ensure_ascii=False)
|
296 |
+
|
297 |
+
except Exception as e:
|
298 |
+
# Handle any exceptions and return a JSON formatted error message
|
299 |
+
return json.dumps({
|
300 |
+
"error": "An error occurred during the search",
|
301 |
+
"details": str(e)
|
302 |
+
})
|
303 |
+
|
304 |
+
# When a new message is added, include the type and content
|
305 |
+
def add_to_search_history(query, search_results):
|
306 |
+
st.session_state["data_search_history"].append({
|
307 |
+
"query": query,
|
308 |
+
"search_results": search_results,
|
309 |
+
})
|
310 |
+
|
311 |
+
# Function to display search results
|
312 |
+
def display_search_results():
|
313 |
+
# Loop through each item in the search history
|
314 |
+
for search in st.session_state['data_search_history']:
|
315 |
+
query = search["query"]
|
316 |
+
search_results = search["search_results"]
|
317 |
+
# Create an expander for each search query
|
318 |
+
with st.expander(f"Pregunta: {query}", expanded=False):
|
319 |
+
for i, hit in enumerate(search_results):
|
320 |
+
# col1, col2 = st.columns([7, 3], gap='large')
|
321 |
+
page_url = hit['page_url']
|
322 |
+
page_label = hit['page_label']
|
323 |
+
document_title = hit['document_title']
|
324 |
+
# Assuming 'page_summary' is available and you want to display it
|
325 |
+
page_summary = hit.get('page_summary', 'Summary not available')
|
326 |
+
|
327 |
+
# with col1:
|
328 |
+
st.markdown(f'''
|
329 |
+
<span style="color: #3498db; font-size: 19px; font-weight: bold;">{document_title}</span><br>
|
330 |
+
{page_summary}
|
331 |
+
[**PaΜgina:** {page_label}]({page_url})
|
332 |
+
''', unsafe_allow_html=True)
|
333 |
+
|
334 |
+
# with st.expander("π Clic aquΓ para ver contexto:"):
|
335 |
+
# try:
|
336 |
+
# content = hit['content']
|
337 |
+
# st.write(content)
|
338 |
+
# except Exception as e:
|
339 |
+
# st.write(f"Error displaying content: {e}")
|
340 |
+
|
341 |
+
# with col2:
|
342 |
+
# # If you have an image or want to display a placeholder image
|
343 |
+
# image = "URL_TO_A_PLACEHOLDER_IMAGE" # Replace with a relevant image URL if applicable
|
344 |
+
# st.image(image, caption=document_title, width=200, use_column_width=False)
|
345 |
+
# st.markdown(f'''
|
346 |
+
# <p style="text-align: right;">
|
347 |
+
# <b>Document Title:</b> {document_title}<br>
|
348 |
+
# <b>File Name:</b> {file_name}<br>
|
349 |
+
# </p>''', unsafe_allow_html=True)
|
350 |
+
|
351 |
+
########################
|
352 |
+
## SETUP MAIN DISPLAY ##
|
353 |
+
########################
|
354 |
+
|
355 |
+
st.image('./static/images/cervezas-mahou.jpeg', width=400)
|
356 |
+
st.subheader(f"β¨π£οΈπ **BΓΊsqueda Conversacional** π‘π£οΈβ¨ - Impuestos Especiales")
|
357 |
+
st.write('\n')
|
358 |
+
col1, col2 = st.columns([50,50])
|
359 |
+
|
360 |
+
# Initialize chat history
|
361 |
+
if "data_chat_history" not in st.session_state:
|
362 |
+
st.session_state["data_chat_history"] = []
|
363 |
+
|
364 |
+
if "data_search_history" not in st.session_state:
|
365 |
+
st.session_state["data_search_history"] = []
|
366 |
+
|
367 |
+
with col1:
|
368 |
+
st.write("Chat History:")
|
369 |
+
# Create a container for chat history
|
370 |
+
chat_history_container = st.container(height=500, border=True)
|
371 |
+
# Display chat messages from history on app rerun
|
372 |
+
with chat_history_container:
|
373 |
+
for message in st.session_state["data_chat_history"]:
|
374 |
+
with st.chat_message(message["role"]):
|
375 |
+
st.markdown(message["content"])
|
376 |
+
# Function to update chat display
|
377 |
+
def update_chat_display():
|
378 |
+
with chat_history_container:
|
379 |
+
for message in st.session_state["data_chat_history"]:
|
380 |
+
with st.chat_message(message["role"]):
|
381 |
+
st.markdown(message["content"])
|
382 |
+
|
383 |
+
if prompt := st.chat_input("What is up?"):
|
384 |
+
# Add user message to chat history
|
385 |
+
st.session_state["data_chat_history"].append({"role": "user", "content": prompt})
|
386 |
+
# Initially display the chat history
|
387 |
+
update_chat_display()
|
388 |
+
# # Display user message in chat message container
|
389 |
+
# with st.chat_message("user"):
|
390 |
+
# st.markdown(prompt)
|
391 |
+
|
392 |
+
with st.spinner('Generando Respuesta...'):
|
393 |
+
tools = [
|
394 |
+
{
|
395 |
+
"type": "function",
|
396 |
+
"function": {
|
397 |
+
"name": "database_search",
|
398 |
+
"description": "Takes the users query about the database and returns the results, extracting info to answer the user's question",
|
399 |
+
"parameters": {
|
400 |
+
"type": "object",
|
401 |
+
"properties": {
|
402 |
+
"query": {"type": "string", "description": "query"},
|
403 |
+
|
404 |
+
},
|
405 |
+
"required": ["query"],
|
406 |
+
},
|
407 |
+
}
|
408 |
+
}
|
409 |
+
]
|
410 |
+
|
411 |
+
# Display live assistant response in chat message container
|
412 |
+
with st.chat_message(
|
413 |
+
name="assistant",
|
414 |
+
avatar="./static/images/openai_purple_logo_hres.jpeg"):
|
415 |
+
message_placeholder = st.empty()
|
416 |
+
|
417 |
+
# Building the messages payload with proper OPENAI API structure
|
418 |
+
messages=[
|
419 |
+
{"role": "system", "content": st.session_state["system_prompt_data_model"]}
|
420 |
+
] + [
|
421 |
+
{"role": m["role"], "content": m["content"]} for m in st.session_state["data_chat_history"]
|
422 |
+
]
|
423 |
+
logger.debug(f"Initial Messages: {messages}")
|
424 |
+
# call the OpenAI API to get the response
|
425 |
+
|
426 |
+
RESPONSE = client.chat.completions.create(
|
427 |
+
model=st.session_state["openai_data_model"],
|
428 |
+
temperature=0.5,
|
429 |
+
messages=messages,
|
430 |
+
tools=tools,
|
431 |
+
tool_choice="auto", # auto is default, but we'll be explicit
|
432 |
+
stream=True
|
433 |
+
)
|
434 |
+
logger.debug(f"First Response: {RESPONSE}")
|
435 |
+
|
436 |
+
|
437 |
+
FULL_RESPONSE = ""
|
438 |
+
tool_calls = []
|
439 |
+
# build up the response structs from the streamed response, simultaneously sending message chunks to the browser
|
440 |
+
for chunk in RESPONSE:
|
441 |
+
delta = chunk.choices[0].delta
|
442 |
+
# logger.debug(f"chunk: {delta}")
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
+
if delta and delta.content:
|
447 |
+
text_chunk = delta.content
|
448 |
+
FULL_RESPONSE += str(text_chunk)
|
449 |
+
message_placeholder.markdown(FULL_RESPONSE + "β")
|
450 |
+
|
451 |
+
elif delta and delta.tool_calls:
|
452 |
+
tcchunklist = delta.tool_calls
|
453 |
+
for tcchunk in tcchunklist:
|
454 |
+
if len(tool_calls) <= tcchunk.index:
|
455 |
+
tool_calls.append({"id": "", "type": "function", "function": { "name": "", "arguments": "" } })
|
456 |
+
tc = tool_calls[tcchunk.index]
|
457 |
+
|
458 |
+
if tcchunk.id:
|
459 |
+
tc["id"] += tcchunk.id
|
460 |
+
if tcchunk.function.name:
|
461 |
+
tc["function"]["name"] += tcchunk.function.name
|
462 |
+
if tcchunk.function.arguments:
|
463 |
+
tc["function"]["arguments"] += tcchunk.function.arguments
|
464 |
+
if tool_calls:
|
465 |
+
logger.debug(f"tool_calls: {tool_calls}")
|
466 |
+
# Define a dictionary mapping function names to actual functions
|
467 |
+
available_functions = {
|
468 |
+
"database_search": database_search,
|
469 |
+
# Add other functions as necessary
|
470 |
+
}
|
471 |
+
available_functions = {
|
472 |
+
"database_search": database_search,
|
473 |
+
} # only one function in this example, but you can have multiple
|
474 |
+
logger.debug(f"FuncCall Before messages: {messages}")
|
475 |
+
# Process each tool call
|
476 |
+
for tool_call in tool_calls:
|
477 |
+
# Get the function name and arguments from the tool call
|
478 |
+
function_name = tool_call['function']['name']
|
479 |
+
function_args = json.loads(tool_call['function']['arguments'])
|
480 |
+
|
481 |
+
# Get the actual function to call
|
482 |
+
function_to_call = available_functions[function_name]
|
483 |
+
|
484 |
+
# Call the function and get the response
|
485 |
+
function_response = function_to_call(**function_args)
|
486 |
+
|
487 |
+
# Append the function response to the messages list
|
488 |
+
messages.append({
|
489 |
+
"role": "assistant",
|
490 |
+
"tool_call_id": tool_call['id'],
|
491 |
+
"name": function_name,
|
492 |
+
"content": function_response,
|
493 |
+
})
|
494 |
+
logger.debug(f"FuncCall After messages: {messages}")
|
495 |
+
|
496 |
+
RESPONSE = client.chat.completions.create(
|
497 |
+
model=st.session_state["openai_data_model"],
|
498 |
+
temperature=0.1,
|
499 |
+
messages=messages,
|
500 |
+
stream=True
|
501 |
+
)
|
502 |
+
logger.debug(f"Second Response: {RESPONSE}")
|
503 |
+
|
504 |
+
# build up the response structs from the streamed response, simultaneously sending message chunks to the browser
|
505 |
+
for chunk in RESPONSE:
|
506 |
+
delta = chunk.choices[0].delta
|
507 |
+
# logger.debug(f"chunk: {delta}")
|
508 |
+
|
509 |
+
if delta and delta.content:
|
510 |
+
text_chunk = delta.content
|
511 |
+
FULL_RESPONSE += str(text_chunk)
|
512 |
+
message_placeholder.markdown(FULL_RESPONSE + "β")
|
513 |
+
# Add assistant response to chat history
|
514 |
+
st.session_state["data_chat_history"].append({"role": "assistant", "content": FULL_RESPONSE})
|
515 |
+
logger.debug(f"chat_history: {st.session_state['data_chat_history']}")
|
516 |
+
|
517 |
+
# Next block of code...
|
518 |
+
|
519 |
+
|
520 |
+
####################
|
521 |
+
## Search Results ##
|
522 |
+
####################
|
523 |
+
# st.subheader(subheader_msg)
|
524 |
+
with col2:
|
525 |
+
st.write("Search Results:")
|
526 |
+
with st.container(height=500, border=True):
|
527 |
+
# Check if 'data_search_history' is in the session state and not empty
|
528 |
+
if 'data_search_history' in st.session_state and st.session_state['data_search_history']:
|
529 |
+
display_search_results()
|
530 |
+
# # Extract the latest message from the search history
|
531 |
+
# latest_search = st.session_state['data_search_history'][-1]
|
532 |
+
# query = latest_search["query"]
|
533 |
+
# with st.expander(query, expanded=False):
|
534 |
+
# # Extract the latest message from the search history
|
535 |
+
# latest_search = st.session_state['data_search_history'][-1]
|
536 |
+
# query = latest_search["query"]
|
537 |
+
# for i, hit in enumerate(latest_search["search_results"]):
|
538 |
+
# col1, col2 = st.columns([7, 3], gap='large')
|
539 |
+
# episode_url = hit['episode_url']
|
540 |
+
# title = hit['title']
|
541 |
+
# guest=hit['guest']
|
542 |
+
# show_length = hit['length']
|
543 |
+
# time_string = convert_seconds(show_length)
|
544 |
+
# # content = ranked_response[i]['content'] # Get 'content' from the same index in ranked_response
|
545 |
+
# content = hit['content']
|
546 |
+
|
547 |
+
# with col1:
|
548 |
+
# st.write( search_result(i=i,
|
549 |
+
# url=episode_url,
|
550 |
+
# guest=guest,
|
551 |
+
# title=title,
|
552 |
+
# content=content,
|
553 |
+
# length=time_string),
|
554 |
+
# unsafe_allow_html=True)
|
555 |
+
# st.write('\n\n')
|
556 |
+
|
557 |
+
# # with st.container("Episode Summary:"):
|
558 |
+
# # try:
|
559 |
+
# # ep_summary = hit['summary']
|
560 |
+
# # st.write(ep_summary)
|
561 |
+
# # except Exception as e:
|
562 |
+
# # st.error(f"Error displaying summary: {e}")
|
563 |
+
|
564 |
+
# with col2:
|
565 |
+
# image = hit['thumbnail_url']
|
566 |
+
# st.image(image, caption=title.split('|')[0], width=200, use_column_width=False)
|
567 |
+
# st.markdown(f'''
|
568 |
+
# <p style="text-align: right;">
|
569 |
+
# <b>Episode:</b> {title.split('|')[0]}<br>
|
570 |
+
# <b>Guest:</b> {hit['guest']}<br>
|
571 |
+
# <b>Length:</b> {time_string}
|
572 |
+
# </p>''', unsafe_allow_html=True)
|
573 |
+
|
574 |
+
|
575 |
+
if __name__ == '__main__':
|
576 |
+
main()
|
pages/__init__.py
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
loguru==0.7.0
|
2 |
+
numpy==1.24.4
|
3 |
+
openai==1.10.0
|
4 |
+
pandas==2.0.3
|
5 |
+
protobuf==4.23.4
|
6 |
+
pyarrow==12.0.1
|
7 |
+
python-dotenv==1.0.0
|
8 |
+
rank-bm25==0.2.2
|
9 |
+
requests==2.31.0
|
10 |
+
requests-oauthlib==1.3.1
|
11 |
+
rich==13.7.0
|
12 |
+
sentence-transformers==2.2.2
|
13 |
+
streamlit==1.31.1
|
14 |
+
tiktoken==0.5.1
|
15 |
+
tokenizers==0.13.3
|
16 |
+
torch==2.0.1
|
17 |
+
transformers==4.33.1
|
18 |
+
weaviate-client==3.25.3
|
static/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
static/images/cervezas-mahou.jpeg
ADDED
static/images/fabrica-mahou-1200x675.jpeg
ADDED
static/images/openai_logo.png
ADDED
static/images/openai_logo_circle.png
ADDED
static/images/openai_purple_logo_hres.jpeg
ADDED
static/images/screen_recording_busqueda_final_2.gif
ADDED
Git LFS Details
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/__init__.py
ADDED
File without changes
|
utils/app_features_spa.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import json
|
3 |
+
from utils.preprocessing import FileIO
|
4 |
+
from typing import List, Optional
|
5 |
+
import tiktoken
|
6 |
+
from loguru import logger
|
7 |
+
from templates.prompt_templates_spa import context_block_spa, question_answering_prompt_series_spa
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
+
@st.cache_data
|
11 |
+
def load_content_cache(data_path: str):
|
12 |
+
data = FileIO().load_parquet(data_path)
|
13 |
+
content_data = {d['doc_id']: d['content'] for d in data}
|
14 |
+
return content_data
|
15 |
+
|
16 |
+
@st.cache_data
|
17 |
+
def load_data(data_path: str):
|
18 |
+
with open(data_path, 'r') as f:
|
19 |
+
data = json.load(f)
|
20 |
+
return data
|
21 |
+
|
22 |
+
def convert_seconds(seconds: int):
|
23 |
+
"""
|
24 |
+
Converts seconds to a string of format Hours:Minutes:Seconds
|
25 |
+
"""
|
26 |
+
return time.strftime("%H:%M:%S", time.gmtime(seconds))
|
27 |
+
|
28 |
+
def generate_prompt_series(query: str, results: List[dict]) -> str:
|
29 |
+
"""
|
30 |
+
Generates a prompt for the OpenAI API by joining the context blocks of the top results.
|
31 |
+
Provides context to the LLM by supplying the summary, document name, and retrieved content of each result.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
-----
|
35 |
+
query : str
|
36 |
+
User query
|
37 |
+
results : List[dict]
|
38 |
+
List of results from the Weaviate client
|
39 |
+
"""
|
40 |
+
context_series = '\n'.join([context_block_spa.format(summary=res['page_summary'],
|
41 |
+
document=res['document_title'],
|
42 |
+
transcript=res['content']
|
43 |
+
)for res in results]).strip()
|
44 |
+
prompt = question_answering_prompt_series_spa.format(question=query, series=context_series)
|
45 |
+
return prompt
|
46 |
+
|
47 |
+
def expand_content(ranked_results: List[dict],
|
48 |
+
content_cache: Optional[dict] = None,
|
49 |
+
content_key: str = 'doc_id',
|
50 |
+
create_new_list: bool = False
|
51 |
+
) -> List[dict]:
|
52 |
+
'''
|
53 |
+
Updates or creates a list of ranked results with content from a cache.
|
54 |
+
|
55 |
+
This function iterates over a list of dictionaries representing ranked results.
|
56 |
+
If a cache is provided, it adds or updates the 'content' key in each dictionary
|
57 |
+
with the corresponding content from the cache based on the content_key.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
- ranked_results (List[dict]): A list of dictionaries, each representing a ranked result.
|
61 |
+
- content_cache (Optional[dict]): A dictionary that maps content_key to content.
|
62 |
+
If None, the content of ranked results will not be updated.
|
63 |
+
- content_key (str): The key used in both the ranked results and content cache to match
|
64 |
+
the ranked results with their corresponding content in the cache.
|
65 |
+
- create_new_list (bool): If True, a new list of dictionaries will be created and
|
66 |
+
returned with the content updated. If False, the ranked_results will be updated in place.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
- List[dict]: A new list with updated content if create_new_list is True; otherwise,
|
70 |
+
the original ranked_results list with updated content.
|
71 |
+
|
72 |
+
Note:
|
73 |
+
- If create_new_list is False, the function will mutate the original ranked_results list.
|
74 |
+
- The function only updates content if the content_key exists in both the ranked result
|
75 |
+
and the content cache.
|
76 |
+
|
77 |
+
Example:
|
78 |
+
```
|
79 |
+
ranked_results = [{'doc_id': '123', 'title': 'Title 1'}, {'doc_id': '456', 'title': 'Title 2'}]
|
80 |
+
content_cache = {'123': 'Content for 123', '456': 'Content for 456'}
|
81 |
+
updated_results = expand_content(ranked_results, content_cache, create_new_list=True)
|
82 |
+
# updated_results is now [{'doc_id': '123', 'title': 'Title 1', 'content': 'Content for 123'},
|
83 |
+
# {'doc_id': '456', 'title': 'Title 2', 'content': 'Content for 456'}]
|
84 |
+
```
|
85 |
+
'''
|
86 |
+
if create_new_list:
|
87 |
+
expanded_response = [{k:v for k, v in resp.items()} for resp in ranked_results]
|
88 |
+
if content_cache is not None:
|
89 |
+
for resp in expanded_response:
|
90 |
+
if resp[content_key] in content_cache:
|
91 |
+
resp['content'] = content_cache[resp[content_key]]
|
92 |
+
return expanded_response
|
93 |
+
else:
|
94 |
+
for resp in ranked_results:
|
95 |
+
if content_cache and resp[content_key] in content_cache:
|
96 |
+
resp['content'] = content_cache[resp[content_key]]
|
97 |
+
return ranked_results
|
98 |
+
|
99 |
+
def validate_token_threshold(ranked_results: List[dict],
|
100 |
+
base_prompt: str,
|
101 |
+
query: str,
|
102 |
+
tokenizer: tiktoken.Encoding,
|
103 |
+
token_threshold: int,
|
104 |
+
verbose: bool = False
|
105 |
+
) -> List[dict]:
|
106 |
+
"""
|
107 |
+
Validates that prompt is below the set token threshold by adding lengths of:
|
108 |
+
1. Base prompt
|
109 |
+
2. User query
|
110 |
+
3. Context material
|
111 |
+
If threshold is exceeded, context results are reduced incrementally until the
|
112 |
+
combined prompt tokens are below the threshold. This function does not take into
|
113 |
+
account every token passed to the LLM, but it is a good approximation.
|
114 |
+
"""
|
115 |
+
overhead_len = len(tokenizer.encode(base_prompt.format(question=query, series='')))
|
116 |
+
context_len = _get_batch_length(ranked_results, tokenizer)
|
117 |
+
|
118 |
+
token_count = overhead_len + context_len
|
119 |
+
if token_count > token_threshold:
|
120 |
+
print('Token count exceeds token count threshold, reducing size of returned results below token threshold')
|
121 |
+
|
122 |
+
while token_count > token_threshold and len(ranked_results) > 1:
|
123 |
+
num_results = len(ranked_results)
|
124 |
+
|
125 |
+
# remove the last ranked (most irrelevant) result
|
126 |
+
ranked_results = ranked_results[:num_results-1]
|
127 |
+
# recalculate new token_count
|
128 |
+
token_count = overhead_len + _get_batch_length(ranked_results, tokenizer)
|
129 |
+
|
130 |
+
if verbose:
|
131 |
+
logger.info(f'Total Final Token Count: {token_count}')
|
132 |
+
return ranked_results
|
133 |
+
|
134 |
+
def _get_batch_length(ranked_results: List[dict], tokenizer: tiktoken.Encoding) -> int:
|
135 |
+
'''
|
136 |
+
Convenience function to get the length in tokens of a batch of results
|
137 |
+
'''
|
138 |
+
contexts = tokenizer.encode_batch([r['content'] for r in ranked_results])
|
139 |
+
context_len = sum(list(map(len, contexts)))
|
140 |
+
return context_len
|
141 |
+
|
142 |
+
def search_result(i: int,
|
143 |
+
url: str,
|
144 |
+
title: str,
|
145 |
+
content: str,
|
146 |
+
guest: str,
|
147 |
+
length: str,
|
148 |
+
space: str=' '
|
149 |
+
) -> str:
|
150 |
+
|
151 |
+
'''
|
152 |
+
HTML to display search results.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
-----
|
156 |
+
i: int
|
157 |
+
index of search result
|
158 |
+
url: str
|
159 |
+
url of YouTube video
|
160 |
+
title: str
|
161 |
+
title of episode
|
162 |
+
content: str
|
163 |
+
content chunk of episode
|
164 |
+
'''
|
165 |
+
return f"""
|
166 |
+
<div style="font-size:120%;">
|
167 |
+
{i + 1}.<a href="{url}">{title}</a>
|
168 |
+
</div>
|
169 |
+
|
170 |
+
<div style="font-size:95%;">
|
171 |
+
<p>Episode Length: {length} {space}{space} Guest: {guest}</p>
|
172 |
+
<div style="color:grey;float:left;">
|
173 |
+
...
|
174 |
+
</div>
|
175 |
+
{content}
|
176 |
+
</div>
|
177 |
+
"""
|
utils/openai_interface_spa.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from openai import OpenAI
|
3 |
+
from typing import List, Any, Tuple
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from tqdm import tqdm
|
6 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
7 |
+
_ = load_dotenv('./.env', override=True) # read local .env file
|
8 |
+
|
9 |
+
|
10 |
+
class GPT_Turbo:
|
11 |
+
|
12 |
+
def __init__(self, model: str="gpt-3.5-turbo-0613", api_key: str=os.environ['OPENAI_API_KEY']):
|
13 |
+
self.model = model
|
14 |
+
self.client = OpenAI(api_key=api_key)
|
15 |
+
|
16 |
+
def get_chat_completion(self,
|
17 |
+
prompt: str,
|
18 |
+
system_message: str='You are a helpful assistant.',
|
19 |
+
temperature: int=0,
|
20 |
+
max_tokens: int=500,
|
21 |
+
stream: bool=False,
|
22 |
+
show_response: bool=False
|
23 |
+
) -> str:
|
24 |
+
messages = [
|
25 |
+
{'role': 'system', 'content': system_message},
|
26 |
+
{'role': 'assistant', 'content': prompt}
|
27 |
+
]
|
28 |
+
|
29 |
+
response = self.client.chat.completions.create( model=self.model,
|
30 |
+
messages=messages,
|
31 |
+
temperature=temperature,
|
32 |
+
max_tokens=max_tokens,
|
33 |
+
stream=stream)
|
34 |
+
if show_response:
|
35 |
+
return response
|
36 |
+
return response.choices[0].message.content
|
37 |
+
|
38 |
+
def multi_thread_request(self,
|
39 |
+
filepath: str,
|
40 |
+
prompt: str,
|
41 |
+
content: List[str],
|
42 |
+
temperature: int=0
|
43 |
+
) -> List[Any]:
|
44 |
+
|
45 |
+
data = []
|
46 |
+
with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
|
47 |
+
futures = [exec.submit(self.get_completion_from_messages, [{'role': 'user','content': f'{prompt} ```{c}```'}], temperature, 500, False) for c in content]
|
48 |
+
with open(filepath, 'a') as f:
|
49 |
+
for future in as_completed(futures):
|
50 |
+
result = future.result()
|
51 |
+
if len(data) % 10 == 0:
|
52 |
+
print(f'{len(data)} of {len(content)} completed.')
|
53 |
+
if result:
|
54 |
+
data.append(result)
|
55 |
+
self.write_to_file(file_handle=f, data=result)
|
56 |
+
return [res for res in data if res]
|
57 |
+
|
58 |
+
def generate_question_context_pairs(self,
|
59 |
+
context_tuple: Tuple[str, str],
|
60 |
+
num_questions_per_chunk: int=2,
|
61 |
+
max_words_per_question: int=10
|
62 |
+
) -> List[str]:
|
63 |
+
|
64 |
+
doc_id, context = context_tuple
|
65 |
+
prompt = f'Context information is included below enclosed in triple backticks. Given the context information and not prior knowledge, generate questions based on the below query.\n\nYou are an end user querying for information about your favorite podcast. \
|
66 |
+
Your task is to setup {num_questions_per_chunk} questions that can be answered using only the given context. The questions should be diverse in nature across the document and be no longer than {max_words_per_question} words. \
|
67 |
+
Restrict the questions to the context information provided.\n\
|
68 |
+
```{context}```\n\n'
|
69 |
+
|
70 |
+
response = self.get_completion_from_messages(prompt=prompt, temperature=0, max_tokens=500, show_response=True)
|
71 |
+
questions = response.choices[0].message["content"]
|
72 |
+
return (doc_id, questions)
|
73 |
+
|
74 |
+
def batch_generate_question_context_pairs(self,
|
75 |
+
context_tuple_list: List[Tuple[str, str]],
|
76 |
+
num_questions_per_chunk: int=2,
|
77 |
+
max_words_per_question: int=10
|
78 |
+
) -> List[Tuple[str, str]]:
|
79 |
+
data = []
|
80 |
+
progress = tqdm(unit="Generated Questions", total=len(context_tuple_list))
|
81 |
+
with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
|
82 |
+
futures = [exec.submit(self.generate_question_context_pairs, context_tuple, num_questions_per_chunk, max_words_per_question) for context_tuple in context_tuple_list]
|
83 |
+
for future in as_completed(futures):
|
84 |
+
result = future.result()
|
85 |
+
if result:
|
86 |
+
data.append(result)
|
87 |
+
progress.update(1)
|
88 |
+
return data
|
89 |
+
|
90 |
+
def get_embedding(self):
|
91 |
+
pass
|
92 |
+
|
93 |
+
def write_to_file(self, file_handle, data: str) -> None:
|
94 |
+
file_handle.write(data)
|
95 |
+
file_handle.write('\n')
|
utils/preprocessing.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pandas as pd
|
4 |
+
from typing import List, Union, Dict
|
5 |
+
from loguru import logger
|
6 |
+
import pandas as pd
|
7 |
+
import pathlib
|
8 |
+
|
9 |
+
|
10 |
+
## Set of helper functions that support data preprocessing
|
11 |
+
class FileIO:
|
12 |
+
'''
|
13 |
+
Convenience class for saving and loading data in parquet and
|
14 |
+
json formats to/from disk.
|
15 |
+
'''
|
16 |
+
|
17 |
+
def save_as_parquet(self,
|
18 |
+
file_path: str,
|
19 |
+
data: Union[List[dict], pd.DataFrame],
|
20 |
+
overwrite: bool=False) -> None:
|
21 |
+
'''
|
22 |
+
Saves DataFrame to disk as a parquet file. Removes the index.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
-----
|
26 |
+
file_path : str
|
27 |
+
Output path to save file, if not included "parquet" will be appended
|
28 |
+
as file extension.
|
29 |
+
data : Union[List[dict], pd.DataFrame]
|
30 |
+
Data to save as parquet file. If data is a list of dicts, it will be
|
31 |
+
converted to a DataFrame before saving.
|
32 |
+
overwrite : bool
|
33 |
+
Overwrite existing file if True, otherwise raise FileExistsError.
|
34 |
+
'''
|
35 |
+
if isinstance(data, list):
|
36 |
+
data = self._convert_toDataFrame(data)
|
37 |
+
if not file_path.endswith('parquet'):
|
38 |
+
file_path = self._rename_file_extension(file_path, 'parquet')
|
39 |
+
self._check_file_path(file_path, overwrite=overwrite)
|
40 |
+
data.to_parquet(file_path, index=False)
|
41 |
+
logger.info(f'DataFrame saved as parquet file here: {file_path}')
|
42 |
+
|
43 |
+
def _convert_toDataFrame(self, data: List[dict]) -> pd.DataFrame:
|
44 |
+
return pd.DataFrame().from_dict(data)
|
45 |
+
|
46 |
+
def _rename_file_extension(self, file_path: str, extension: str):
|
47 |
+
'''
|
48 |
+
Renames file with appropriate extension if file_path
|
49 |
+
does not already have correct extension.
|
50 |
+
'''
|
51 |
+
prefix = os.path.splitext(file_path)[0]
|
52 |
+
file_path = prefix + '.' + extension
|
53 |
+
return file_path
|
54 |
+
|
55 |
+
def _check_file_path(self, file_path: str, overwrite: bool) -> None:
|
56 |
+
'''
|
57 |
+
Checks for existence of file and overwrite permissions.
|
58 |
+
'''
|
59 |
+
if os.path.exists(file_path) and overwrite == False:
|
60 |
+
raise FileExistsError(f'File by name {file_path} already exists, try using another file name or set overwrite to True.')
|
61 |
+
elif os.path.exists(file_path):
|
62 |
+
os.remove(file_path)
|
63 |
+
else:
|
64 |
+
file_name = os.path.basename(file_path)
|
65 |
+
dir_structure = file_path.replace(file_name, '')
|
66 |
+
pathlib.Path(dir_structure).mkdir(parents=True, exist_ok=True)
|
67 |
+
|
68 |
+
def load_parquet(self, file_path: str, verbose: bool=True) -> List[dict]:
|
69 |
+
'''
|
70 |
+
Loads parquet from disk, converts to pandas DataFrame as intermediate
|
71 |
+
step and outputs a list of dicts (docs).
|
72 |
+
'''
|
73 |
+
df = pd.read_parquet(file_path)
|
74 |
+
vector_labels = ['content_vector', 'image_vector', 'content_embedding']
|
75 |
+
for label in vector_labels:
|
76 |
+
if label in df.columns:
|
77 |
+
df[label] = df[label].apply(lambda x: x.tolist())
|
78 |
+
if verbose:
|
79 |
+
memory_usage = round(df.memory_usage().sum()/(1024*1024),2)
|
80 |
+
print(f'Shape of data: {df.values.shape}')
|
81 |
+
print(f'Memory Usage: {memory_usage}+ MB')
|
82 |
+
list_of_dicts = df.to_dict('records')
|
83 |
+
return list_of_dicts
|
84 |
+
|
85 |
+
def load_json(self, file_path: str):
|
86 |
+
'''
|
87 |
+
Loads json file from disk.
|
88 |
+
'''
|
89 |
+
with open(file_path) as f:
|
90 |
+
data = json.load(f)
|
91 |
+
return data
|
92 |
+
|
93 |
+
def save_as_json(self,
|
94 |
+
file_path: str,
|
95 |
+
data: Union[List[dict], dict],
|
96 |
+
indent: int=4,
|
97 |
+
overwrite: bool=False
|
98 |
+
) -> None:
|
99 |
+
'''
|
100 |
+
Saves data to disk as a json file. Data can be a list of dicts or a single dict.
|
101 |
+
'''
|
102 |
+
if not file_path.endswith('json'):
|
103 |
+
file_path = self._rename_file_extension(file_path, 'json')
|
104 |
+
self._check_file_path(file_path, overwrite=overwrite)
|
105 |
+
with open(file_path, 'w') as f:
|
106 |
+
json.dump(data, f, indent=indent)
|
107 |
+
logger.info(f'Data saved as json file here: {file_path}')
|
108 |
+
|
109 |
+
class Utilities:
|
110 |
+
|
111 |
+
def create_video_url(self, video_id: str, playlist_id: str):
|
112 |
+
'''
|
113 |
+
Creates a hyperlink to a video episode given a video_id and playlist_id.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
-----
|
117 |
+
video_id : str
|
118 |
+
Video id of the episode from YouTube
|
119 |
+
playlist_id : str
|
120 |
+
Playlist id of the episode from YouTube
|
121 |
+
'''
|
122 |
+
return f'https://www.youtube.com/watch?v={video_id}&list={playlist_id}'
|
123 |
+
|
utils/prompt_templates_spa.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
question_answering_prompt_series_spa = '''
|
2 |
+
Su tarea es sintetizar y razonar sobre una serie de contenidos proporcionados.
|
3 |
+
DespuΓ©s de su sΓntesis, utilice estos contenidos para responder a la pregunta a continuaciΓ³n. La serie estarΓ‘ en el siguiente formato:\n
|
4 |
+
```
|
5 |
+
RESUMEN: <summary>
|
6 |
+
DOCUMENTO: <document>
|
7 |
+
CONTENIDO: <transcript>
|
8 |
+
```\n\n
|
9 |
+
Inicio de la Serie:
|
10 |
+
```
|
11 |
+
{series}
|
12 |
+
```
|
13 |
+
Pregunta:\n
|
14 |
+
{question}\n
|
15 |
+
Responda a la pregunta y proporcione razonamientos si es necesario para explicar la respuesta.
|
16 |
+
Si el contexto no proporciona suficiente informaciΓ³n para responder a la pregunta, entonces
|
17 |
+
indique que no puede responder a la pregunta con el contexto proporcionado.
|
18 |
+
|
19 |
+
Respuesta:
|
20 |
+
'''
|
21 |
+
|
22 |
+
context_block_spa = '''
|
23 |
+
RESUMEN: {summary}
|
24 |
+
DOCUMENTO: {document}
|
25 |
+
CONTENIDO: {transcript}
|
26 |
+
'''
|
utils/reranker_spa.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import CrossEncoder
|
2 |
+
from torch.nn import Sigmoid
|
3 |
+
from typing import List, Union
|
4 |
+
import numpy as np
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
class ReRanker(CrossEncoder):
|
8 |
+
'''
|
9 |
+
Cross-Encoder models achieve higher performance than Bi-Encoders,
|
10 |
+
however, they do not scale well to large datasets. The lack of scalability
|
11 |
+
is due to the underlying cross-attention mechanism, which is computationally
|
12 |
+
expensive. Thus a Bi-Encoder is best used for 1st-stage document retrieval and
|
13 |
+
a Cross-Encoder is used to re-rank the retrieved documents.
|
14 |
+
|
15 |
+
https://www.sbert.net/examples/applications/cross-encoder/README.html
|
16 |
+
'''
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
model_name: str='cross-encoder/ms-marco-MiniLM-L-6-v2',
|
20 |
+
**kwargs
|
21 |
+
):
|
22 |
+
super().__init__(model_name=model_name,
|
23 |
+
**kwargs)
|
24 |
+
self.model_name = model_name
|
25 |
+
self.score_field = 'cross_score'
|
26 |
+
self.activation_fct = Sigmoid()
|
27 |
+
|
28 |
+
def _cross_encoder_score(self,
|
29 |
+
results: List[dict],
|
30 |
+
query: str,
|
31 |
+
hit_field: str='content',
|
32 |
+
apply_sigmoid: bool=True,
|
33 |
+
return_scores: bool=False
|
34 |
+
) -> Union[np.array, None]:
|
35 |
+
'''
|
36 |
+
Given a list of hits from a Retriever:
|
37 |
+
1. Scores hits by passing query and results through CrossEncoder model.
|
38 |
+
2. Adds cross-score key to results dictionary.
|
39 |
+
3. If desired returns np.array of Cross Encoder scores.
|
40 |
+
'''
|
41 |
+
activation_fct = self.activation_fct if apply_sigmoid else None
|
42 |
+
#build query/content list
|
43 |
+
cross_inp = [[query, hit[hit_field]] for hit in results]
|
44 |
+
#get scores
|
45 |
+
cross_scores = self.predict(cross_inp, activation_fct=activation_fct)
|
46 |
+
for i, result in enumerate(results):
|
47 |
+
result[self.score_field]=cross_scores[i]
|
48 |
+
|
49 |
+
if return_scores:return cross_scores
|
50 |
+
|
51 |
+
def rerank(self,
|
52 |
+
results: List[dict],
|
53 |
+
query: str,
|
54 |
+
top_k: int=10,
|
55 |
+
apply_sigmoid: bool=True,
|
56 |
+
threshold: float=None
|
57 |
+
) -> List[dict]:
|
58 |
+
'''
|
59 |
+
Given a list of hits from a Retriever:
|
60 |
+
1. Scores hits by passing query and results through CrossEncoder model.
|
61 |
+
2. Adds cross_score key to results dictionary.
|
62 |
+
3. Returns reranked results limited by either a threshold value or top_k.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
-----
|
66 |
+
results : List[dict]
|
67 |
+
List of results from the Weaviate client
|
68 |
+
query : str
|
69 |
+
User query
|
70 |
+
top_k : int=10
|
71 |
+
Number of results to return
|
72 |
+
apply_sigmoid : bool=True
|
73 |
+
Whether to apply sigmoid activation to cross-encoder scores. If False,
|
74 |
+
returns raw cross-encoder scores (logits).
|
75 |
+
threshold : float=None
|
76 |
+
Minimum cross-encoder score to return. If no hits are above threshold,
|
77 |
+
returns top_k hits.
|
78 |
+
'''
|
79 |
+
# Sort results by the cross-encoder scores
|
80 |
+
self._cross_encoder_score(results=results, query=query, apply_sigmoid=apply_sigmoid)
|
81 |
+
|
82 |
+
sorted_hits = sorted(results, key=lambda x: x[self.score_field], reverse=True)
|
83 |
+
if threshold or threshold == 0:
|
84 |
+
filtered_hits = [hit for hit in sorted_hits if hit[self.score_field] >= threshold]
|
85 |
+
if not any(filtered_hits):
|
86 |
+
logger.warning(f'No hits above threshold {threshold}. Returning top {top_k} hits.')
|
87 |
+
return sorted_hits[:top_k]
|
88 |
+
return filtered_hits
|
89 |
+
return sorted_hits[:top_k]
|
utils/retrieval_evaluation_spa.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#external files
|
2 |
+
from openai_interface_spa import GPT_Turbo
|
3 |
+
from weaviate_interface_v3_spa import WeaviateClient
|
4 |
+
from llama_index.finetuning import EmbeddingQAFinetuneDataset
|
5 |
+
from templates.prompt_templates_spa import qa_generation_prompt
|
6 |
+
from reranker_spa import ReRanker
|
7 |
+
|
8 |
+
#standard library imports
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
import uuid
|
12 |
+
import os
|
13 |
+
import re
|
14 |
+
import random
|
15 |
+
from datetime import datetime
|
16 |
+
from typing import List, Dict, Tuple, Union, Literal
|
17 |
+
|
18 |
+
#misc
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
|
22 |
+
class QueryContextGenerator:
|
23 |
+
'''
|
24 |
+
Class designed for the generation of query/context pairs using a
|
25 |
+
Generative LLM. The LLM is used to generate questions from a given
|
26 |
+
corpus of text. The query/context pairs can be used to fine-tune
|
27 |
+
an embedding model using a MultipleNegativesRankingLoss loss function
|
28 |
+
or can be used to create evaluation datasets for retrieval models.
|
29 |
+
'''
|
30 |
+
def __init__(self, openai_key: str, model_id: str='gpt-3.5-turbo-0613'):
|
31 |
+
self.llm = GPT_Turbo(model=model_id, api_key=openai_key)
|
32 |
+
|
33 |
+
def clean_validate_data(self,
|
34 |
+
data: List[dict],
|
35 |
+
valid_fields: List[str]=['content', 'summary', 'guest', 'doc_id'],
|
36 |
+
total_chars: int=950
|
37 |
+
) -> List[dict]:
|
38 |
+
'''
|
39 |
+
Strip original data chunks so they only contain valid_fields.
|
40 |
+
Remove any chunks less than total_chars in size. Prevents LLM
|
41 |
+
from asking questions from sparse content.
|
42 |
+
'''
|
43 |
+
clean_docs = [{k:v for k,v in d.items() if k in valid_fields} for d in data]
|
44 |
+
valid_docs = [d for d in clean_docs if len(d['content']) > total_chars]
|
45 |
+
return valid_docs
|
46 |
+
|
47 |
+
def train_val_split(self,
|
48 |
+
data: List[dict],
|
49 |
+
n_train_questions: int,
|
50 |
+
n_val_questions: int,
|
51 |
+
n_questions_per_chunk: int=2,
|
52 |
+
total_chars: int=950):
|
53 |
+
'''
|
54 |
+
Splits corpus into training and validation sets. Training and
|
55 |
+
validation samples are randomly selected from the corpus. total_chars
|
56 |
+
parameter is set based on pre-analysis of average doc length in the
|
57 |
+
training corpus.
|
58 |
+
'''
|
59 |
+
clean_data = self.clean_validate_data(data, total_chars=total_chars)
|
60 |
+
random.shuffle(clean_data)
|
61 |
+
train_index = n_train_questions//n_questions_per_chunk
|
62 |
+
valid_index = n_val_questions//n_questions_per_chunk
|
63 |
+
end_index = valid_index + train_index
|
64 |
+
if end_index > len(clean_data):
|
65 |
+
raise ValueError('Cannot create dataset with desired number of questions, try using a larger dataset')
|
66 |
+
train_data = clean_data[:train_index]
|
67 |
+
valid_data = clean_data[train_index:end_index]
|
68 |
+
print(f'Length Training Data: {len(train_data)}')
|
69 |
+
print(f'Length Validation Data: {len(valid_data)}')
|
70 |
+
return train_data, valid_data
|
71 |
+
|
72 |
+
def generate_qa_embedding_pairs(
|
73 |
+
self,
|
74 |
+
data: List[dict],
|
75 |
+
generate_prompt_tmpl: str=None,
|
76 |
+
num_questions_per_chunk: int = 2,
|
77 |
+
) -> EmbeddingQAFinetuneDataset:
|
78 |
+
"""
|
79 |
+
Generate query/context pairs from a list of documents. The query/context pairs
|
80 |
+
can be used for fine-tuning an embedding model using a MultipleNegativesRankingLoss
|
81 |
+
or can be used to create an evaluation dataset for retrieval models.
|
82 |
+
|
83 |
+
This function was adapted for this course from the llama_index.finetuning.common module:
|
84 |
+
https://github.com/run-llama/llama_index/blob/main/llama_index/finetuning/embeddings/common.py
|
85 |
+
"""
|
86 |
+
generate_prompt_tmpl = qa_generation_prompt if not generate_prompt_tmpl else generate_prompt_tmpl
|
87 |
+
queries = {}
|
88 |
+
relevant_docs = {}
|
89 |
+
corpus = {chunk['doc_id'] : chunk['content'] for chunk in data}
|
90 |
+
for chunk in tqdm(data):
|
91 |
+
page_summary = chunk['page_summary']
|
92 |
+
# guest = chunk['guest']
|
93 |
+
context_str = chunk['content']
|
94 |
+
node_id = chunk['doc_id']
|
95 |
+
query = generate_prompt_tmpl.format(page_summary=page_summary,
|
96 |
+
# guest=guest,
|
97 |
+
context_str=context_str,
|
98 |
+
num_questions_per_chunk=num_questions_per_chunk)
|
99 |
+
try:
|
100 |
+
response = self.llm.get_chat_completion(prompt=query, temperature=0.1, max_tokens=100)
|
101 |
+
except Exception as e:
|
102 |
+
print(e)
|
103 |
+
continue
|
104 |
+
result = str(response).strip().split("\n")
|
105 |
+
questions = [
|
106 |
+
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
107 |
+
]
|
108 |
+
questions = [question for question in questions if len(question) > 0]
|
109 |
+
|
110 |
+
for question in questions:
|
111 |
+
question_id = str(uuid.uuid4())
|
112 |
+
queries[question_id] = question
|
113 |
+
relevant_docs[question_id] = [node_id]
|
114 |
+
|
115 |
+
# construct dataset
|
116 |
+
return EmbeddingQAFinetuneDataset(
|
117 |
+
queries=queries, corpus=corpus, relevant_docs=relevant_docs
|
118 |
+
)
|
119 |
+
|
120 |
+
def execute_evaluation(dataset: EmbeddingQAFinetuneDataset,
|
121 |
+
class_name: str,
|
122 |
+
retriever: WeaviateClient,
|
123 |
+
reranker: ReRanker=None,
|
124 |
+
alpha: float=0.5,
|
125 |
+
retrieve_limit: int=100,
|
126 |
+
top_k: int=5,
|
127 |
+
chunk_size: int=256,
|
128 |
+
hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
|
129 |
+
search_type: Literal['kw', 'vector', 'hybrid', 'all']='all',
|
130 |
+
display_properties: List[str]=['doc_id', 'content'],
|
131 |
+
dir_outpath: str='./eval_results',
|
132 |
+
include_miss_info: bool=False,
|
133 |
+
user_def_params: dict=None
|
134 |
+
) -> Union[dict, Tuple[dict, List[dict]]]:
|
135 |
+
'''
|
136 |
+
Given a dataset, a retriever, and a reranker, evaluate the performance of the retriever and reranker.
|
137 |
+
Returns a dict of kw, vector, and hybrid hit rates and mrr scores. If include_miss_info is True, will
|
138 |
+
also return a list of kw and vector responses and their associated queries that did not return a hit.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
-----
|
142 |
+
dataset: EmbeddingQAFinetuneDataset
|
143 |
+
Dataset to be used for evaluation
|
144 |
+
class_name: str
|
145 |
+
Name of Class on Weaviate host to be used for retrieval
|
146 |
+
retriever: WeaviateClient
|
147 |
+
WeaviateClient object to be used for retrieval
|
148 |
+
reranker: ReRanker
|
149 |
+
ReRanker model to be used for results reranking
|
150 |
+
alpha: float=0.5
|
151 |
+
Weighting factor for BM25 and Vector search.
|
152 |
+
alpha can be any number from 0 to 1, defaulting to 0.5:
|
153 |
+
alpha = 0 executes a pure keyword search method (BM25)
|
154 |
+
alpha = 0.5 weighs the BM25 and vector methods evenly
|
155 |
+
alpha = 1 executes a pure vector search method
|
156 |
+
retrieve_limit: int=5
|
157 |
+
Number of documents to retrieve from Weaviate host
|
158 |
+
top_k: int=5
|
159 |
+
Number of top results to evaluate
|
160 |
+
chunk_size: int=256
|
161 |
+
Number of tokens used to chunk text
|
162 |
+
hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef']
|
163 |
+
List of keys to be used for retrieving HNSW Index parameters from Weaviate host
|
164 |
+
search_type: Literal['kw', 'vector', 'hybrid', 'all']='all'
|
165 |
+
Type of search to be evaluated. Options are 'kw', 'vector', 'hybrid', or 'all'
|
166 |
+
display_properties: List[str]=['doc_id', 'content']
|
167 |
+
List of properties to be returned from Weaviate host for display in response
|
168 |
+
dir_outpath: str='./eval_results'
|
169 |
+
Directory path for saving results. Directory will be created if it does not
|
170 |
+
already exist.
|
171 |
+
include_miss_info: bool=False
|
172 |
+
Option to include queries and their associated search response values
|
173 |
+
for queries that are "total misses"
|
174 |
+
user_def_params : dict=None
|
175 |
+
Option for user to pass in a dictionary of user-defined parameters and their values.
|
176 |
+
Will be automatically added to the results_dict if correct type is passed.
|
177 |
+
'''
|
178 |
+
|
179 |
+
reranker_name = reranker.model_name if reranker else "None"
|
180 |
+
|
181 |
+
results_dict = {'n':retrieve_limit,
|
182 |
+
'top_k': top_k,
|
183 |
+
'alpha': alpha,
|
184 |
+
'Retriever': retriever.model_name_or_path,
|
185 |
+
'Ranker': reranker_name,
|
186 |
+
'chunk_size': chunk_size,
|
187 |
+
'kw_hit_rate': 0,
|
188 |
+
'kw_mrr': 0,
|
189 |
+
'vector_hit_rate': 0,
|
190 |
+
'vector_mrr': 0,
|
191 |
+
'hybrid_hit_rate':0,
|
192 |
+
'hybrid_mrr': 0,
|
193 |
+
'total_misses': 0,
|
194 |
+
'total_questions':0
|
195 |
+
}
|
196 |
+
#add extra params to results_dict
|
197 |
+
results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
|
198 |
+
|
199 |
+
start = time.perf_counter()
|
200 |
+
miss_info = []
|
201 |
+
for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
|
202 |
+
results_dict['total_questions'] += 1
|
203 |
+
hit = False
|
204 |
+
#make Keyword, Vector, and Hybrid calls to Weaviate host
|
205 |
+
try:
|
206 |
+
kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
|
207 |
+
vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
|
208 |
+
hybrid_response = retriever.hybrid_search(request=q, class_name=class_name, alpha=alpha, limit=retrieve_limit, display_properties=display_properties)
|
209 |
+
#rerank returned responses if reranker is provided
|
210 |
+
if reranker:
|
211 |
+
kw_response = reranker.rerank(kw_response, q, top_k=top_k)
|
212 |
+
vector_response = reranker.rerank(vector_response, q, top_k=top_k)
|
213 |
+
hybrid_response = reranker.rerank(hybrid_response, q, top_k=top_k)
|
214 |
+
|
215 |
+
#collect doc_ids to check for document matches (include only results_top_k)
|
216 |
+
kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response[:top_k], 1)}
|
217 |
+
vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response[:top_k], 1)}
|
218 |
+
hybrid_doc_ids = {result['doc_id']:i for i, result in enumerate(hybrid_response[:top_k], 1)}
|
219 |
+
|
220 |
+
#extract doc_id for scoring purposes
|
221 |
+
doc_id = dataset.relevant_docs[query_id][0]
|
222 |
+
|
223 |
+
#increment hit_rate counters and mrr scores
|
224 |
+
if doc_id in kw_doc_ids:
|
225 |
+
results_dict['kw_hit_rate'] += 1
|
226 |
+
results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
|
227 |
+
hit = True
|
228 |
+
if doc_id in vector_doc_ids:
|
229 |
+
results_dict['vector_hit_rate'] += 1
|
230 |
+
results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
|
231 |
+
hit = True
|
232 |
+
if doc_id in hybrid_doc_ids:
|
233 |
+
results_dict['hybrid_hit_rate'] += 1
|
234 |
+
results_dict['hybrid_mrr'] += 1/hybrid_doc_ids[doc_id]
|
235 |
+
hit = True
|
236 |
+
# if no hits, let's capture that
|
237 |
+
if not hit:
|
238 |
+
results_dict['total_misses'] += 1
|
239 |
+
miss_info.append({'query': q,
|
240 |
+
'answer': dataset.corpus[doc_id],
|
241 |
+
'doc_id': doc_id,
|
242 |
+
'kw_response': kw_response,
|
243 |
+
'vector_response': vector_response,
|
244 |
+
'hybrid_response': hybrid_response})
|
245 |
+
except Exception as e:
|
246 |
+
print(e)
|
247 |
+
continue
|
248 |
+
|
249 |
+
#use raw counts to calculate final scores
|
250 |
+
calc_hit_rate_scores(results_dict, search_type=search_type)
|
251 |
+
calc_mrr_scores(results_dict, search_type=search_type)
|
252 |
+
|
253 |
+
end = time.perf_counter() - start
|
254 |
+
print(f'Total Processing Time: {round(end/60, 2)} minutes')
|
255 |
+
record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
|
256 |
+
|
257 |
+
if include_miss_info:
|
258 |
+
return results_dict, miss_info
|
259 |
+
return results_dict
|
260 |
+
|
261 |
+
def calc_hit_rate_scores(results_dict: Dict[str, Union[str, int]],
|
262 |
+
search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
|
263 |
+
) -> None:
|
264 |
+
if search_type == 'all':
|
265 |
+
search_type = ['kw', 'vector', 'hybrid']
|
266 |
+
for prefix in search_type:
|
267 |
+
results_dict[f'{prefix}_hit_rate'] = round(results_dict[f'{prefix}_hit_rate']/results_dict['total_questions'],2)
|
268 |
+
|
269 |
+
def calc_mrr_scores(results_dict: Dict[str, Union[str, int]],
|
270 |
+
search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
|
271 |
+
) -> None:
|
272 |
+
if search_type == 'all':
|
273 |
+
search_type = ['kw', 'vector', 'hybrid']
|
274 |
+
for prefix in search_type:
|
275 |
+
results_dict[f'{prefix}_mrr'] = round(results_dict[f'{prefix}_mrr']/results_dict['total_questions'],2)
|
276 |
+
|
277 |
+
def create_dir(dir_path: str) -> None:
|
278 |
+
'''
|
279 |
+
Checks if directory exists, and creates new directory
|
280 |
+
if it does not exist
|
281 |
+
'''
|
282 |
+
if not os.path.exists(dir_path):
|
283 |
+
os.makedirs(dir_path)
|
284 |
+
|
285 |
+
def record_results(results_dict: Dict[str, Union[str, int]],
|
286 |
+
chunk_size: int,
|
287 |
+
dir_outpath: str='./eval_results',
|
288 |
+
as_text: bool=False
|
289 |
+
) -> None:
|
290 |
+
'''
|
291 |
+
Write results to output file in either txt or json format
|
292 |
+
|
293 |
+
Args:
|
294 |
+
-----
|
295 |
+
results_dict: Dict[str, Union[str, int]]
|
296 |
+
Dictionary containing results of evaluation
|
297 |
+
chunk_size: int
|
298 |
+
Size of text chunks in tokens
|
299 |
+
dir_outpath: str
|
300 |
+
Path to output directory. Directory only, filename is hardcoded
|
301 |
+
as part of this function.
|
302 |
+
as_text: bool
|
303 |
+
If True, write results as text file. If False, write as json file.
|
304 |
+
'''
|
305 |
+
create_dir(dir_outpath)
|
306 |
+
time_marker = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
307 |
+
ext = 'txt' if as_text else 'json'
|
308 |
+
path = os.path.join(dir_outpath, f'retrieval_eval_{chunk_size}_{time_marker}.{ext}')
|
309 |
+
if as_text:
|
310 |
+
with open(path, 'a') as f:
|
311 |
+
f.write(f"{results_dict}\n")
|
312 |
+
else:
|
313 |
+
with open(path, 'w') as f:
|
314 |
+
json.dump(results_dict, f, indent=4)
|
315 |
+
|
316 |
+
def add_params(client: WeaviateClient,
|
317 |
+
class_name: str,
|
318 |
+
results_dict: dict,
|
319 |
+
param_options: dict,
|
320 |
+
hnsw_config_keys: List[str]
|
321 |
+
) -> dict:
|
322 |
+
hnsw_params = {k:v for k,v in client.show_class_config(class_name)['vectorIndexConfig'].items() if k in hnsw_config_keys}
|
323 |
+
if hnsw_params:
|
324 |
+
results_dict = {**results_dict, **hnsw_params}
|
325 |
+
if param_options and isinstance(param_options, dict):
|
326 |
+
results_dict = {**results_dict, **param_options}
|
327 |
+
return results_dict
|
328 |
+
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
|
utils/system_prompts.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
chatgpt = '''
|
2 |
+
You are a helpful assistant.
|
3 |
+
'''
|
4 |
+
|
5 |
+
professor_synapse = '''
|
6 |
+
Act as Professor Synapseπ§πΎββοΈ, a conductor of expert agents. Your job is to support me in accomplishing my goals by finding alignment with me, then calling upon an expert agent perfectly suited to the task by initializing:
|
7 |
+
|
8 |
+
Synapse_CoR = "[emoji]: I am an expert in [role&domain]. I know [context]. I will reason step-by-step to determine the best course of action to achieve [goal]. I can use [tools] and [relevant frameworks] to help in this process.
|
9 |
+
|
10 |
+
I will help you accomplish your goal by following these steps:
|
11 |
+
[reasoned steps]
|
12 |
+
|
13 |
+
My task ends when [completion].
|
14 |
+
|
15 |
+
[first step, question]"
|
16 |
+
|
17 |
+
Instructions:
|
18 |
+
|
19 |
+
1. π§πΎββοΈ gather context, relevant information and clarify my goals by asking questions
|
20 |
+
2. Once confirmed, initialize Synapse_CoR
|
21 |
+
3. π§πΎββοΈ and [emoji] support me until goal is complete
|
22 |
+
|
23 |
+
Commands:
|
24 |
+
/start=π§πΎββοΈ,introduce and begin with step one
|
25 |
+
/ts=π§πΎββοΈ,summon (Synapse_CoR*3) town square debate
|
26 |
+
/saveπ§πΎββοΈ, restate goal, summarize progress, reason next step
|
27 |
+
|
28 |
+
Personality:
|
29 |
+
-curious, inquisitive, encouraging
|
30 |
+
-use emojis to express yourself
|
31 |
+
|
32 |
+
Rules:
|
33 |
+
-End every output with a question or reasoned next step
|
34 |
+
-Start every output with π§πΎββοΈ: or [emoji]: to indicate who is speaking.
|
35 |
+
-Organize every output βπ§πΎββοΈ: [aligning on my goal], [emoji]: [actionable response]
|
36 |
+
-π§πΎββοΈ, recommend save after each task is completed
|
37 |
+
'''
|
38 |
+
|
39 |
+
marketing_jane = '''
|
40 |
+
Act as Marcus π©πΌβπΌMarketing jane, a strategist adept at melding analytics with creative zest. With mastery over data-driven marketing and an innate knack for storytelling, your mission is to carve out distinctive marketing strategies. From fledgling startups to seasoned giants.
|
41 |
+
|
42 |
+
Your strategy formulation entails:
|
43 |
+
- Understanding the business's narrative, competitive landscape, and audience psyche.
|
44 |
+
- Crafting a data-informed marketing roadmap, encompassing various channels, and innovative tactics.
|
45 |
+
- Leveraging storytelling to forge brand engagement and pioneering avant-garde campaigns.
|
46 |
+
|
47 |
+
Your endeavor culminates when the user possesses a dynamic, data-enriched marketing strategy, resonating with their business ethos.
|
48 |
+
|
49 |
+
Steps:
|
50 |
+
1. π©πΌβπΌ, Grasp the business's ethos, objectives, and challenges
|
51 |
+
2. Design a data-backed marketing strategy, resonating with audience sentiments and business goals
|
52 |
+
3. Engage in feedback loops, iteratively refining the strategy
|
53 |
+
|
54 |
+
Commands:
|
55 |
+
/explore - Modify the strategic focus or delve deeper into specific marketing nuances
|
56 |
+
/save - Chronicle progress, dissect strategy elements, and chart future endeavors
|
57 |
+
/critic - π©πΌβπΌ seeks insights from fellow marketing aficionados
|
58 |
+
/reason - π©πΌβπΌ and user collaboratively weave the marketing narrative
|
59 |
+
/new - Ignite a fresh strategic quest for a new venture or campaign
|
60 |
+
|
61 |
+
Rules:
|
62 |
+
- Culminate with an evocative campaign concept or the next strategic juncture
|
63 |
+
- Preface with π©πΌβπΌ: for clarity
|
64 |
+
- Integrate data insights with creative innovation
|
65 |
+
'''
|
66 |
+
|
67 |
+
# Define a dictionary to map the emojis to the variables
|
68 |
+
prompt_mapping = {
|
69 |
+
"π€ChatGPT": chatgpt,
|
70 |
+
"π§πΎββοΈProfessor Synapse": professor_synapse,
|
71 |
+
"π©πΌβπΌMarketing Jane": marketing_jane,
|
72 |
+
}
|
utils/weaviate_interface_v3_spa.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from weaviate import Client, AuthApiKey
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from openai import OpenAI
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from typing import List, Union, Callable
|
6 |
+
from torch import cuda
|
7 |
+
from tqdm import tqdm
|
8 |
+
import time
|
9 |
+
|
10 |
+
class WeaviateClient(Client):
|
11 |
+
'''
|
12 |
+
A python native Weaviate Client class that encapsulates Weaviate functionalities
|
13 |
+
in one object. Several convenience methods are added for ease of use.
|
14 |
+
|
15 |
+
Args
|
16 |
+
----
|
17 |
+
api_key: str
|
18 |
+
The API key for the Weaviate Cloud Service (WCS) instance.
|
19 |
+
https://console.weaviate.cloud/dashboard
|
20 |
+
|
21 |
+
endpoint: str
|
22 |
+
The url endpoint for the Weaviate Cloud Service instance.
|
23 |
+
|
24 |
+
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2'
|
25 |
+
The name or path of the SentenceTransformer model to use for vector search.
|
26 |
+
Will also support OpenAI text-embedding-ada-002 model. This param enables
|
27 |
+
the use of most leading models on MTEB Leaderboard:
|
28 |
+
https://huggingface.co/spaces/mteb/leaderboard
|
29 |
+
openai_api_key: str=None
|
30 |
+
The API key for the OpenAI API. Only required if using OpenAI text-embedding-ada-002 model.
|
31 |
+
'''
|
32 |
+
def __init__(self,
|
33 |
+
endpoint: str,
|
34 |
+
api_key: str = None, # Make the api_key optional
|
35 |
+
model_name_or_path: str = 'sentence-transformers/all-MiniLM-L6-v2',
|
36 |
+
openai_api_key: str = None,
|
37 |
+
**kwargs
|
38 |
+
):
|
39 |
+
if api_key: # Only use AuthApiKey if api_key is provided
|
40 |
+
auth_config = AuthApiKey(api_key=api_key)
|
41 |
+
super().__init__(auth_client_secret=auth_config, url=endpoint, **kwargs)
|
42 |
+
else:
|
43 |
+
super().__init__(url=endpoint, **kwargs)
|
44 |
+
|
45 |
+
self.model_name_or_path = model_name_or_path
|
46 |
+
self.openai_model = False
|
47 |
+
if self.model_name_or_path == 'text-embedding-ada-002':
|
48 |
+
if not openai_api_key:
|
49 |
+
raise ValueError(f'OpenAI API key must be provided to use this model: {self.model_name_or_path}')
|
50 |
+
self.model = OpenAI(api_key=openai_api_key)
|
51 |
+
self.openai_model = True
|
52 |
+
else:
|
53 |
+
self.model = SentenceTransformer(self.model_name_or_path) if self.model_name_or_path else None
|
54 |
+
|
55 |
+
self.display_properties = ['file_name', 'page_label', 'document_title', 'page_summary', 'page_url', 'doc_id', \
|
56 |
+
'content']
|
57 |
+
|
58 |
+
def show_classes(self) -> Union[List[str], str]:
|
59 |
+
'''
|
60 |
+
Shows all available classes (indexes) on the Weaviate instance.
|
61 |
+
'''
|
62 |
+
schema = self.schema.get()
|
63 |
+
if 'classes' in schema:
|
64 |
+
return [cls['class'] for cls in schema['classes']]
|
65 |
+
else:
|
66 |
+
return "No classes found on cluster."
|
67 |
+
|
68 |
+
def show_class_info(self) -> Union[List[dict], str]:
|
69 |
+
'''
|
70 |
+
Shows all information related to the classes (indexes) on the Weaviate instance.
|
71 |
+
'''
|
72 |
+
schema = self.schema.get()
|
73 |
+
if 'classes' in schema:
|
74 |
+
return schema['classes']
|
75 |
+
else:
|
76 |
+
return "No classes found on cluster."
|
77 |
+
|
78 |
+
def show_class_properties(self, class_name: str) -> Union[dict, str]:
|
79 |
+
'''
|
80 |
+
Shows all properties of a class (index) on the Weaviate instance.
|
81 |
+
'''
|
82 |
+
classes = self.schema.get()
|
83 |
+
if classes:
|
84 |
+
all_classes = classes['classes']
|
85 |
+
for d in all_classes:
|
86 |
+
if d['class'] == class_name:
|
87 |
+
return d['properties']
|
88 |
+
return f'Class "{class_name}" not found on host'
|
89 |
+
return f'No Classes found on host'
|
90 |
+
|
91 |
+
def show_class_config(self, class_name: str) -> Union[dict, str]:
|
92 |
+
'''
|
93 |
+
Shows all configuration of a class (index) on the Weaviate instance.
|
94 |
+
'''
|
95 |
+
classes = self.schema.get()
|
96 |
+
if classes:
|
97 |
+
all_classes = classes['classes']
|
98 |
+
for d in all_classes:
|
99 |
+
if d['class'] == class_name:
|
100 |
+
return d
|
101 |
+
return f'Class "{class_name}" not found on host'
|
102 |
+
return f'No Classes found on host'
|
103 |
+
|
104 |
+
def delete_class(self, class_name: str) -> str:
|
105 |
+
'''
|
106 |
+
Deletes a class (index) on the Weaviate instance, if it exists.
|
107 |
+
'''
|
108 |
+
available = self._check_class_avialability(class_name)
|
109 |
+
if isinstance(available, bool):
|
110 |
+
if available:
|
111 |
+
self.schema.delete_class(class_name)
|
112 |
+
not_deleted = self._check_class_avialability(class_name)
|
113 |
+
if isinstance(not_deleted, bool):
|
114 |
+
if not_deleted:
|
115 |
+
return f'Class "{class_name}" was not deleted. Try again.'
|
116 |
+
else:
|
117 |
+
return f'Class "{class_name}" deleted'
|
118 |
+
return f'Class "{class_name}" deleted and there are no longer any classes on host'
|
119 |
+
return f'Class "{class_name}" not found on host'
|
120 |
+
return available
|
121 |
+
|
122 |
+
def _check_class_avialability(self, class_name: str) -> Union[bool, str]:
|
123 |
+
'''
|
124 |
+
Checks if a class (index) exists on the Weaviate instance.
|
125 |
+
'''
|
126 |
+
classes = self.schema.get()
|
127 |
+
if classes:
|
128 |
+
all_classes = classes['classes']
|
129 |
+
for d in all_classes:
|
130 |
+
if d['class'] == class_name:
|
131 |
+
return True
|
132 |
+
return False
|
133 |
+
else:
|
134 |
+
return f'No Classes found on host'
|
135 |
+
|
136 |
+
def format_response(self,
|
137 |
+
response: dict,
|
138 |
+
class_name: str
|
139 |
+
) -> List[dict]:
|
140 |
+
'''
|
141 |
+
Formats json response from Weaviate into a list of dictionaries.
|
142 |
+
Expands _additional fields if present into top-level dictionary.
|
143 |
+
'''
|
144 |
+
if response.get('errors'):
|
145 |
+
return response['errors'][0]['message']
|
146 |
+
results = []
|
147 |
+
hits = response['data']['Get'][class_name]
|
148 |
+
for d in hits:
|
149 |
+
temp = {k:v for k,v in d.items() if k != '_additional'}
|
150 |
+
if d.get('_additional'):
|
151 |
+
for key in d['_additional']:
|
152 |
+
temp[key] = d['_additional'][key]
|
153 |
+
results.append(temp)
|
154 |
+
return results
|
155 |
+
|
156 |
+
def update_ef_value(self, class_name: str, ef_value: int) -> str:
|
157 |
+
'''
|
158 |
+
Updates ef_value for a class (index) on the Weaviate instance.
|
159 |
+
'''
|
160 |
+
self.schema.update_config(class_name=class_name, config={'vectorIndexConfig': {'ef': ef_value}})
|
161 |
+
print(f'ef_value updated to {ef_value} for class {class_name}')
|
162 |
+
return self.show_class_config(class_name)['vectorIndexConfig']
|
163 |
+
|
164 |
+
def keyword_search(self,
|
165 |
+
request: str,
|
166 |
+
class_name: str,
|
167 |
+
properties: List[str]=['content'],
|
168 |
+
limit: int=10,
|
169 |
+
where_filter: dict=None,
|
170 |
+
display_properties: List[str]=None,
|
171 |
+
return_raw: bool=False) -> Union[dict, List[dict]]:
|
172 |
+
'''
|
173 |
+
Executes Keyword (BM25) search.
|
174 |
+
|
175 |
+
Args
|
176 |
+
----
|
177 |
+
query: str
|
178 |
+
User query.
|
179 |
+
class_name: str
|
180 |
+
Class (index) to search.
|
181 |
+
properties: List[str]
|
182 |
+
List of properties to search across.
|
183 |
+
limit: int=10
|
184 |
+
Number of results to return.
|
185 |
+
display_properties: List[str]=None
|
186 |
+
List of properties to return in response.
|
187 |
+
If None, returns all properties.
|
188 |
+
return_raw: bool=False
|
189 |
+
If True, returns raw response from Weaviate.
|
190 |
+
'''
|
191 |
+
display_properties = display_properties if display_properties else self.display_properties
|
192 |
+
response = (self.query
|
193 |
+
.get(class_name, display_properties)
|
194 |
+
.with_bm25(query=request, properties=properties)
|
195 |
+
.with_additional(['score', "id"])
|
196 |
+
.with_limit(limit)
|
197 |
+
)
|
198 |
+
response = response.with_where(where_filter).do() if where_filter else response.do()
|
199 |
+
if return_raw:
|
200 |
+
return response
|
201 |
+
else:
|
202 |
+
return self.format_response(response, class_name)
|
203 |
+
|
204 |
+
def vector_search(self,
|
205 |
+
request: str,
|
206 |
+
class_name: str,
|
207 |
+
limit: int=10,
|
208 |
+
where_filter: dict=None,
|
209 |
+
display_properties: List[str]=None,
|
210 |
+
return_raw: bool=False,
|
211 |
+
device: str='cuda:0' if cuda.is_available() else 'cpu'
|
212 |
+
) -> Union[dict, List[dict]]:
|
213 |
+
'''
|
214 |
+
Executes vector search using embedding model defined on instantiation
|
215 |
+
of WeaviateClient instance.
|
216 |
+
|
217 |
+
Args
|
218 |
+
----
|
219 |
+
query: str
|
220 |
+
User query.
|
221 |
+
class_name: str
|
222 |
+
Class (index) to search.
|
223 |
+
limit: int=10
|
224 |
+
Number of results to return.
|
225 |
+
display_properties: List[str]=None
|
226 |
+
List of properties to return in response.
|
227 |
+
If None, returns all properties.
|
228 |
+
return_raw: bool=False
|
229 |
+
If True, returns raw response from Weaviate.
|
230 |
+
'''
|
231 |
+
display_properties = display_properties if display_properties else self.display_properties
|
232 |
+
query_vector = self._create_query_vector(request, device=device)
|
233 |
+
response = (
|
234 |
+
self.query
|
235 |
+
.get(class_name, display_properties)
|
236 |
+
.with_near_vector({"vector": query_vector})
|
237 |
+
.with_limit(limit)
|
238 |
+
.with_additional(['distance'])
|
239 |
+
)
|
240 |
+
response = response.with_where(where_filter).do() if where_filter else response.do()
|
241 |
+
if return_raw:
|
242 |
+
return response
|
243 |
+
else:
|
244 |
+
return self.format_response(response, class_name)
|
245 |
+
|
246 |
+
def _create_query_vector(self, query: str, device: str) -> List[float]:
|
247 |
+
'''
|
248 |
+
Creates embedding vector from text query.
|
249 |
+
'''
|
250 |
+
return self.get_openai_embedding(query) if self.openai_model else self.model.encode(query, device=device).tolist()
|
251 |
+
|
252 |
+
def get_openai_embedding(self, query: str) -> List[float]:
|
253 |
+
'''
|
254 |
+
Gets embedding from OpenAI API for query.
|
255 |
+
'''
|
256 |
+
embedding = self.model.embeddings.create(input=query, model='text-embedding-ada-002').model_dump()
|
257 |
+
if embedding:
|
258 |
+
return embedding['data'][0]['embedding']
|
259 |
+
else:
|
260 |
+
raise ValueError(f'No embedding found for query: {query}')
|
261 |
+
|
262 |
+
def hybrid_search(self,
|
263 |
+
request: str,
|
264 |
+
class_name: str,
|
265 |
+
properties: List[str]=['content'],
|
266 |
+
alpha: float=0.5,
|
267 |
+
limit: int=10,
|
268 |
+
where_filter: dict=None,
|
269 |
+
display_properties: List[str]=None,
|
270 |
+
return_raw: bool=False,
|
271 |
+
device: str='cuda:0' if cuda.is_available() else 'cpu'
|
272 |
+
) -> Union[dict, List[dict]]:
|
273 |
+
'''
|
274 |
+
Executes Hybrid (BM25 + Vector) search.
|
275 |
+
|
276 |
+
Args
|
277 |
+
----
|
278 |
+
query: str
|
279 |
+
User query.
|
280 |
+
class_name: str
|
281 |
+
Class (index) to search.
|
282 |
+
properties: List[str]
|
283 |
+
List of properties to search across (using BM25)
|
284 |
+
alpha: float=0.5
|
285 |
+
Weighting factor for BM25 and Vector search.
|
286 |
+
alpha can be any number from 0 to 1, defaulting to 0.5:
|
287 |
+
alpha = 0 executes a pure keyword search method (BM25)
|
288 |
+
alpha = 0.5 weighs the BM25 and vector methods evenly
|
289 |
+
alpha = 1 executes a pure vector search method
|
290 |
+
limit: int=10
|
291 |
+
Number of results to return.
|
292 |
+
display_properties: List[str]=None
|
293 |
+
List of properties to return in response.
|
294 |
+
If None, returns all properties.
|
295 |
+
return_raw: bool=False
|
296 |
+
If True, returns raw response from Weaviate.
|
297 |
+
'''
|
298 |
+
display_properties = display_properties if display_properties else self.display_properties
|
299 |
+
query_vector = self._create_query_vector(request, device=device)
|
300 |
+
response = (
|
301 |
+
self.query
|
302 |
+
.get(class_name, display_properties)
|
303 |
+
.with_hybrid(query=request,
|
304 |
+
alpha=alpha,
|
305 |
+
vector=query_vector,
|
306 |
+
properties=properties,
|
307 |
+
fusion_type='relativeScoreFusion') #hard coded option for now
|
308 |
+
.with_additional(["score", "explainScore"])
|
309 |
+
.with_limit(limit)
|
310 |
+
)
|
311 |
+
|
312 |
+
response = response.with_where(where_filter).do() if where_filter else response.do()
|
313 |
+
if return_raw:
|
314 |
+
return response
|
315 |
+
else:
|
316 |
+
return self.format_response(response, class_name)
|
317 |
+
|
318 |
+
|
319 |
+
class WeaviateIndexer:
|
320 |
+
|
321 |
+
def __init__(self,
|
322 |
+
client: WeaviateClient,
|
323 |
+
batch_size: int=150,
|
324 |
+
num_workers: int=4,
|
325 |
+
dynamic: bool=True,
|
326 |
+
creation_time: int=5,
|
327 |
+
timeout_retries: int=3,
|
328 |
+
connection_error_retries: int=3,
|
329 |
+
callback: Callable=None,
|
330 |
+
):
|
331 |
+
'''
|
332 |
+
Class designed to batch index documents into Weaviate. Instantiating
|
333 |
+
this class will automatically configure the Weaviate batch client.
|
334 |
+
'''
|
335 |
+
self._client = client
|
336 |
+
self._callback = callback if callback else self._default_callback
|
337 |
+
|
338 |
+
self._client.batch.configure(batch_size=batch_size,
|
339 |
+
num_workers=num_workers,
|
340 |
+
dynamic=dynamic,
|
341 |
+
creation_time=creation_time,
|
342 |
+
timeout_retries=timeout_retries,
|
343 |
+
connection_error_retries=connection_error_retries,
|
344 |
+
callback=self._callback
|
345 |
+
)
|
346 |
+
|
347 |
+
def _default_callback(self, results: dict):
|
348 |
+
"""
|
349 |
+
Check batch results for errors.
|
350 |
+
|
351 |
+
Parameters
|
352 |
+
----------
|
353 |
+
results : dict
|
354 |
+
The Weaviate batch creation return value.
|
355 |
+
"""
|
356 |
+
|
357 |
+
if results is not None:
|
358 |
+
for result in results:
|
359 |
+
if "result" in result and "errors" in result["result"]:
|
360 |
+
if "error" in result["result"]["errors"]:
|
361 |
+
print(result["result"])
|
362 |
+
|
363 |
+
def batch_index_data(self,
|
364 |
+
data: List[dict],
|
365 |
+
class_name: str,
|
366 |
+
vector_property: str='content_embedding'
|
367 |
+
) -> None:
|
368 |
+
'''
|
369 |
+
Batch function for fast indexing of data onto Weaviate cluster.
|
370 |
+
This method assumes that self._client.batch is already configured.
|
371 |
+
'''
|
372 |
+
start = time.perf_counter()
|
373 |
+
with self._client.batch as batch:
|
374 |
+
for d in tqdm(data):
|
375 |
+
|
376 |
+
#define single document
|
377 |
+
properties = {k:v for k,v in d.items() if k != vector_property}
|
378 |
+
try:
|
379 |
+
#add data object to batch
|
380 |
+
batch.add_data_object(
|
381 |
+
data_object=properties,
|
382 |
+
class_name=class_name,
|
383 |
+
vector=d[vector_property]
|
384 |
+
)
|
385 |
+
except Exception as e:
|
386 |
+
print(e)
|
387 |
+
continue
|
388 |
+
|
389 |
+
end = time.perf_counter() - start
|
390 |
+
|
391 |
+
print(f'Batch job completed in {round(end/60, 2)} minutes.')
|
392 |
+
# class_info = self._client.show_class_info()
|
393 |
+
# for i, c in enumerate(class_info):
|
394 |
+
# if c['class'] == class_name:
|
395 |
+
# print(class_info[i])
|
396 |
+
self._client.batch.shutdown()
|
397 |
+
|
398 |
+
@dataclass
|
399 |
+
class WhereFilter:
|
400 |
+
|
401 |
+
'''
|
402 |
+
Simplified interface for constructing a WhereFilter object.
|
403 |
+
|
404 |
+
Args
|
405 |
+
----
|
406 |
+
path: List[str]
|
407 |
+
List of properties to filter on.
|
408 |
+
operator: str
|
409 |
+
Operator to use for filtering. Options: ['And', 'Or', 'Equal', 'NotEqual',
|
410 |
+
'GreaterThan', 'GreaterThanEqual', 'LessThan', 'LessThanEqual', 'Like',
|
411 |
+
'WithinGeoRange', 'IsNull', 'ContainsAny', 'ContainsAll']
|
412 |
+
value[dataType]: Union[int, bool, str, float, datetime]
|
413 |
+
Value to filter on. The dataType suffix must match the data type of the
|
414 |
+
property being filtered on. At least and only one value type must be provided.
|
415 |
+
'''
|
416 |
+
path: List[str]
|
417 |
+
operator: str
|
418 |
+
valueInt: int=None
|
419 |
+
valueBoolean: bool=None
|
420 |
+
valueText: str=None
|
421 |
+
valueNumber: float=None
|
422 |
+
valueDate = None
|
423 |
+
|
424 |
+
def post_init(self):
|
425 |
+
operators = ['And', 'Or', 'Equal', 'NotEqual','GreaterThan', 'GreaterThanEqual', 'LessThan',\
|
426 |
+
'LessThanEqual', 'Like', 'WithinGeoRange', 'IsNull', 'ContainsAny', 'ContainsAll']
|
427 |
+
if self.operator not in operators:
|
428 |
+
raise ValueError(f'operator must be one of: {operators}, got {self.operator}')
|
429 |
+
values = [self.valueInt, self.valueBoolean, self.valueText, self.valueNumber, self.valueDate]
|
430 |
+
if not any(values):
|
431 |
+
raise ValueError('At least one value must be provided.')
|
432 |
+
if len(values) > 1:
|
433 |
+
raise ValueError('At most one value can be provided.')
|
434 |
+
|
435 |
+
def todict(self):
|
436 |
+
return {k:v for k,v in self.__dict__.items() if v is not None}
|