Trent commited on
Commit
5cd1ac6
1 Parent(s): 31f3439

Asymmetric QA

Browse files
Files changed (4) hide show
  1. app.py +38 -3
  2. backend/config.py +2 -4
  3. backend/inference.py +3 -2
  4. backend/utils.py +3 -4
app.py CHANGED
@@ -2,12 +2,12 @@ import streamlit as st
2
  import pandas as pd
3
 
4
  from backend import inference
5
- from backend.config import MODELS_ID
6
 
7
  st.title('Demo using Flax-Sentence-Tranformers')
8
 
9
  st.sidebar.title('Tasks')
10
- menu = st.sidebar.radio("", options=["Sentence Similarity", "Search", "Clustering"], index=0)
11
 
12
  st.markdown('''
13
 
@@ -42,7 +42,7 @@ if menu == "Sentence Similarity":
42
  inputs.append(input)
43
 
44
  if st.button('Tell me the similarity.'):
45
- results = {model: inference.text_similarity(anchor, inputs, model) for model in select_models}
46
  df_results = {model: results[model] for model in results}
47
 
48
  index = [f"{idx}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
@@ -54,6 +54,41 @@ if menu == "Sentence Similarity":
54
  st.write(df_total)
55
  st.write('Visualize the results of each model:')
56
  st.line_chart(df_total)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  elif menu == "Search":
58
  select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
59
 
 
2
  import pandas as pd
3
 
4
  from backend import inference
5
+ from backend.config import MODELS_ID, QA_MODELS_ID
6
 
7
  st.title('Demo using Flax-Sentence-Tranformers')
8
 
9
  st.sidebar.title('Tasks')
10
+ menu = st.sidebar.radio("", options=["Sentence Similarity", "Asymmetric QA", "Search", "Clustering"], index=0)
11
 
12
  st.markdown('''
13
 
 
42
  inputs.append(input)
43
 
44
  if st.button('Tell me the similarity.'):
45
+ results = {model: inference.text_similarity(anchor, inputs, model, MODELS_ID) for model in select_models}
46
  df_results = {model: results[model] for model in results}
47
 
48
  index = [f"{idx}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
 
54
  st.write(df_total)
55
  st.write('Visualize the results of each model:')
56
  st.line_chart(df_total)
57
+ elif menu == "Asymmetric QA":
58
+ select_models = st.multiselect("Choose models", options=list(QA_MODELS_ID), default=list(QA_MODELS_ID)[0])
59
+
60
+ anchor = st.text_input(
61
+ 'Please enter here the query you want to compare with given answers:',
62
+ value="How many close friends do you have?"
63
+ )
64
+
65
+ n_texts = st.number_input(
66
+ f'''How many answers you want to compare with: '{anchor}'?''',
67
+ value=3,
68
+ min_value=2)
69
+
70
+ inputs = []
71
+
72
+ defaults = ["I have 10.", "How many children do you have?", "I have 3 brothers."]
73
+ for i in range(int(n_texts)):
74
+ input = st.text_input(f'Answer {i + 1}:', value=defaults[i] if i < len(defaults) else "")
75
+
76
+ inputs.append(input)
77
+
78
+ if st.button('Tell me Answer likeliness.'):
79
+ results = {model: inference.text_similarity(anchor, inputs, model, QA_MODELS_ID) for model in select_models}
80
+ df_results = {model: results[model] for model in results}
81
+
82
+ index = [f"{idx}:{input[:min(15, len(input))]}..." for idx, input in enumerate(inputs)]
83
+ df_total = pd.DataFrame(index=index)
84
+ for key, value in df_results.items():
85
+ df_total[key] = list(value['score'].values)
86
+
87
+ st.write('Here are the results for selected models:')
88
+ st.write(df_total)
89
+ st.write('Visualize the results of each model:')
90
+ st.line_chart(df_total)
91
+
92
  elif menu == "Search":
93
  select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
94
 
backend/config.py CHANGED
@@ -1,12 +1,10 @@
1
  MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
2
  mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
3
- mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
4
- mpnet_asymmetric_qa = ['flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-Q',
5
- 'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A'],
6
  minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
7
 
8
  QA_MODELS_ID = dict(
9
  mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
10
  mpnet_asymmetric_qa = ['flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-Q',
11
- 'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A']
 
12
  )
 
1
  MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
2
  mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
 
 
 
3
  minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
4
 
5
  QA_MODELS_ID = dict(
6
  mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
7
  mpnet_asymmetric_qa = ['flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-Q',
8
+ 'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A'],
9
+ distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
10
  )
backend/inference.py CHANGED
@@ -4,6 +4,7 @@ import jax.numpy as jnp
4
  from typing import List, Union
5
 
6
  # Defining cosine similarity using flax.
 
7
  from backend.utils import load_model
8
 
9
 
@@ -12,9 +13,9 @@ def cos_sim(a, b):
12
 
13
 
14
  # We get similarity between embeddings.
15
- def text_similarity(anchor: str, inputs: List[str], model_name: str):
16
  print(model_name)
17
- model = load_model(model_name)
18
 
19
  # Creating embeddings
20
  if hasattr(model, 'encode'):
 
4
  from typing import List, Union
5
 
6
  # Defining cosine similarity using flax.
7
+ from backend.config import MODELS_ID
8
  from backend.utils import load_model
9
 
10
 
 
13
 
14
 
15
  # We get similarity between embeddings.
16
+ def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict: dict):
17
  print(model_name)
18
+ model = load_model(model_name, model_dict)
19
 
20
  # Creating embeddings
21
  if hasattr(model, 'encode'):
backend/utils.py CHANGED
@@ -1,13 +1,12 @@
1
  import streamlit as st
2
  from sentence_transformers import SentenceTransformer
3
- from .config import MODELS_ID
4
 
5
 
6
  @st.cache(allow_output_mutation=True)
7
- def load_model(model_name):
8
- assert model_name in MODELS_ID.keys()
9
  # Lazy downloading
10
- model_ids = MODELS_ID[model_name]
11
  if type(model_ids) == str:
12
  output = SentenceTransformer(model_ids)
13
  elif hasattr(model_ids, '__iter__'):
 
1
  import streamlit as st
2
  from sentence_transformers import SentenceTransformer
 
3
 
4
 
5
  @st.cache(allow_output_mutation=True)
6
+ def load_model(model_name, model_dict):
7
+ assert model_name in model_dict.keys()
8
  # Lazy downloading
9
+ model_ids = model_dict[model_name]
10
  if type(model_ids) == str:
11
  output = SentenceTransformer(model_ids)
12
  elif hasattr(model_ids, '__iter__'):