Arko Banik commited on
Commit
33950f7
·
1 Parent(s): 060a952

add app file

Browse files
Files changed (2) hide show
  1. miniproject1_part4-2-1.py +391 -0
  2. requirements.txt +3 -0
miniproject1_part4-2-1.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import numpy.linalg as la
4
+ import pickle
5
+ import os
6
+ import gdown
7
+ from sentence_transformers import SentenceTransformer
8
+ import matplotlib.pyplot as plt
9
+ import math
10
+
11
+
12
+ # Compute Cosine Similarity
13
+ def cosine_similarity(x, y):
14
+ """
15
+ Exponentiated cosine similarity
16
+ 1. Compute cosine similarity
17
+ 2. Exponentiate cosine similarity
18
+ 3. Return exponentiated cosine similarity
19
+ (20 pts)
20
+ """
21
+ ##################################
22
+ ### TODO: Add code here ##########
23
+ ##################################
24
+ pass
25
+
26
+
27
+ # Function to Load Glove Embeddings
28
+ def load_glove_embeddings(glove_path="Data/embeddings.pkl"):
29
+ with open(glove_path, "rb") as f:
30
+ embeddings_dict = pickle.load(f, encoding="latin1")
31
+
32
+ return embeddings_dict
33
+
34
+
35
+ def get_model_id_gdrive(model_type):
36
+ if model_type == "25d":
37
+ word_index_id = "13qMXs3-oB9C6kfSRMwbAtzda9xuAUtt8"
38
+ embeddings_id = "1-RXcfBvWyE-Av3ZHLcyJVsps0RYRRr_2"
39
+ elif model_type == "50d":
40
+ embeddings_id = "1DBaVpJsitQ1qxtUvV1Kz7ThDc3az16kZ"
41
+ word_index_id = "1rB4ksHyHZ9skes-fJHMa2Z8J1Qa7awQ9"
42
+ elif model_type == "100d":
43
+ word_index_id = "1-oWV0LqG3fmrozRZ7WB1jzeTJHRUI3mq"
44
+ embeddings_id = "1SRHfX130_6Znz7zbdfqboKosz-PfNvNp"
45
+
46
+ return word_index_id, embeddings_id
47
+
48
+
49
+ def download_glove_embeddings_gdrive(model_type):
50
+ # Get glove embeddings from google drive
51
+ word_index_id, embeddings_id = get_model_id_gdrive(model_type)
52
+
53
+ # Use gdown to get files from google drive
54
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
55
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
56
+
57
+ # Download word_index pickle file
58
+ print("Downloading word index dictionary....\n")
59
+ gdown.download(id=word_index_id, output=word_index_temp, quiet=False)
60
+
61
+ # Download embeddings numpy file
62
+ print("Donwloading embedings...\n\n")
63
+ gdown.download(id=embeddings_id, output=embeddings_temp, quiet=False)
64
+
65
+
66
+ # @st.cache_data()
67
+ def load_glove_embeddings_gdrive(model_type):
68
+ word_index_temp = "word_index_dict_" + str(model_type) + "_temp.pkl"
69
+ embeddings_temp = "embeddings_" + str(model_type) + "_temp.npy"
70
+
71
+ # Load word index dictionary
72
+ word_index_dict = pickle.load(open(word_index_temp, "rb"), encoding="latin")
73
+
74
+ # Load embeddings numpy
75
+ embeddings = np.load(embeddings_temp)
76
+
77
+ return word_index_dict, embeddings
78
+
79
+
80
+ @st.cache_resource()
81
+ def load_sentence_transformer_model(model_name):
82
+ sentenceTransformer = SentenceTransformer(model_name)
83
+ return sentenceTransformer
84
+
85
+
86
+ def get_sentence_transformer_embeddings(sentence, model_name="all-MiniLM-L6-v2"):
87
+ """
88
+ Get sentence transformer embeddings for a sentence
89
+ """
90
+ # 384 dimensional embedding
91
+ # Default model: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
92
+
93
+ sentenceTransformer = load_sentence_transformer_model(model_name)
94
+
95
+ try:
96
+ return sentenceTransformer.encode(sentence)
97
+ except:
98
+ if model_name == "all-MiniLM-L6-v2":
99
+ return np.zeros(384)
100
+ else:
101
+ return np.zeros(512)
102
+
103
+
104
+ def get_glove_embeddings(word, word_index_dict, embeddings, model_type):
105
+ """
106
+ Get glove embedding for a single word
107
+ """
108
+ if word.lower() in word_index_dict:
109
+ return embeddings[word_index_dict[word.lower()]]
110
+ else:
111
+ return np.zeros(int(model_type.split("d")[0]))
112
+
113
+
114
+ def averaged_glove_embeddings_gdrive(sentence, word_index_dict, embeddings, model_type=50):
115
+ """
116
+ Get averaged glove embeddings for a sentence
117
+ 1. Split sentence into words
118
+ 2. Get embeddings for each word
119
+ 3. Add embeddings for each word
120
+ 4. Divide by number of words
121
+ 5. Return averaged embeddings
122
+ (30 pts)
123
+ """
124
+ embedding = np.zeros(int(model_type.split("d")[0]))
125
+ ##################################
126
+ ##### TODO: Add code here ########
127
+ ##################################
128
+
129
+
130
+ def get_category_embeddings(embeddings_metadata):
131
+ """
132
+ Get embeddings for each category
133
+ 1. Split categories into words
134
+ 2. Get embeddings for each word
135
+ """
136
+ model_name = embeddings_metadata["model_name"]
137
+ st.session_state["cat_embed_" + model_name] = {}
138
+ for category in st.session_state.categories.split(" "):
139
+ if model_name:
140
+ if not category in st.session_state["cat_embed_" + model_name]:
141
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category, model_name=model_name)
142
+ else:
143
+ if not category in st.session_state["cat_embed_" + model_name]:
144
+ st.session_state["cat_embed_" + model_name][category] = get_sentence_transformer_embeddings(category)
145
+
146
+
147
+ def update_category_embeddings(embedings_metadata):
148
+ """
149
+ Update embeddings for each category
150
+ """
151
+ get_category_embeddings(embeddings_metadata)
152
+
153
+
154
+ def get_sorted_cosine_similarity(embeddings_metadata):
155
+ """
156
+ Get sorted cosine similarity between input sentence and categories
157
+ Steps:
158
+ 1. Get embeddings for input sentence
159
+ 2. Get embeddings for categories (if not found, update category embeddings)
160
+ 3. Compute cosine similarity between input sentence and categories
161
+ 4. Sort cosine similarity
162
+ 5. Return sorted cosine similarity
163
+ (50 pts)
164
+ """
165
+ categories = st.session_state.categories.split(" ")
166
+ cosine_sim = {}
167
+ if embeddings_metadata["embedding_model"] == "glove":
168
+ word_index_dict = embeddings_metadata["word_index_dict"]
169
+ embeddings = embeddings_metadata["embeddings"]
170
+ model_type = embeddings_metadata["model_type"]
171
+
172
+ input_embedding = averaged_glove_embeddings_gdrive(st.session_state.text_search,
173
+ word_index_dict,
174
+ embeddings, model_type)
175
+
176
+ ##########################################
177
+ ## TODO: Get embeddings for categories ###
178
+ ##########################################
179
+
180
+ else:
181
+ model_name = embeddings_metadata["model_name"]
182
+ if not "cat_embed_" + model_name in st.session_state:
183
+ get_category_embeddings(embeddings_metadata)
184
+
185
+ category_embeddings = st.session_state["cat_embed_" + model_name]
186
+
187
+ print("text_search = ", st.session_state.text_search)
188
+ if model_name:
189
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search, model_name=model_name)
190
+ else:
191
+ input_embedding = get_sentence_transformer_embeddings(st.session_state.text_search)
192
+ for index in range(len(categories)):
193
+ pass
194
+ ##########################################
195
+ # TODO: Compute cosine similarity between input sentence and categories
196
+ # TODO: Update category embeddings if category not found
197
+ ##########################################
198
+
199
+ return
200
+
201
+
202
+ def plot_piechart(sorted_cosine_scores_items):
203
+ sorted_cosine_scores = np.array([
204
+ sorted_cosine_scores_items[index][1]
205
+ for index in range(len(sorted_cosine_scores_items))
206
+ ]
207
+ )
208
+ categories = st.session_state.categories.split(" ")
209
+ categories_sorted = [
210
+ categories[sorted_cosine_scores_items[index][0]]
211
+ for index in range(len(sorted_cosine_scores_items))
212
+ ]
213
+ fig, ax = plt.subplots()
214
+ ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
215
+ st.pyplot(fig) # Figure
216
+
217
+
218
+ def plot_piechart_helper(sorted_cosine_scores_items):
219
+ sorted_cosine_scores = np.array(
220
+ [
221
+ sorted_cosine_scores_items[index][1]
222
+ for index in range(len(sorted_cosine_scores_items))
223
+ ]
224
+ )
225
+ categories = st.session_state.categories.split(" ")
226
+ categories_sorted = [
227
+ categories[sorted_cosine_scores_items[index][0]]
228
+ for index in range(len(sorted_cosine_scores_items))
229
+ ]
230
+ fig, ax = plt.subplots(figsize=(3, 3))
231
+ my_explode = np.zeros(len(categories_sorted))
232
+ my_explode[0] = 0.2
233
+ if len(categories_sorted) == 3:
234
+ my_explode[1] = 0.1 # explode this by 0.2
235
+ elif len(categories_sorted) > 3:
236
+ my_explode[2] = 0.05
237
+ ax.pie(
238
+ sorted_cosine_scores,
239
+ labels=categories_sorted,
240
+ autopct="%1.1f%%",
241
+ explode=my_explode,
242
+ )
243
+
244
+ return fig
245
+
246
+
247
+ def plot_piecharts(sorted_cosine_scores_models):
248
+ scores_list = []
249
+ categories = st.session_state.categories.split(" ")
250
+ index = 0
251
+ for model in sorted_cosine_scores_models:
252
+ scores_list.append(sorted_cosine_scores_models[model])
253
+ # scores_list[index] = np.array([scores_list[index][ind2][1] for ind2 in range(len(scores_list[index]))])
254
+ index += 1
255
+
256
+ if len(sorted_cosine_scores_models) == 2:
257
+ fig, (ax1, ax2) = plt.subplots(2)
258
+
259
+ categories_sorted = [
260
+ categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))
261
+ ]
262
+ sorted_scores = np.array(
263
+ [scores_list[0][index][1] for index in range(len(scores_list[0]))]
264
+ )
265
+ ax1.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
266
+
267
+ categories_sorted = [
268
+ categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))
269
+ ]
270
+ sorted_scores = np.array(
271
+ [scores_list[1][index][1] for index in range(len(scores_list[1]))]
272
+ )
273
+ ax2.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
274
+
275
+ st.pyplot(fig)
276
+
277
+
278
+ def plot_alatirchart(sorted_cosine_scores_models):
279
+ models = list(sorted_cosine_scores_models.keys())
280
+ tabs = st.tabs(models)
281
+ figs = {}
282
+ for model in models:
283
+ figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
284
+
285
+ for index in range(len(tabs)):
286
+ with tabs[index]:
287
+ st.pyplot(figs[models[index]])
288
+
289
+
290
+ ### Text Search ###
291
+ st.sidebar.title("GloVe Twitter")
292
+ st.sidebar.markdown(
293
+ """
294
+ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
295
+ 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
296
+
297
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
298
+ """
299
+ )
300
+
301
+ model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
302
+
303
+
304
+ st.title("Search Based Retrieval Demo")
305
+ st.subheader(
306
+ "Pass in space separated categories you want this search demo to be about."
307
+ )
308
+ # st.selectbox(label="Pick the categories you want this search demo to be about...",
309
+ # options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"),
310
+ # key="categories"
311
+ # )
312
+ st.text_input(
313
+ label="Categories", key="categories", value="Flowers Colors Cars Weather Food"
314
+ )
315
+ print(st.session_state["categories"])
316
+ print(type(st.session_state["categories"]))
317
+ # print("Categories = ", categories)
318
+ # st.session_state.categories = categories
319
+
320
+ st.subheader("Pass in an input word or even a sentence")
321
+ text_search = st.text_input(
322
+ label="Input your sentence",
323
+ key="text_search",
324
+ value="Roses are red, trucks are blue, and Seattle is grey right now",
325
+ )
326
+ # st.session_state.text_search = text_search
327
+
328
+ # Download glove embeddings if it doesn't exist
329
+ embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
330
+ word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
331
+ if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
332
+ print("Model type = ", model_type)
333
+ glove_path = "Data/glove_" + str(model_type) + ".pkl"
334
+ print("glove_path = ", glove_path)
335
+
336
+ # Download embeddings from google drive
337
+ with st.spinner("Downloading glove embeddings..."):
338
+ download_glove_embeddings_gdrive(model_type)
339
+
340
+
341
+ # Load glove embeddings
342
+ word_index_dict, embeddings = load_glove_embeddings_gdrive(model_type)
343
+
344
+
345
+ # Find closest word to an input word
346
+ if st.session_state.text_search:
347
+ # Glove embeddings
348
+ print("Glove Embedding")
349
+ embeddings_metadata = {
350
+ "embedding_model": "glove",
351
+ "word_index_dict": word_index_dict,
352
+ "embeddings": embeddings,
353
+ "model_type": model_type,
354
+ }
355
+ with st.spinner("Obtaining Cosine similarity for Glove..."):
356
+ sorted_cosine_sim_glove = get_sorted_cosine_similarity(
357
+ st.session_state.text_search, embeddings_metadata
358
+ )
359
+
360
+ # Sentence transformer embeddings
361
+ print("Sentence Transformer Embedding")
362
+ embeddings_metadata = {"embedding_model": "transformers", "model_name": ""}
363
+ with st.spinner("Obtaining Cosine similarity for 384d sentence transformer..."):
364
+ sorted_cosine_sim_transformer = get_sorted_cosine_similarity(
365
+ st.session_state.text_search, embeddings_metadata
366
+ )
367
+
368
+ # Results and Plot Pie Chart for Glove
369
+ print("Categories are: ", st.session_state.categories)
370
+ st.subheader(
371
+ "Closest word I have between: "
372
+ + st.session_state.categories
373
+ + " as per different Embeddings"
374
+ )
375
+
376
+ print(sorted_cosine_sim_glove)
377
+ print(sorted_cosine_sim_transformer)
378
+ # print(sorted_distilbert)
379
+ # Altair Chart for all models
380
+ plot_alatirchart(
381
+ {
382
+ "glove_" + str(model_type): sorted_cosine_sim_glove,
383
+ "sentence_transformer_384": sorted_cosine_sim_transformer,
384
+ }
385
+ )
386
+ # "distilbert_512": sorted_distilbert})
387
+
388
+ st.write("")
389
+ st.write(
390
+ "Demo developed by [Dr. Karthik Mohan](https://www.linkedin.com/in/karthik-mohan-72a4b323/)"
391
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ gdown
3
+ sentence_transformers