Trent commited on
Commit
75c3a89
1 Parent(s): fa5d8a4

Search function

Browse files
.gitattributes CHANGED
@@ -14,3 +14,5 @@
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *.jsonl.gz filter=lfs diff=lfs merge=lfs -text
18
+ *.csv filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -2,12 +2,12 @@ import streamlit as st
2
  import pandas as pd
3
 
4
  from backend import inference
5
- from backend.config import MODELS_ID, QA_MODELS_ID
6
 
7
  st.title('Demo using Flax-Sentence-Tranformers')
8
 
9
  st.sidebar.title('Tasks')
10
- menu = st.sidebar.radio("", options=["Sentence Similarity", "Asymmetric QA", "Search", "Clustering"], index=0)
11
 
12
  st.markdown('''
13
 
@@ -71,7 +71,7 @@ For more cool information on sentence embeddings, see the [sBert project](https:
71
 
72
  n_texts = st.number_input(
73
  f'''How many answers you want to compare with: '{anchor}'?''',
74
- value=3,
75
  min_value=2)
76
 
77
  inputs = []
@@ -97,7 +97,25 @@ For more cool information on sentence embeddings, see the [sBert project](https:
97
  st.line_chart(df_total)
98
 
99
  elif menu == "Search":
100
- select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- elif menu == "Clustering":
103
- select_models = st.multiselect("Choose models", options=list(MODELS_ID), default=list(MODELS_ID)[0])
 
 
2
  import pandas as pd
3
 
4
  from backend import inference
5
+ from backend.config import MODELS_ID, QA_MODELS_ID, SEARCH_MODELS_ID
6
 
7
  st.title('Demo using Flax-Sentence-Tranformers')
8
 
9
  st.sidebar.title('Tasks')
10
+ menu = st.sidebar.radio("", options=["Sentence Similarity", "Asymmetric QA", "Search"], index=0)
11
 
12
  st.markdown('''
13
 
 
71
 
72
  n_texts = st.number_input(
73
  f'''How many answers you want to compare with: '{anchor}'?''',
74
+ value=10,
75
  min_value=2)
76
 
77
  inputs = []
 
97
  st.line_chart(df_total)
98
 
99
  elif menu == "Search":
100
+ st.header('SEARCH')
101
+ st.markdown('''
102
+ **Instructions**: Make a query for anything related to "Python" and the model you choose will return you similar queries.
103
+
104
+ For more cool information on sentence embeddings, see the [sBert project](https://www.sbert.net/examples/applications/computing-embeddings/README.html).
105
+ ''')
106
+
107
+ select_models = st.multiselect("Choose models", options=list(SEARCH_MODELS_ID), default=list(SEARCH_MODELS_ID)[0])
108
+
109
+ anchor = st.text_input(
110
+ 'Please enter here your query about "Python", we will look for similar ones:',
111
+ value="How do I sort a dataframe by column"
112
+ )
113
+
114
+ n_texts = st.number_input(
115
+ f'''How many similar queries you want?''',
116
+ value=3,
117
+ min_value=2)
118
 
119
+ if st.button('Give me my search.'):
120
+ results = {model: inference.text_search(anchor, n_texts, model, QA_MODELS_ID) for model in select_models}
121
+ st.table(pd.DataFrame(results[select_models[0]]).T)
backend/config.py CHANGED
@@ -7,4 +7,8 @@ QA_MODELS_ID = dict(
7
  'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A'],
8
  mpnet_qa='flax-sentence-embeddings/mpnet_stackexchange_v1',
9
  distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
 
 
 
 
10
  )
 
7
  'flax-sentence-embeddings/multi-QA_v1-mpnet-asymmetric-A'],
8
  mpnet_qa='flax-sentence-embeddings/mpnet_stackexchange_v1',
9
  distilbert_qa = 'flax-sentence-embeddings/multi-qa_v1-distilbert-cls_dot'
10
+ )
11
+
12
+ SEARCH_MODELS_ID = dict(
13
+ mpnet_qa='flax-sentence-embeddings/mpnet_stackexchange_v1'
14
  )
