shpotes commited on
Commit
e602a6e
β€’
1 Parent(s): b230b07
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import faiss
4
+ import hdbscan
5
+ from sklearn.feature_extraction.text import CountVectorizer
6
+
7
+ from src.modelling.topics.topic_extractor import (
8
+ TopicExtractionConfig, TopicExtractor
9
+ )
10
+ from src.modelling.topics.class_tf_idf import ClassTfidfTransformer
11
+ from src import deploy_utils
12
+
13
+ semantic_search_header = "What kind of product are you trying to sell?"
14
+ semantic_search_placeholder = "Your magic idea goes here ✨"
15
+ search_label = "Search for similar products"
16
+
17
+
18
+ @st.cache_data
19
+ def load_data():
20
+ reviews = pd.read_csv("data/filtered_reviews.csv").set_index("reviewID")
21
+ products = pd.read_csv("data/products.csv")
22
+
23
+ return reviews, products
24
+
25
+
26
+ def load_uncached_models():
27
+ topic_extraction_config = TopicExtractionConfig(
28
+ vectorizer_model=CountVectorizer(
29
+ ngram_range=(1, 3), stop_words="english"),
30
+ ctfidf_model=ClassTfidfTransformer(reduce_frequent_words=True),
31
+ number_of_representative_documents=5,
32
+ review_text_key="summary",
33
+ )
34
+
35
+ topic_extractor = TopicExtractor(topic_extraction_config)
36
+
37
+ clusterer = hdbscan.HDBSCAN(
38
+ min_cluster_size=5, min_samples=5, metric="precomputed")
39
+
40
+ return topic_extractor, clusterer
41
+
42
+
43
+ @st.cache_resource
44
+ def load_models():
45
+ product_model = deploy_utils.load_model("all-MiniLM-L6-v2")
46
+ reviews_model = deploy_utils.load_model(
47
+ "https://tfhub.dev/google/universal-sentence-encoder/4"
48
+ )
49
+ product_indexer = faiss.read_index("vectordb/populated.index")
50
+
51
+ return reviews_model, product_model, product_indexer
52
+
53
+
54
+ def render_cta_link(url, label, font_awesome_icon):
55
+ st.markdown(
56
+ '<link rel="stylesheet" href="<https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css>">',
57
+ unsafe_allow_html=True,
58
+ )
59
+ button_code = f"""<a href="{url}" target=_blank><i class="fa {font_awesome_icon}"></i> {label}</a>"""
60
+ return st.markdown(button_code, unsafe_allow_html=True)
61
+
62
+
63
+ def handler_search():
64
+ relevant_products = deploy_utils.query_relevant_documents(
65
+ product_model=product_model,
66
+ indexer=product_indexer,
67
+ products=products,
68
+ query_text=st.session_state.user_search_query,
69
+ )
70
+
71
+ # TODO: check if there are relevant products
72
+
73
+ relevant_reviews = deploy_utils.get_relevant_reviews(
74
+ relevant_products, reviews)
75
+
76
+ raw_topic_assigment = deploy_utils.clusterize_reviews(
77
+ relevant_reviews, reviews_model, clusterer)
78
+ relevant_reviews["topic"] = raw_topic_assigment
79
+ reviews_with_topics = relevant_reviews[relevant_reviews["topic"] != -1]
80
+
81
+ # TODO: check if there are still topics
82
+
83
+ extracted_topics = topic_extractor(reviews_with_topics)
84
+
85
+ key_reviews = deploy_utils.get_key_reviews(
86
+ reviews_with_topics,
87
+ extracted_topics,
88
+ )
89
+
90
+ st.session_state.key_reviews = key_reviews
91
+ print('search done')
92
+
93
+
94
+ def render_search():
95
+ """
96
+ Render the search form in the sidebar.
97
+ """
98
+ with st.sidebar:
99
+ st.text_input(
100
+ label=semantic_search_header,
101
+ placeholder=semantic_search_placeholder,
102
+ key="user_search_query",
103
+ )
104
+
105
+ st.button(
106
+ label=search_label,
107
+ key="location_search",
108
+ on_click=handler_search)
109
+
110
+ st.write("---")
111
+ render_cta_link(
112
+ url="https://github.com/CamiVasz/factored-datathon-2023-almond",
113
+ label="Check the code",
114
+ font_awesome_icon="fa-github",
115
+ )
116
+
117
+
118
+ def render_results():
119
+ # TODO: temporal
120
+ st.write("# Relevant reviews")
121
+ for rev in st.session_state.key_reviews:
122
+ st.write(" *", rev.split("\n")[0])
123
+
124
+
125
+ # Execution start here!
126
+
127
+ st.set_page_config(
128
+ page_title="almond - demo",
129
+ page_icon="πŸ”",
130
+ layout="wide",
131
+ initial_sidebar_state="expanded",
132
+ )
133
+
134
+
135
+ reviews, products = load_data()
136
+ reviews_model, product_model, product_indexer = load_models()
137
+ topic_extractor, clusterer = load_uncached_models()
138
+
139
+ render_search()
140
+ if "key_reviews" in st.session_state:
141
+ render_results()