devtrent commited on
Commit
f7a5664
1 Parent(s): 5cd1ac6

Asymmetric QA

Browse files
Files changed (2) hide show
  1. app.py +17 -10
  2. backend/config.py +1 -1
app.py CHANGED
@@ -13,16 +13,15 @@ st.markdown('''
13
 
14
  Hi! This is the demo for the [flax sentence embeddings](https://huggingface.co/flax-sentence-embeddings) created for the **Flax/JAX community week 🤗**. We are going to use three flax-sentence-embeddings models: a **distilroberta base**, a **mpnet base** and a **minilm-l6**. All were trained on all the dataset of the 1B+ train corpus with the v3 setup.
15
 
16
- ---
17
 
18
- **Instructions**: You can compare the similarity of a main text with other texts of your choice (in the sidebar). In the background, we'll create an embedding for each text, and then we'll use the cosine similarity function to calculate a similarity metric between our main sentence and the others.
 
 
 
19
 
20
  For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
21
-
22
- Please enjoy!!
23
  ''')
24
-
25
- if menu == "Sentence Similarity":
26
  select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
27
 
28
  anchor = st.text_input(
@@ -45,7 +44,7 @@ if menu == "Sentence Similarity":
45
  results = {model: inference.text_similarity(anchor, inputs, model, MODELS_ID) for model in select_models}
46
  df_results = {model: results[model] for model in results}
47
 
48
- index = [f"{idx}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
49
  df_total = pd.DataFrame(index=index)
50
  for key, value in df_results.items():
51
  df_total[key] = list(value['score'].values)
@@ -55,11 +54,19 @@ if menu == "Sentence Similarity":
55
  st.write('Visualize the results of each model:')
56
  st.line_chart(df_total)
57
  elif menu == "Asymmetric QA":
 
 
 
 
 
 
 
 
58
  select_models = st.multiselect("Choose models", options=list(QA_MODELS_ID), default=list(QA_MODELS_ID)[0])
59
 
60
  anchor = st.text_input(
61
  'Please enter here the query you want to compare with given answers:',
62
- value="How many close friends do you have?"
63
  )
64
 
65
  n_texts = st.number_input(
@@ -69,7 +76,7 @@ elif menu == "Asymmetric QA":
69
 
70
  inputs = []
71
 
72
- defaults = ["I have 10.", "How many children do you have?", "I have 3 brothers."]
73
  for i in range(int(n_texts)):
74
  input = st.text_input(f'Answer {i + 1}:', value=defaults[i] if i < len(defaults) else "")
75
 
@@ -79,7 +86,7 @@ elif menu == "Asymmetric QA":
79
  results = {model: inference.text_similarity(anchor, inputs, model, QA_MODELS_ID) for model in select_models}
80
  df_results = {model: results[model] for model in results}
81
 
82
- index = [f"{idx}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
83
  df_total = pd.DataFrame(index=index)
84
  for key, value in df_results.items():
85
  df_total[key] = list(value['score'].values)
 
13
 
14
  Hi! This is the demo for the [flax sentence embeddings](https://huggingface.co/flax-sentence-embeddings) created for the **Flax/JAX community week 🤗**. We are going to use three flax-sentence-embeddings models: a **distilroberta base**, a **mpnet base** and a **minilm-l6**. All were trained on all the dataset of the 1B+ train corpus with the v3 setup.
15
 
16
+ ''')
17
 
18
+ if menu == "Sentence Similarity":
19
+ st.header('Sentence Similarity')
20
+ st.markdown('''
21
+ **Instructions**: You can compare the similarity of a main text with other texts of your choice. In the background, we'll create an embedding for each text, and then we'll use the cosine similarity function to calculate a similarity metric between our main sentence and the others.
22
 
23
  For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
 
 
24
  ''')
 
 
25
  select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
26
 
27
  anchor = st.text_input(
 
44
  results = {model: inference.text_similarity(anchor, inputs, model, MODELS_ID) for model in select_models}
45
  df_results = {model: results[model] for model in results}
46
 
47
+ index = [f"{idx + 1}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
48
  df_total = pd.DataFrame(index=index)
49
  for key, value in df_results.items():
50
  df_total[key] = list(value['score'].values)
 
54
  st.write('Visualize the results of each model:')
55
  st.line_chart(df_total)
56
  elif menu == "Asymmetric QA":
57
+ st.header('Asymmetric QA')
58
+ st.markdown('''
59
+ **Instructions**: You can compare the Answer likeliness of a given Query with answer candidates of your choice. In the background, we'll create an embedding for each answers, and then we'll use the cosine similarity function to calculate a similarity metric between our query sentence and the others.
60
+ `mpnet_asymmetric_qa` model works best for hard negative answers or distinguishing similar queries due to separate models applied for encoding questions and answers.
61
+
62
+ For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
63
+ ''')
64
+
65
  select_models = st.multiselect("Choose models", options=list(QA_MODELS_ID), default=list(QA_MODELS_ID)[0])
66
 
67
  anchor = st.text_input(
68
  'Please enter here the query you want to compare with given answers:',
69
+ value="What is the weather in Paris?"
70
  )
71
 
72
  n_texts = st.number_input(
 
76
 
77
  inputs = []
78
 
79
+ defaults = ["It is raining in Paris right now with 70 F temperature.", "What is the weather in Berlin?", "I have 3 brothers."]
80
  for i in range(int(n_texts)):
81
  input = st.text_input(f'Answer {i + 1}:', value=defaults[i] if i < len(defaults) else "")
82
 
 
86
  results = {model: inference.text_similarity(anchor, inputs, model, QA_MODELS_ID) for model in select_models}
87
  df_results = {model: results[model] for model in results}
88
 
89
+ index = [f"{idx + 1}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
90
  df_total = pd.DataFrame(index=index)
91
  for key, value in df_results.items():
92
  df_total[key] = list(value['score'].values)
backend/config.py CHANGED
@@ -3,8 +3,8 @@ MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilr
3
  minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
4
 
5
  QA_MODELS_ID = dict(
6
- mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
7
  mpnet_asymmetric_qa = ['flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-Q',
8
  'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A'],
 
9
  distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
10
  )
 
3
  minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
4
 
5
  QA_MODELS_ID = dict(
 
6
  mpnet_asymmetric_qa = ['flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-Q',
7
  'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A'],
8
+ mpnet_qa='flax-sentence-embeddings/mpnet_stackexchange_v1',
9
  distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
10
  )