search_demo / app.py
bibliotecadebabel
first commit
37c2a8d
raw
history blame
5.25 kB
import torch
import src.constants.config as configurations
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from src.constants.credentials import cohere_trial_key
import streamlit as st
from src.reader import Reader
from src.utils_search import UtilsSearch
from copy import deepcopy
import numpy as np
import cohere
configurations = configurations.service_mxbai_msc_direct_config
api_key = cohere_trial_key
co = cohere.Client(api_key)
semantic_column_names = configurations["semantic_column_names"]
# Check CUDA availability and set device
if torch.cuda.is_available():
torch.cuda.set_device(0) # Use the first GPU
else:
st.write("CUDA is not available. Using CPU instead.")
@st.cache_data
def init():
config = configurations
search_utils = UtilsSearch(config)
reader = Reader(config=config["reader_config"])
model = SentenceTransformer(config['sentence_transformer_name'], device='cuda:0')
cross_encoder = CrossEncoder(config['cross_encoder_name'], device='cuda:0')
df = reader.read()
index = search_utils.dataframe_to_index(df)
return df, model, cross_encoder, index, search_utils
def get_possible_values_for_column(column_name, search_utils, df):
if column_name not in st.session_state:
setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name))
return getattr(st.session_state, column_name)
# Initialize or retrieve from session state
if 'init_results' not in st.session_state:
st.session_state.init_results = init()
# Now you can access your initialized objects directly from the session state
df, model, cross_encoder, index, search_utils = st.session_state.init_results
# Streamlit app layout
st.title('Search Demo')
# Input fields
query = st.text_input('Enter your search query here')
use_cohere = st.checkbox('Use Cohere', value=False) # Default to checked
programmatic_search_config = deepcopy(configurations['programmatic_search_config'])
dynamic_programmatic_search_config = {
"scalar_columns": [],
"discrete_columns": []
}
for column in programmatic_search_config['scalar_columns']:
# Create number input for scalar values
col_name = column["column_name"]
min_val = float(column["min_value"])
max_val = float(column["max_value"])
user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val)
user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val)
dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max})
for column in programmatic_search_config['discrete_columns']:
# Create multiselect for discrete values
col_name = column["column_name"]
default_values = column["default_values"]
# Assuming you have a function to fetch possible values for the discrete columns based on the column name
possible_values = get_possible_values_for_column(col_name, search_utils, df) # Implement this function based on your application
selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values)
dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values})
programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns']
programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns']
# Search button
if st.button('Search'):
if query: # Checking if a query was entered
df_filtered = search_utils.filter_dataframe(df, programmatic_search_config)
if len(df_filtered) == 0:
st.write('No results found')
else:
index = search_utils.dataframe_to_index(df_filtered)
if use_cohere == False:
# Call your Cohere-based search function here
results_df = search_utils.search(query, df_filtered, model, cross_encoder, index)
results_df = search_utils.drop_columns(results_df, programmatic_search_config)
else:
df_retrieved = search_utils.retrieve(query, df_filtered, model, index)
df_retrieved = search_utils.drop_columns(df_retrieved, programmatic_search_config)
df_retrieved.fillna(value="", inplace=True)
docs = df_retrieved.to_dict('records')
column_names = semantic_column_names
docs = [{name: str(doc[name]) for name in column_names} for doc in docs]
rank_fields = list(docs[0].keys())
results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0',
rank_fields=rank_fields)
top_ids = [hit.index for hit in results.results]
# Create the DataFrame with the rerank results
results_df = df_retrieved.iloc[top_ids].copy()
results_df['rank'] = (np.arange(len(results_df)) + 1)
st.write(results_df)
else:
st.write("Please enter a query to search.")