import os
import google.generativeai as palm
import streamlit as st
import pandas as pd
import faiss
import hdbscan
from sklearn.feature_extraction.text import CountVectorizer
from src.modelling.topics.topic_extractor import (
TopicExtractionConfig, TopicExtractor
)
from src.modelling.topics.class_tf_idf import ClassTfidfTransformer
from src import deploy_utils
semantic_search_header = "What kind of product are you trying to sell?"
semantic_search_placeholder = "Your magic idea goes here ✨"
search_label = "Generate"
def setup_palm():
palm.configure(api_key=os.environ.get('PALM_TOKEN'))
@st.cache_data
def load_data():
reviews = pd.read_csv("data/filtered_reviews.csv").set_index("reviewID")
products = pd.read_csv("data/products.csv")
return reviews, products
def load_uncached_models():
topic_extraction_config = TopicExtractionConfig(
vectorizer_model=CountVectorizer(
ngram_range=(1, 3), stop_words="english"),
ctfidf_model=ClassTfidfTransformer(reduce_frequent_words=True),
number_of_representative_documents=5,
review_text_key="summary",
)
topic_extractor = TopicExtractor(topic_extraction_config)
clusterer = hdbscan.HDBSCAN(
min_cluster_size=5, min_samples=5, metric="precomputed")
return topic_extractor, clusterer
@st.cache_resource
def load_models():
product_model = deploy_utils.load_model("all-MiniLM-L6-v2")
reviews_model = deploy_utils.load_model(
"https://tfhub.dev/google/universal-sentence-encoder/4"
)
product_indexer = faiss.read_index("vectordb/populated.index")
return reviews_model, product_model, product_indexer
def render_cta_link(url, label, font_awesome_icon):
st.markdown(
'',
unsafe_allow_html=True,
)
button_code = f""" {label}"""
return st.markdown(button_code, unsafe_allow_html=True)
def handler_search():
relevant_products = deploy_utils.query_relevant_documents(
product_model=product_model,
indexer=product_indexer,
products=products,
query_text=st.session_state.user_search_query,
)
# TODO: check if there are relevant products
relevant_reviews = deploy_utils.get_relevant_reviews(
relevant_products, reviews)
raw_topic_assigment = deploy_utils.clusterize_reviews(
relevant_reviews, reviews_model, clusterer)
relevant_reviews["topic"] = raw_topic_assigment
reviews_with_topics = relevant_reviews[relevant_reviews["topic"] != -1]
# TODO: check if there are still topics
extracted_topics = topic_extractor(reviews_with_topics)
key_reviews = deploy_utils.get_key_reviews(
reviews_with_topics,
extracted_topics,
)
st.session_state.key_reviews = key_reviews
print('search done')
def palm_handler():
response = palm.generate_text(prompt=st.session_state.user_prompt)
st.session_state.palm_output = response
def render_search():
"""
Render the search form in the sidebar.
"""
with st.sidebar:
st.text_input(
label=semantic_search_header,
placeholder=semantic_search_placeholder,
key="user_search_query",
)
st.text_area(
placeholder="prompt here",
key="user_prompt"
)
st.button(
label=search_label,
key="location_search",
on_click=palm_handler)
st.write("---")
render_cta_link(
url="https://github.com/CamiVasz/factored-datathon-2023-almond",
label="Check the code",
font_awesome_icon="fa-github",
)
def render_results():
# TODO: temporal
st.write("# PaLM outputs")
st.write(st.session_state.palm_output)
# Execution start here!
st.set_page_config(
page_title="almond - demo",
page_icon="🔍",
layout="wide",
initial_sidebar_state="expanded",
)
setup_palm()
reviews, products = load_data()
reviews_model, product_model, product_indexer = load_models()
topic_extractor, clusterer = load_uncached_models()
render_search()
if "palm_output" in st.session_state:
render_results()