File size: 4,360 Bytes
e7af180
 
 
e602a6e
 
 
 
 
 
 
 
 
 
 
 
 
 
e7af180
e602a6e
e7af180
 
e602a6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7af180
 
 
 
 
e602a6e
 
 
 
 
 
 
 
 
 
 
e7af180
 
 
 
 
e602a6e
 
 
e7af180
e602a6e
 
 
 
 
 
 
 
 
 
 
e7af180
 
e602a6e
 
 
 
 
 
 
 
 
 
e7af180
e602a6e
 
 
 
 
e7af180
e602a6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os 

import google.generativeai as palm
import streamlit as st
import pandas as pd
import faiss
import hdbscan
from sklearn.feature_extraction.text import CountVectorizer

from src.modelling.topics.topic_extractor import (
    TopicExtractionConfig, TopicExtractor
)
from src.modelling.topics.class_tf_idf import ClassTfidfTransformer
from src import deploy_utils

semantic_search_header = "What kind of product are you trying to sell?"
semantic_search_placeholder = "Your magic idea goes here ✨"
search_label = "Generate"

def setup_palm():
    palm.configure(api_key=os.environ.get('PALM_TOKEN'))

@st.cache_data
def load_data():
    reviews = pd.read_csv("data/filtered_reviews.csv").set_index("reviewID")
    products = pd.read_csv("data/products.csv")

    return reviews, products


def load_uncached_models():
    topic_extraction_config = TopicExtractionConfig(
        vectorizer_model=CountVectorizer(
            ngram_range=(1, 3), stop_words="english"),
        ctfidf_model=ClassTfidfTransformer(reduce_frequent_words=True),
        number_of_representative_documents=5,
        review_text_key="summary",
    )

    topic_extractor = TopicExtractor(topic_extraction_config)

    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=5, min_samples=5, metric="precomputed")

    return topic_extractor, clusterer


@st.cache_resource
def load_models():
    product_model = deploy_utils.load_model("all-MiniLM-L6-v2")
    reviews_model = deploy_utils.load_model(
        "https://tfhub.dev/google/universal-sentence-encoder/4"
    )
    product_indexer = faiss.read_index("vectordb/populated.index")

    return reviews_model, product_model, product_indexer


def render_cta_link(url, label, font_awesome_icon):
    st.markdown(
        '<link rel="stylesheet" href="<https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css>">',
        unsafe_allow_html=True,
    )
    button_code = f"""<a href="{url}" target=_blank><i class="fa {font_awesome_icon}"></i> {label}</a>"""
    return st.markdown(button_code, unsafe_allow_html=True)


def handler_search():
    relevant_products = deploy_utils.query_relevant_documents(
        product_model=product_model,
        indexer=product_indexer,
        products=products,
        query_text=st.session_state.user_search_query,
    )

    # TODO: check if there are relevant products

    relevant_reviews = deploy_utils.get_relevant_reviews(
        relevant_products, reviews)

    raw_topic_assigment = deploy_utils.clusterize_reviews(
        relevant_reviews, reviews_model, clusterer)
    relevant_reviews["topic"] = raw_topic_assigment
    reviews_with_topics = relevant_reviews[relevant_reviews["topic"] != -1]

    # TODO: check if there are still topics

    extracted_topics = topic_extractor(reviews_with_topics)

    key_reviews = deploy_utils.get_key_reviews(
        reviews_with_topics,
        extracted_topics,
    )

    st.session_state.key_reviews = key_reviews
    print('search done')


def palm_handler():
    response = palm.generate_text(prompt=st.session_state.user_prompt)
    st.session_state.palm_output = response


def render_search():
    """
    Render the search form in the sidebar.
    """
    with st.sidebar:
        st.text_input(
            label=semantic_search_header,
            placeholder=semantic_search_placeholder,
            key="user_search_query",
        )

        st.text_area(
            placeholder="prompt here",
            key="user_prompt"
        )

        st.button(
            label=search_label,
            key="location_search",
            on_click=palm_handler)

        st.write("---")
        render_cta_link(
            url="https://github.com/CamiVasz/factored-datathon-2023-almond",
            label="Check the code",
            font_awesome_icon="fa-github",
        )


def render_results():
    # TODO: temporal
    st.write("# PaLM outputs")
    st.write(st.session_state.palm_output)

# Execution start here!

st.set_page_config(
    page_title="almond - demo",
    page_icon="🔍",
    layout="wide",
    initial_sidebar_state="expanded",
)

setup_palm()
reviews, products = load_data()
reviews_model, product_model, product_indexer = load_models()
topic_extractor, clusterer = load_uncached_models()

render_search()
if "palm_output" in st.session_state:
    render_results()