EdwardXu commited on
Commit
8ed8a0a
1 Parent(s): 5dfb73d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py CHANGED
@@ -276,3 +276,205 @@ Create a simple streamlit or equivalent webapp like the link in 5.
276
  This is also part of your Mini-Project 1!
277
  """
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  This is also part of your Mini-Project 1!
277
  """
278
 
279
+ def plot_piechart(sorted_cosine_scores_items):
280
+ sorted_cosine_scores = np.array([
281
+ sorted_cosine_scores_items[index][1]
282
+ for index in range(len(sorted_cosine_scores_items))
283
+ ]
284
+ )
285
+ categories = st.session_state.categories.split(" ")
286
+ categories_sorted = [
287
+ categories[sorted_cosine_scores_items[index][0]]
288
+ for index in range(len(sorted_cosine_scores_items))
289
+ ]
290
+ fig, ax = plt.subplots()
291
+ ax.pie(sorted_cosine_scores, labels=categories_sorted, autopct="%1.1f%%")
292
+ st.pyplot(fig) # Figure
293
+
294
+
295
+ def plot_piechart_helper(sorted_cosine_scores_items):
296
+ sorted_cosine_scores = np.array(
297
+ [
298
+ sorted_cosine_scores_items[index][1]
299
+ for index in range(len(sorted_cosine_scores_items))
300
+ ]
301
+ )
302
+ categories = st.session_state.categories.split(" ")
303
+ categories_sorted = [
304
+ categories[sorted_cosine_scores_items[index][0]]
305
+ for index in range(len(sorted_cosine_scores_items))
306
+ ]
307
+ fig, ax = plt.subplots(figsize=(3, 3))
308
+ my_explode = np.zeros(len(categories_sorted))
309
+ my_explode[0] = 0.2
310
+ if len(categories_sorted) == 3:
311
+ my_explode[1] = 0.1 # explode this by 0.2
312
+ elif len(categories_sorted) > 3:
313
+ my_explode[2] = 0.05
314
+ ax.pie(
315
+ sorted_cosine_scores,
316
+ labels=categories_sorted,
317
+ autopct="%1.1f%%",
318
+ explode=my_explode,
319
+ )
320
+
321
+ return fig
322
+
323
+
324
+ def plot_piecharts(sorted_cosine_scores_models):
325
+ scores_list = []
326
+ categories = st.session_state.categories.split(" ")
327
+ index = 0
328
+ for model in sorted_cosine_scores_models:
329
+ scores_list.append(sorted_cosine_scores_models[model])
330
+ # scores_list[index] = np.array([scores_list[index][ind2][1] for ind2 in range(len(scores_list[index]))])
331
+ index += 1
332
+
333
+ if len(sorted_cosine_scores_models) == 2:
334
+ fig, (ax1, ax2) = plt.subplots(2)
335
+
336
+ categories_sorted = [
337
+ categories[scores_list[0][index][0]] for index in range(len(scores_list[0]))
338
+ ]
339
+ sorted_scores = np.array(
340
+ [scores_list[0][index][1] for index in range(len(scores_list[0]))]
341
+ )
342
+ ax1.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
343
+
344
+ categories_sorted = [
345
+ categories[scores_list[1][index][0]] for index in range(len(scores_list[1]))
346
+ ]
347
+ sorted_scores = np.array(
348
+ [scores_list[1][index][1] for index in range(len(scores_list[1]))]
349
+ )
350
+ ax2.pie(sorted_scores, labels=categories_sorted, autopct="%1.1f%%")
351
+
352
+ st.pyplot(fig)
353
+
354
+
355
+ def plot_alatirchart(sorted_cosine_scores_models):
356
+ models = list(sorted_cosine_scores_models.keys())
357
+ tabs = st.tabs(models)
358
+ figs = {}
359
+ for model in models:
360
+ figs[model] = plot_piechart_helper(sorted_cosine_scores_models[model])
361
+
362
+ for index in range(len(tabs)):
363
+ with tabs[index]:
364
+ st.pyplot(figs[models[index]])
365
+
366
+
367
+ ### Text Search ###
368
+ st.sidebar.title("GloVe Twitter")
369
+ st.sidebar.markdown(
370
+ """
371
+ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
372
+ 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
373
+
374
+ Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
375
+ """
376
+ )
377
+
378
+ # initialize Session State variable
379
+ if 'categories' not in st.session_state:
380
+ st.session_state['categories'] = "Flowers Colors Cars Weather Food"
381
+ if 'text_search' not in st.session_state:
382
+ st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
383
+
384
+ model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
385
+
386
+
387
+ st.title("In Class practice 1 demo")
388
+ st.subheader(
389
+ "Pass in space separated categories you want this search demo to be about."
390
+ )
391
+ # st.selectbox(label="Pick the categories you want this search demo to be about...",
392
+ # options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"),
393
+ # key="categories"
394
+ # )
395
+
396
+
397
+ # categories of user input
398
+ categories = st.text_input(
399
+ label="Categories", value=st.session_state.categories
400
+ )
401
+
402
+ st.session_state.categories = categories.split(" ")
403
+
404
+ print(st.session_state.get("categories"))
405
+ print(type(st.session_state.get("categories")))
406
+ # print("Categories = ", categories)
407
+ # st.session_state.categories = categories
408
+
409
+ st.subheader("Pass in an input word or even a sentence")
410
+ text_search = st.text_input(
411
+ label="Input your sentence",
412
+ st.session_state.text_search,
413
+ )
414
+
415
+ st.session_state.text_search = text_search
416
+
417
+ # Download glove embeddings if it doesn't exist
418
+ embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
419
+ word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
420
+ if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
421
+ print("Model type = ", model_type)
422
+ glove_path = "Data/glove_" + str(model_type) + ".pkl"
423
+ print("glove_path = ", glove_path)
424
+
425
+ # Download embeddings from google drive
426
+ with st.spinner("Downloading glove embeddings..."):
427
+ download_glove_embeddings_gdrive(model_type)
428
+
429
+
430
+ # Load glove embeddings
431
+ word_index_dict, embeddings = embeddings_model.load_glove_embeddings(model_type)
432
+
433
+ category_embeddings = {category: embeddings_model.get_sentence_transformer_embedding(category) for category in
434
+ st.session_state.categories}
435
+
436
+ search_using_cos = Search(embeddings_model)
437
+
438
+ # Find closest word to an input word
439
+ if st.session_state.get("text_search"):
440
+ # sentence transformer Embedding
441
+ print("sentence transformer Embedding")
442
+ embeddings_metadata = {
443
+ "word_index_dict": word_index_dict,
444
+ "embeddings": embeddings,
445
+ "model_type": model_type,
446
+ "text_search": st.session_state.text_search
447
+ }
448
+ with st.spinner("Obtaining Cosine similarity ..."):
449
+ sorted_cosine_sim_transformer = search_using_cos.get_topK_similar_categories(
450
+ st.session_state.text_search, category_embeddings
451
+ )
452
+
453
+
454
+
455
+ # Results and Plot Pie Chart for Glove
456
+ print("Categories are: ", st.session_state.categories)
457
+ st.subheader(
458
+ "Closest word I have between: "
459
+ + st.session_state.categories
460
+ + " as per different Embeddings"
461
+ )
462
+
463
+ # print(sorted_cosine_sim_glove)
464
+ print(sorted_cosine_sim_transformer)
465
+ print(list(sorted_cosine_sim_transformer.keys())[0])
466
+
467
+ st.write(
468
+ f"Closest category using sentence transformer embeddings : {list(sorted_cosine_sim_transformer.keys())[0]}")
469
+
470
+ plot_alatirchart(
471
+ {
472
+ "sentence_transformer_384": sorted_cosine_sim_transformer,
473
+ }
474
+ )
475
+
476
+
477
+ st.write("")
478
+ st.write(
479
+ "Demo developed by Edward Xu"
480
+ )