backend/inference.py CHANGED
@@ -1,11 +1,16 @@
 
 
 
1
  import pandas as pd
 
2
  import jax.numpy as jnp
 
3
 
 
4
  from typing import List, Union
 
5
 
6
- # Defining cosine similarity using flax.
7
- from backend.config import MODELS_ID
8
- from backend.utils import load_model
9
 
10
 
11
  def cos_sim(a, b):
@@ -35,3 +40,34 @@ def text_similarity(anchor: str, inputs: List[str], model_name: str, model_dict:
35
  df = pd.DataFrame(d, columns=['inputs', 'score'])
36
 
37
  return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import json
3
+
4
  import pandas as pd
5
+ import numpy as np
6
  import jax.numpy as jnp
7
+ import tqdm
8
 
9
+ from sentence_transformers import util
10
  from typing import List, Union
11
+ import torch
12
 
13
+ from backend.utils import load_model, filter_questions, load_embeddings
 
 
14
 
15
 
16
  def cos_sim(a, b):
 
40
  df = pd.DataFrame(d, columns=['inputs', 'score'])
41
 
42
  return df
43
+
44
+
45
+ # Search
46
+ def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
47
+ # Proceeding with model
48
+ print(model_name)
49
+ assert model_name == "mpnet_qa"
50
+ model = load_model(model_name, model_dict)
51
+
52
+ # Creating embeddings
53
+ query_emb = model.encode(anchor, convert_to_tensor=True)[None, :]
54
+
55
+ print("loading embeddings")
56
+ corpus_emb = load_embeddings()
57
+
58
+ # Getting hits
59
+ hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0]
60
+
61
+ filtered_posts = filter_questions("python")
62
+ print(f"{len(filtered_posts)} posts found with tag: python")
63
+
64
+ hits_titles = []
65
+ hits_scores = []
66
+ urls = []
67
+ for hit in hits:
68
+ post = filtered_posts[hit['corpus_id']]
69
+ hits_titles.append(post['title'])
70
+ hits_scores.append("{:.3f}".format(hit['score']))
71
+ urls.append(f"https://stackoverflow.com/q/{post['id']}")
72
+
73
+ return hits_titles, hits_scores, urls
backend/utils.py CHANGED
@@ -1,4 +1,10 @@
 
 
 
 
1
  import streamlit as st
 
 
2
  from sentence_transformers import SentenceTransformer
3
 
4
 
@@ -13,3 +19,28 @@ def load_model(model_name, model_dict):
13
  output = [SentenceTransformer(name) for name in model_ids]
14
 
15
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import json
3
+ import numpy as np
4
+
5
  import streamlit as st
6
+ import torch
7
+ import tqdm
8
  from sentence_transformers import SentenceTransformer
9
 
10
 
 
19
  output = [SentenceTransformer(name) for name in model_ids]
20
 
21
  return output
22
+
23
+ @st.cache(allow_output_mutation=True)
24
+ def load_embeddings():
25
+ # embedding pre-generated
26
+ corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-mpnet-emb.csv', max_rows=10000))
27
+ return corpus_emb.float()
28
+
29
+ @st.cache(allow_output_mutation=True)
30
+ def filter_questions(tag, max_questions=10000):
31
+ posts = []
32
+ max_posts = 6e6
33
+ with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn:
34
+ for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"):
35
+ posts.append(json.loads(line))
36
+
37
+ if len(posts) >= max_posts:
38
+ break
39
+
40
+ filtered_posts = []
41
+ for post in posts:
42
+ if tag in post["tags"]:
43
+ filtered_posts.append(post)
44
+ if len(filtered_posts) >= max_questions:
45
+ break
46
+ return filtered_posts
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/__init__.py ADDED
File without changes
requirements.txt CHANGED
@@ -3,3 +3,5 @@ pandas
3
  jax
4
  jaxlib
5
  streamlit
 
 
 
3
  jax
4
  jaxlib
5
  streamlit
6
+ numpy
7
+ torch