devtrent commited on
Commit
883e41e
1 Parent(s): 75c3a89

Clustering function

Browse files
Files changed (3) hide show
  1. app.py +5 -1
  2. backend/inference.py +68 -1
  3. requirements.txt +2 -0
app.py CHANGED
@@ -118,4 +118,8 @@ For more cool information on sentence embeddings, see the [sBert project](https:
118
 
119
  if st.button('Give me my search.'):
120
  results = {model: inference.text_search(anchor, n_texts, model, QA_MODELS_ID) for model in select_models}
121
- st.table(pd.DataFrame(results[select_models[0]]).T)
 
 
 
 
 
118
 
119
  if st.button('Give me my search.'):
120
  results = {model: inference.text_search(anchor, n_texts, model, QA_MODELS_ID) for model in select_models}
121
+ st.table(pd.DataFrame(results[select_models[0]]).T)
122
+
123
+ if st.button('3D Clustering of search result (new window)'):
124
+ fig = inference.text_cluster(anchor, 1000, select_models[0], QA_MODELS_ID)
125
+ fig.show()
backend/inference.py CHANGED
@@ -1,5 +1,6 @@
1
  import gzip
2
  import json
 
3
 
4
  import pandas as pd
5
  import numpy as np
@@ -11,7 +12,7 @@ from typing import List, Union
11
  import torch
12
 
13
  from backend.utils import load_model, filter_questions, load_embeddings
14
-
15
 
16
  def cos_sim(a, b):
17
  return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
@@ -71,3 +72,69 @@ def text_search(anchor: str, n_answers: int, model_name: str, model_dict: dict):
71
  urls.append(f"https://stackoverflow.com/q/{post['id']}")
72
 
73
  return hits_titles, hits_scores, urls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gzip
2
  import json
3
+ from collections import Counter
4
 
5
  import pandas as pd
6
  import numpy as np
 
12
  import torch
13
 
14
  from backend.utils import load_model, filter_questions, load_embeddings
15
+ from MulticoreTSNE import MulticoreTSNE as TSNE
16
 
17
  def cos_sim(a, b):
18
  return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
 
72
  urls.append(f"https://stackoverflow.com/q/{post['id']}")
73
 
74
  return hits_titles, hits_scores, urls
75
+
76
+
77
+ def text_cluster(anchor: str, n_answers: int, model_name: str, model_dict: dict):
78
+ # Proceeding with model
79
+ print(model_name)
80
+ assert model_name == "mpnet_qa"
81
+ model = load_model(model_name, model_dict)
82
+
83
+ # Creating embeddings
84
+ query_emb = model.encode(anchor, convert_to_tensor=True)[None, :]
85
+
86
+ print("loading embeddings")
87
+ corpus_emb = load_embeddings()
88
+
89
+ # Getting hits
90
+ hits = util.semantic_search(query_emb, corpus_emb, score_function=util.dot_score, top_k=n_answers)[0]
91
+
92
+ filtered_posts = filter_questions("python")
93
+
94
+ hits_dict = [filtered_posts[hit['corpus_id']] for hit in hits]
95
+ hits_dict.append(dict(id = '1', title = anchor, tags = ['']))
96
+
97
+ hits_emb = torch.stack([corpus_emb[hit['corpus_id']] for hit in hits])
98
+ hits_emb = torch.cat((hits_emb, query_emb))
99
+
100
+ # Dimensionality reduction with t-SNE
101
+ tsne = TSNE(n_components=3, verbose=1, perplexity=15, n_iter=1000)
102
+ tsne_results = tsne.fit_transform(hits_emb.cpu())
103
+ df = pd.DataFrame(hits_dict)
104
+ tags = list(df['tags'])
105
+
106
+ counter = Counter(tags[0])
107
+ for i in tags[1:]:
108
+ counter.update(i)
109
+
110
+ df_tags = pd.DataFrame(counter.most_common(), columns=['Tag', 'Mentions'])
111
+ most_common_tags = list(df_tags['Tag'])[1:5]
112
+
113
+ labels = []
114
+
115
+ for tags_list in list(df['tags']):
116
+ for common_tag in most_common_tags:
117
+ if common_tag in tags_list:
118
+ labels.append(common_tag)
119
+ break
120
+ elif common_tag != most_common_tags[-1]:
121
+ continue
122
+ else:
123
+ labels.append('others')
124
+
125
+ df['title'] = [post['title'] for post in hits_dict]
126
+ df['labels'] = labels
127
+ df['tsne_x'] = tsne_results[:, 0]
128
+ df['tsne_y'] = tsne_results[:, 1]
129
+ df['tsne_z'] = tsne_results[:, 2]
130
+
131
+ df['size'] = [2 for i in range(len(df))]
132
+
133
+ # Making the query bigger than the rest of the observations
134
+ df['size'][len(df) - 1] = 10
135
+ df['labels'][len(df) - 1] = 'QUERY'
136
+ import plotly.express as px
137
+
138
+ fig = px.scatter_3d(df, x='tsne_x', y='tsne_y', z='tsne_z', color='labels', size='size',
139
+ color_discrete_sequence=px.colors.qualitative.D3, hover_data=[df.title])
140
+ return fig
requirements.txt CHANGED
@@ -5,3 +5,5 @@ jaxlib
5
  streamlit
6
  numpy
7
  torch
 
 
 
5
  streamlit
6
  numpy
7
  torch
8
+ MulticoreTSNE
9
+ plotly