import os import google.generativeai as palm import streamlit as st import pandas as pd import faiss import hdbscan from sklearn.feature_extraction.text import CountVectorizer from src.modelling.topics.topic_extractor import ( TopicExtractionConfig, TopicExtractor ) from src.modelling.topics.class_tf_idf import ClassTfidfTransformer from src import deploy_utils semantic_search_header = "What kind of product are you trying to sell?" semantic_search_placeholder = "Your magic idea goes here ✨" search_label = "Generate" def setup_palm(): palm.configure(api_key=os.environ.get('PALM_TOKEN')) @st.cache_data def load_data(): reviews = pd.read_csv("data/filtered_reviews.csv").set_index("reviewID") products = pd.read_csv("data/products.csv") return reviews, products def load_uncached_models(): topic_extraction_config = TopicExtractionConfig( vectorizer_model=CountVectorizer( ngram_range=(1, 3), stop_words="english"), ctfidf_model=ClassTfidfTransformer(reduce_frequent_words=True), number_of_representative_documents=5, review_text_key="summary", ) topic_extractor = TopicExtractor(topic_extraction_config) clusterer = hdbscan.HDBSCAN( min_cluster_size=5, min_samples=5, metric="precomputed") return topic_extractor, clusterer @st.cache_resource def load_models(): product_model = deploy_utils.load_model("all-MiniLM-L6-v2") reviews_model = deploy_utils.load_model( "https://tfhub.dev/google/universal-sentence-encoder/4" ) product_indexer = faiss.read_index("vectordb/populated.index") return reviews_model, product_model, product_indexer def render_cta_link(url, label, font_awesome_icon): st.markdown( '', unsafe_allow_html=True, ) button_code = f""" {label}""" return st.markdown(button_code, unsafe_allow_html=True) def handler_search(): relevant_products = deploy_utils.query_relevant_documents( product_model=product_model, indexer=product_indexer, products=products, query_text=st.session_state.user_search_query, ) # TODO: check if there are relevant products relevant_reviews = deploy_utils.get_relevant_reviews( relevant_products, reviews) raw_topic_assigment = deploy_utils.clusterize_reviews( relevant_reviews, reviews_model, clusterer) relevant_reviews["topic"] = raw_topic_assigment reviews_with_topics = relevant_reviews[relevant_reviews["topic"] != -1] # TODO: check if there are still topics extracted_topics = topic_extractor(reviews_with_topics) key_reviews = deploy_utils.get_key_reviews( reviews_with_topics, extracted_topics, ) st.session_state.key_reviews = key_reviews print('search done') def palm_handler(): response = palm.generate_text(prompt=st.session_state.user_prompt) st.session_state.palm_output = response def render_search(): """ Render the search form in the sidebar. """ with st.sidebar: st.text_input( label=semantic_search_header, placeholder=semantic_search_placeholder, key="user_search_query", ) st.text_area( placeholder="prompt here", key="user_prompt" ) st.button( label=search_label, key="location_search", on_click=palm_handler) st.write("---") render_cta_link( url="https://github.com/CamiVasz/factored-datathon-2023-almond", label="Check the code", font_awesome_icon="fa-github", ) def render_results(): # TODO: temporal st.write("# PaLM outputs") st.write(st.session_state.palm_output) # Execution start here! st.set_page_config( page_title="almond - demo", page_icon="🔍", layout="wide", initial_sidebar_state="expanded", ) setup_palm() reviews, products = load_data() reviews_model, product_model, product_indexer = load_models() topic_extractor, clusterer = load_uncached_models() render_search() if "palm_output" in st.session_state: render_results()