devtrent commited on
Commit
a41bdbc
1 Parent(s): 49438d6

Multi model select and local model loading

Browse files
__init__.py ADDED
File without changes
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import base64
4
- import requests
 
5
 
6
  st.title('Demo using Flax-Sentence-Tranformers')
7
 
@@ -20,12 +21,12 @@ For more cool information on sentence embeddings, see the [sBert project](https:
20
  Please enjoy!!
21
  ''')
22
 
23
-
24
  anchor = st.text_input(
25
  'Please enter here the main text you want to compare:'
26
  )
27
 
28
  if anchor:
 
29
  n_texts = st.sidebar.number_input(
30
  f'''How many texts you want to compare with: '{anchor}'?''',
31
  value=2,
@@ -34,40 +35,21 @@ if anchor:
34
  inputs = []
35
 
36
  for i in range(n_texts):
37
-
38
- input = st.sidebar.text_input(f'Text {i+1}:')
39
 
40
  inputs.append(input)
41
 
42
-
43
-
44
- api_base_url = 'http://127.0.0.1:8000/similarity'
45
-
46
  if anchor:
47
  if st.sidebar.button('Tell me the similarity.'):
48
- res_distilroberta = requests.get(url = api_base_url, params = dict(anchor = anchor,
49
- inputs = inputs,
50
- model = 'distilroberta'))
51
- res_mpnet = requests.get(url = api_base_url, params = dict(anchor = anchor,
52
- inputs = inputs,
53
- model = 'mpnet'))
54
- res_minilm_l6 = requests.get(url = api_base_url, params = dict(anchor = anchor,
55
- inputs = inputs,
56
- model = 'minilm_l6'))
57
-
58
- d_distilroberta = res_distilroberta.json()['dataframe']
59
- d_mpnet = res_mpnet.json()['dataframe']
60
- d_minilm_l6 = res_minilm_l6.json()['dataframe']
61
-
62
- index = list(d_distilroberta['inputs'].values())
63
  df_total = pd.DataFrame(index=index)
64
- df_total['distilroberta'] = list(d_distilroberta['score'].values())
65
- df_total['mpnet'] = list(d_mpnet['score'].values())
66
- df_total['minilm_l6'] = list(d_minilm_l6['score'].values())
67
 
68
- st.write('Here are the results for our three models:')
69
  st.write(df_total)
70
  st.write('Visualize the results of each model:')
71
  st.area_chart(df_total)
72
-
73
-
1
  import streamlit as st
2
  import pandas as pd
3
+
4
+ from backend import inference
5
+ from backend.config import MODELS_ID
6
 
7
  st.title('Demo using Flax-Sentence-Tranformers')
8
 
21
  Please enjoy!!
22
  ''')
23
 
 
24
  anchor = st.text_input(
25
  'Please enter here the main text you want to compare:'
26
  )
27
 
28
  if anchor:
29
+ select_models = st.sidebar.multiselect("Choose models", options=MODELS_ID.keys())
30
  n_texts = st.sidebar.number_input(
31
  f'''How many texts you want to compare with: '{anchor}'?''',
32
  value=2,
35
  inputs = []
36
 
37
  for i in range(n_texts):
38
+ input = st.sidebar.text_input(f'Text {i + 1}:')
 
39
 
40
  inputs.append(input)
41
 
 
 
 
 
42
  if anchor:
43
  if st.sidebar.button('Tell me the similarity.'):
44
+ results = {model: inference.text_similarity(anchor, inputs, model) for model in select_models}
45
+ df_results = {model: results[model] for model in results}
46
+
47
+ index = inputs
 
 
 
 
 
 
 
 
 
 
 
48
  df_total = pd.DataFrame(index=index)
49
+ for key, value in df_results.items():
50
+ df_total[key] = list(value['score'].values)
 
51
 
52
+ st.write('Here are the results for selected models:')
53
  st.write(df_total)
54
  st.write('Visualize the results of each model:')
55
  st.area_chart(df_total)
 
 
backend/__init__.py ADDED
File without changes
backend/config.py CHANGED
@@ -1,3 +1,4 @@
1
  MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
2
  mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
 
3
  minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
1
  MODELS_ID = dict(distilroberta = 'flax-sentence-embeddings/st-codesearch-distilroberta-base',
2
  mpnet = 'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
3
+ mpnet_qa = 'flax-sentence-embeddings/mpnet_stackexchange_v1',
4
  minilm_l6 = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L6')
backend/inference.py CHANGED
@@ -1,41 +1,30 @@
1
- from sentence_transformers import SentenceTransformer
2
  import pandas as pd
3
  import jax.numpy as jnp
4
 
5
  from typing import List
6
- import config
7
-
8
- # We download the models we will be using.
9
- # If you do not want to use all, you can comment the unused ones.
10
- distilroberta_model = SentenceTransformer(config.MODELS_ID['distilroberta'])
11
- mpnet_model = SentenceTransformer(config.MODELS_ID['mpnet'])
12
- minilm_l6_model = SentenceTransformer(config.MODELS_ID['minilm_l6'])
13
 
14
  # Defining cosine similarity using flax.
 
 
 
15
  def cos_sim(a, b):
16
- return jnp.matmul(a, jnp.transpose(b))/(jnp.linalg.norm(a)*jnp.linalg.norm(b))
17
 
18
 
19
  # We get similarity between embeddings.
20
- def text_similarity(anchor: str, inputs: List[str], model: str = 'distilroberta'):
 
21
 
22
  # Creating embeddings
23
- if model == 'distilroberta':
24
- anchor_emb = distilroberta_model.encode(anchor)[None, :]
25
- inputs_emb = distilroberta_model.encode([input for input in inputs])
26
- elif model == 'mpnet':
27
- anchor_emb = mpnet_model.encode(anchor)[None, :]
28
- inputs_emb = mpnet_model.encode([input for input in inputs])
29
- elif model == 'minilm_l6':
30
- anchor_emb = minilm_l6_model.encode(anchor)[None, :]
31
- inputs_emb = minilm_l6_model.encode([input for input in inputs])
32
 
33
  # Obtaining similarity
34
  similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
35
 
36
  # Returning a Pandas' dataframe
37
  d = {'inputs': [input for input in inputs],
38
- 'score': [round(similarity[i],3) for i in range(len(similarity))]}
39
  df = pd.DataFrame(d, columns=['inputs', 'score'])
40
 
41
  return df.sort_values('score', ascending=False)
 
1
  import pandas as pd
2
  import jax.numpy as jnp
3
 
4
  from typing import List
 
 
 
 
 
 
 
5
 
6
  # Defining cosine similarity using flax.
7
+ from backend.utils import load_model
8
+
9
+
10
  def cos_sim(a, b):
11
+ return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
12
 
13
 
14
  # We get similarity between embeddings.
15
+ def text_similarity(anchor: str, inputs: List[str], model_name: str):
16
+ model = load_model(model_name)
17
 
18
  # Creating embeddings
19
+ anchor_emb = model.encode(anchor)[None, :]
20
+ inputs_emb = model.encode([input for input in inputs])
 
 
 
 
 
 
 
21
 
22
  # Obtaining similarity
23
  similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
24
 
25
  # Returning a Pandas' dataframe
26
  d = {'inputs': [input for input in inputs],
27
+ 'score': [round(similarity[i], 3) for i in range(len(similarity))]}
28
  df = pd.DataFrame(d, columns=['inputs', 'score'])
29
 
30
  return df.sort_values('score', ascending=False)
backend/main.py DELETED
@@ -1,19 +0,0 @@
1
- from fastapi import Query, FastAPI
2
-
3
- import config
4
- import inference
5
- from typing import List
6
-
7
- app = FastAPI()
8
-
9
- @app.get("/")
10
- def read_root():
11
- return {"message": "Welcome to the API of flax-sentence-embeddings."}
12
-
13
- @app.get('/similarity')
14
- def get_similarity(anchor: str, inputs: List[str] = Query([]), model: str = 'distilroberta'):
15
- return {'dataframe': inference.text_similarity(anchor, inputs, model)}
16
-
17
-
18
- #if __name__ == "__main__":
19
- # uvicorn.run("main:app", host="0.0.0.0", port=8080)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/utils.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from .config import MODELS_ID
4
+
5
+
6
+ @st.cache(allow_output_mutation=True)
7
+ def load_model(model_name):
8
+ assert model_name in MODELS_ID.keys()
9
+ # Lazy downloading
10
+ model = SentenceTransformer(MODELS_ID[model_name])
11
+ return model
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- fastapi
2
  sentence_transformers
3
  pandas
4
  jax
 
5
  streamlit
 
1
  sentence_transformers
2
  pandas
3
  jax
4
+ jaxlib
5
  streamlit