EdwardXu commited on
Commit
0e7c08f
1 Parent(s): 8ed8a0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -101
app.py CHANGED
@@ -196,29 +196,6 @@ class Search:
196
  similarity_scores[choice] = similarity
197
  return similarity_scores
198
 
199
- """## 3. Word Arithmetic
200
-
201
- Let's test your embeddings. Answer the question below through the search functionality you implemented above
202
- """
203
-
204
- embeddings_model = Embeddings()
205
- search_using_cos = Search(embeddings_model)
206
-
207
- word_index_dict, embeddings = embeddings_model.load_glove_embeddings(50)
208
-
209
- current_embedding = embeddings_model.embeddings_preprocess( word_index_dict, ["king", "woman"], ["man"], embeddings)
210
-
211
- closest_word = search_using_cos.find_closest_words(current_embedding, ["girl", "queen", "princess", "daughter", "mother"], word_index_dict, embeddings )
212
-
213
- print("'King - Man + Woman':", closest_word)
214
-
215
-
216
- word_index_dict, embeddings = embeddings_model.load_glove_embeddings(50)
217
-
218
-
219
- closest_word = search_using_cos.find_word_as( ("tesla", "car"), "apple", ["fruit", "vegetable", "gas"], word_index_dict, embeddings)
220
-
221
- print("'Tesla:Car as Apple:?': ", closest_word)
222
 
223
  """## 4. Plots
224
 
@@ -238,43 +215,9 @@ def plot_pie_chart(category_similarity_scores):
238
  ax.axis('equal') # Equal aspect ratio ensures the pie chart is circular.
239
  plt.show()
240
 
241
- word_index_dict, embeddings = embeddings_model.load_glove_embeddings(50)
242
-
243
- # Find the word closest to the vector resulting from "king" - "man" + "woman"
244
- current_embedding = embeddings_model.embeddings_preprocess(word_index_dict, ["king", "woman"], ["man"], embeddings)
245
-
246
- # Calculate similarity scores for a set of words and plot them
247
- sim_scores = search_using_cos.find_similarity_scores(current_embedding, ["girl", "queen", "princess", "daughter", "mother"], word_index_dict, embeddings)
248
- plot_pie_chart(sim_scores)
249
-
250
- """## 5. Test
251
-
252
- Test your pie chart against some of the examples in the demo listed here:
253
-
254
- https://categorysearch.streamlit.app or
255
- https://searchdemo.streamlit.app
256
 
257
- a) Do the results make sense?
258
- b) Which embedding gives more meaningful results?
259
 
