nickmuchi commited on
Commit
9df8524
1 Parent(s): a0e6621

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -18,6 +18,8 @@ model = SentenceTransformer('all-mpnet-base-v2')
18
  flix_ds = load_dataset("nickmuchi/netflix-shows-mpnet-embeddings", use_auth_token=True)
19
  dataset_embeddings = torch.from_numpy(flix_ds["train"].to_pandas().to_numpy()).to(torch.float)
20
 
 
 
21
  #load cross-encoder for reranking
22
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
23
 
@@ -35,7 +37,7 @@ def semantic_search(query,embeddings,top_k):
35
  '''Encode query and check similarity with embeddings'''
36
 
37
  question_embedding = model.encode(query, convert_to_tensor=True).cpu()
38
- hits = util.semantic_search(question_embedding, embeddings, top_k=top_k)
39
  hits = hits[0]
40
 
41
  ##### Re-Ranking #####
@@ -113,7 +115,7 @@ with demo:
113
  sem_but = gr.Button('Search')
114
 
115
 
116
- sem_but.click(semantic_search,inputs=[query,dataset_embeddings,top_k],outputs=[bi_output,cross_output],queue=True)
117
 
118
  gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-netflix-shows-semantic-search)")
119
 
 
18
  flix_ds = load_dataset("nickmuchi/netflix-shows-mpnet-embeddings", use_auth_token=True)
19
  dataset_embeddings = torch.from_numpy(flix_ds["train"].to_pandas().to_numpy()).to(torch.float)
20
 
21
+ embed = {'embedding': dataset_embeddings}
22
+
23
  #load cross-encoder for reranking
24
  cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
25
 
 
37
  '''Encode query and check similarity with embeddings'''
38
 
39
  question_embedding = model.encode(query, convert_to_tensor=True).cpu()
40
+ hits = util.semantic_search(question_embedding, embeddings['embedding'], top_k=top_k)
41
  hits = hits[0]
42
 
43
  ##### Re-Ranking #####
 
115
  sem_but = gr.Button('Search')
116
 
117
 
118
+ sem_but.click(semantic_search,inputs=[query,embed,top_k],outputs=[bi_output,cross_output],queue=True)
119
 
120
  gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-netflix-shows-semantic-search)")
121