suzhoum commited on
Commit
ca75f47
1 Parent(s): 5e17fcf
Files changed (1) hide show
  1. app.py +54 -22
app.py CHANGED
@@ -1,29 +1,59 @@
1
  import gradio as gr
2
  import ir_datasets
3
  import pandas as pd
 
4
 
5
  from autogluon.multimodal import MultiModalPredictor
6
 
7
 
8
- def text_embedding(query: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  model_name = "sentence-transformers/all-MiniLM-L6-v2"
10
- # dataset = ir_datasets.load("beir/fiqa/dev")
11
- # docs_df = pd.DataFrame(dataset.docs_iter()).set_index("doc_id").sample(frac=0.001)
12
  predictor = MultiModalPredictor(
13
  pipeline="feature_extraction",
14
  hyperparameters={
15
  "model.hf_text.checkpoint_name": model_name
16
  }
17
  )
18
- # query_embedding = predictor.extract_embedding(docs_df)
19
- # return query_embedding["text"]
20
- query_embedding = predictor.extract_embedding([query])
21
- return query_embedding["0"]
 
 
 
 
 
 
 
22
 
 
 
 
 
 
23
 
24
  def main():
25
  with gr.Blocks(title="OpenSearch Demo") as demo:
26
- gr.Markdown("# Text Embedding for Search Queries")
27
  gr.Markdown("Ask an open question!")
28
  with gr.Row():
29
  inp_single = gr.Textbox(show_label=False)
@@ -31,23 +61,25 @@ def main():
31
  btn_single = gr.Button("Generate Embedding")
32
  with gr.Row():
33
  out_single = gr.DataFrame(label="Embedding", show_label=True)
34
- gr.Markdown("You can select one of the sample datasets for batch inference")
35
- with gr.Row():
36
- with gr.Column():
37
- btn_fiqa = gr.Button("fiqa")
38
- with gr.Column():
39
- btn_faiss = gr.Button("faiss")
40
  with gr.Row():
41
- out_batch = gr.DataFrame(label="Embedding", show_label=True)
42
- gr.Markdown("You can also try out our batch inference by uploading a file")
43
  with gr.Row():
44
- out_batch = gr.File(interactive=True)
 
45
  with gr.Row():
46
- btn_file = gr.Button("Generate Embedding")
47
-
48
- btn_single.click(fn=text_embedding, inputs=inp_single, outputs=out_single)
49
- btn_file.click(fn=text_embedding, inputs=inp_single, outputs=out_single)
50
-
 
 
 
 
 
 
 
51
  demo.launch()
52
 
53
 
 
1
  import gradio as gr
2
  import ir_datasets
3
  import pandas as pd
4
+ import numpy as np
5
 
6
  from autogluon.multimodal import MultiModalPredictor
7
 
8
 
9
+ query_embedding = None
10
+ document_embedding = None
11
+ docs_df = None
12
+
13
+ def text_embedding_batch():
14
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
15
+ dataset = ir_datasets.load("beir/fiqa/dev")
16
+ docs_df = pd.DataFrame(dataset.docs_iter()).set_index("doc_id").sample(frac=0.0001)
17
+ predictor = MultiModalPredictor(
18
+ pipeline="feature_extraction",
19
+ hyperparameters={
20
+ "model.hf_text.checkpoint_name": model_name
21
+ }
22
+ )
23
+ embedding = predictor.extract_embedding(docs_df)
24
+ query_embedding = embedding["text"]
25
+ return query_embedding
26
+
27
+
28
+ def text_embedding_single(query: str):
29
  model_name = "sentence-transformers/all-MiniLM-L6-v2"
 
 
30
  predictor = MultiModalPredictor(
31
  pipeline="feature_extraction",
32
  hyperparameters={
33
  "model.hf_text.checkpoint_name": model_name
34
  }
35
  )
36
+ embedding = predictor.extract_embedding([query])
37
+ document_embedding = embedding["0"]
38
+ return document_embedding
39
+
40
+
41
+ def rank_document():
42
+ q_norm = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True)
43
+ print(q_norm)
44
+ d_norm = document_embedding / np.linalg.norm(document_embedding, axis=-1, keepdims=True)
45
+ scores = d_norm.dot(q_norm[0])
46
+ print(scores)
47
 
48
+ result = []
49
+ for idx in np.argsort(-scores)[:2]:
50
+ result.append(docs_df['text'].iloc[idx])
51
+ return result
52
+
53
 
54
  def main():
55
  with gr.Blocks(title="OpenSearch Demo") as demo:
56
+ gr.Markdown("# Semantic Search with Autogluon")
57
  gr.Markdown("Ask an open question!")
58
  with gr.Row():
59
  inp_single = gr.Textbox(show_label=False)
 
61
  btn_single = gr.Button("Generate Embedding")
62
  with gr.Row():
63
  out_single = gr.DataFrame(label="Embedding", show_label=True)
64
+ gr.Markdown("You can select one of the sample datasets for document embedding")
 
 
 
 
 
65
  with gr.Row():
66
+ btn_fiqa = gr.Button("fiqa")
 
67
  with gr.Row():
68
+ out_batch = gr.DataFrame(label="Sample Embeddings", show_label=True, row_count=5)
69
+ gr.Markdown("Now rank the documents and pick the top 3 most relevant from the dataset")
70
  with gr.Row():
71
+ btn_rank = gr.Button("Rank documents")
72
+ with gr.Row():
73
+ out_rank = gr.DataFrame(label="Top ranked documents", show_label=True, row_count=5)
74
+ # with gr.Row():
75
+ # out_batch = gr.File(interactive=True)
76
+ # with gr.Row():
77
+ # btn_file = gr.Button("Generate Embedding")
78
+
79
+ btn_single.click(fn=text_embedding_single, inputs=inp_single, outputs=out_single)
80
+ btn_fiqa.click(fn=text_embedding_batch, inputs=None, outputs=out_batch)
81
+ btn_rank.click(fn=rank_document, inputs=None, outputs=out_rank)
82
+ # btn_file.click(fn=text_embedding_batch, inputs=inp_single, outputs=out_single)
83
  demo.launch()
84
 
85