260
- """
261
-
262
- input_sentence = "Roses are red, trucks are blue, and Seattle is grey right now"
263
- category_names = ["Flowers", "Colors", "Cars", "Weather", "Food"]
264
-
265
- embeddings_model = Embeddings()
266
- word_index_dict, embeddings = embeddings_model.load_glove_embeddings(50)
267
- categories_embedding = {category: embeddings_model.get_sentence_transformer_embedding(category) for category in category_names}
268
-
269
- search_instance = Search(embeddings_model)
270
- category_similarity_scores = search_instance.get_topK_similar_categories(input_sentence, categories_embedding)
271
 
272
- plot_pie_chart(category_similarity_scores) # Plot and see
273
-
274
- """## 6. Bonus (if time permits)!
275
- Create a simple streamlit or equivalent webapp like the link in 5.
276
- This is also part of your Mini-Project 1!
277
- """
278
 
279
  def plot_piechart(sorted_cosine_scores_items):
280
  sorted_cosine_scores = np.array([
@@ -365,67 +308,41 @@ def plot_alatirchart(sorted_cosine_scores_models):
365
 
366
 
367
  ### Text Search ###
368
- st.sidebar.title("GloVe Twitter")
369
- st.sidebar.markdown(
370
- """
371
- GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Pretrained on
372
- 2 billion tweets with vocabulary size of 1.2 million. Download from [Stanford NLP](http://nlp.stanford.edu/data/glove.twitter.27B.zip).
373
 
374
- Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. *GloVe: Global Vectors for Word Representation*.
375
- """
376
- )
377
 
378
- # initialize Session State variable
379
  if 'categories' not in st.session_state:
380
  st.session_state['categories'] = "Flowers Colors Cars Weather Food"
381
  if 'text_search' not in st.session_state:
382
  st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
383
 
384
- model_type = st.sidebar.selectbox("Choose the model", ("25d", "50d"), index=1)
385
 
 
386
 
387
- st.title("In Class practice 1 demo")
388
  st.subheader(
389
  "Pass in space separated categories you want this search demo to be about."
390
  )
391
- # st.selectbox(label="Pick the categories you want this search demo to be about...",
392
- # options=("Flowers Colors Cars Weather Food", "Chocolate Milk", "Anger Joy Sad Frustration Worry Happiness", "Positive Negative"),
393
- # key="categories"
394
- # )
395
-
396
 
397
  # categories of user input
398
- categories = st.text_input(
399
  label="Categories", value=st.session_state.categories
400
  )
401
 
402
- st.session_state.categories = categories.split(" ")
403
 
404
  print(st.session_state.get("categories"))
 
405
  print(type(st.session_state.get("categories")))
406
- # print("Categories = ", categories)
407
- # st.session_state.categories = categories
408
 
409
  st.subheader("Pass in an input word or even a sentence")
410
- text_search = st.text_input(
411
  label="Input your sentence",
412
- st.session_state.text_search,
413
  )
414
 
415
- st.session_state.text_search = text_search
416
-
417
- # Download glove embeddings if it doesn't exist
418
- embeddings_path = "embeddings_" + str(model_type) + "_temp.npy"
419
- word_index_dict_path = "word_index_dict_" + str(model_type) + "_temp.pkl"
420
- if not os.path.isfile(embeddings_path) or not os.path.isfile(word_index_dict_path):
421
- print("Model type = ", model_type)
422
- glove_path = "Data/glove_" + str(model_type) + ".pkl"
423
- print("glove_path = ", glove_path)
424
-
425
- # Download embeddings from google drive
426
- with st.spinner("Downloading glove embeddings..."):
427
- download_glove_embeddings_gdrive(model_type)
428
-
429
 
430
  # Load glove embeddings
431
  word_index_dict, embeddings = embeddings_model.load_glove_embeddings(model_type)
@@ -436,8 +353,8 @@ category_embeddings = {category: embeddings_model.get_sentence_transformer_embed
436
  search_using_cos = Search(embeddings_model)
437
 
438
  # Find closest word to an input word
439
- if st.session_state.get("text_search"):
440
- # sentence transformer Embedding
441
  print("sentence transformer Embedding")
442
  embeddings_metadata = {
443
  "word_index_dict": word_index_dict,
@@ -445,18 +362,16 @@ if st.session_state.get("text_search"):
445
  "model_type": model_type,
446
  "text_search": st.session_state.text_search
447
  }
448
- with st.spinner("Obtaining Cosine similarity ..."):
449
  sorted_cosine_sim_transformer = search_using_cos.get_topK_similar_categories(
450
  st.session_state.text_search, category_embeddings
451
  )
452
 
453
-
454
-
455
  # Results and Plot Pie Chart for Glove
456
  print("Categories are: ", st.session_state.categories)
457
  st.subheader(
458
  "Closest word I have between: "
459
- + st.session_state.categories
460
  + " as per different Embeddings"
461
  )
462
 
@@ -466,13 +381,13 @@ if st.session_state.get("text_search"):
466
 
467
  st.write(
468
  f"Closest category using sentence transformer embeddings : {list(sorted_cosine_sim_transformer.keys())[0]}")
469
-
470
  plot_alatirchart(
471
  {
472
  "sentence_transformer_384": sorted_cosine_sim_transformer,
473
  }
474
  )
475
-
476
 
477
  st.write("")
478
  st.write(
 
196
  similarity_scores[choice] = similarity
197
  return similarity_scores
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  """## 4. Plots
201
 
 
215
  ax.axis('equal') # Equal aspect ratio ensures the pie chart is circular.
216
  plt.show()
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
 
 
219
 
 
 
 
 
 
 
 
 
 
 
 
220
 
 
 
 
 
 
 
221
 
222
  def plot_piechart(sorted_cosine_scores_items):
223
  sorted_cosine_scores = np.array([
 
308
 
309
 
310
  ### Text Search ###
311
+ st.sidebar.title("sentence transformer")
 
 
 
 
312
 
 
 
 
313
 
 
314
  if 'categories' not in st.session_state:
315
  st.session_state['categories'] = "Flowers Colors Cars Weather Food"
316
  if 'text_search' not in st.session_state:
317
  st.session_state['text_search'] = "Roses are red, trucks are blue, and Seattle is grey right now"
318
 
319
+ embeddings_model = Embeddings()
320
 
321
+ model_type = st.sidebar(("50d"), index=1)
322
 
323
+ st.title("in in-class coding practice1 Demo")
324
  st.subheader(
325
  "Pass in space separated categories you want this search demo to be about."
326
  )
 
 
 
 
 
327
 
328
  # categories of user input
329
+ user_categories = st.text_input(
330
  label="Categories", value=st.session_state.categories
331
  )
332
 
333
+ st.session_state.categories = user_categories.split(" ")
334
 
335
  print(st.session_state.get("categories"))
336
+
337
  print(type(st.session_state.get("categories")))
 
 
338
 
339
  st.subheader("Pass in an input word or even a sentence")
340
+ user_text_search = st.text_input(
341
  label="Input your sentence",
342
+ value=st.session_state.text_search,
343
  )
344
 
345
+ st.session_state.text_search = user_text_search
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
  # Load glove embeddings
348
  word_index_dict, embeddings = embeddings_model.load_glove_embeddings(model_type)
 
353
  search_using_cos = Search(embeddings_model)
354
 
355
  # Find closest word to an input word
356
+ if st.session_state.text_search:
357
+ # sentence transformer embeddings
358
  print("sentence transformer Embedding")
359
  embeddings_metadata = {
360
  "word_index_dict": word_index_dict,
 
362
  "model_type": model_type,
363
  "text_search": st.session_state.text_search
364
  }
365
+ with st.spinner("Obtaining Cosine similarity for Glove..."):
366
  sorted_cosine_sim_transformer = search_using_cos.get_topK_similar_categories(
367
  st.session_state.text_search, category_embeddings
368
  )
369
 
 
 
370
  # Results and Plot Pie Chart for Glove
371
  print("Categories are: ", st.session_state.categories)
372
  st.subheader(
373
  "Closest word I have between: "
374
+ + " ".join(st.session_state.categories)
375
  + " as per different Embeddings"
376
  )
377
 
 
381
 
382
  st.write(
383
  f"Closest category using sentence transformer embeddings : {list(sorted_cosine_sim_transformer.keys())[0]}")
384
+
385
  plot_alatirchart(
386
  {
387
  "sentence_transformer_384": sorted_cosine_sim_transformer,
388
  }
389
  )
390
+
391
 
392
  st.write("")
393
  st.write(