ssilwal commited on
Commit
3975d16
1 Parent(s): 269f587

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. app.py +307 -0
  3. contexts-emb.txt +3 -0
  4. requirements.txt +5 -0
  5. synthetic-dataset.csv +3 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ contexts-emb.txt filter=lfs diff=lfs merge=lfs -text
36
+ synthetic-dataset.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ # import faiss
7
+ from sentence_transformers import util, LoggingHandler
8
+ from sentence_transformers.cross_encoder import CrossEncoder
9
+ import streamlit as st
10
+
11
+
12
+
13
+ def get_embeddings_from_contexts(model, contexts): # for embeddings
14
+ """
15
+ It takes a list of contexts and returns a list of embeddings
16
+
17
+ :param model: the model you want to use to get the embeddings
18
+ :param contexts: a list of strings, each string is a context
19
+ :return: The embeddings of the contexts
20
+ """
21
+ return model.encode(contexts)
22
+
23
+ def load_semantic_search_model(model_name):
24
+ """
25
+ It loads the model
26
+
27
+ :param model_name: The name of the model to load
28
+ :return: A sentence transformer object
29
+ """
30
+ from sentence_transformers import SentenceTransformer
31
+
32
+ return SentenceTransformer(model_name)
33
+
34
+
35
+
36
+ def convert_embeddings_to_faiss_index(embeddings, context_ids):
37
+ """
38
+ We take in a list of embeddings and a list of context IDs, convert the embeddings to a numpy array,
39
+ instantiate a flat index, pass the index to IndexIDMap, add the embeddings and their IDs to the
40
+ index, instantiate the resources, and move the index to the GPU
41
+
42
+ :param embeddings: The embeddings you want to convert to a faiss index
43
+ :param context_ids: The IDs of the contexts
44
+ :return: A GPU index
45
+ """
46
+ embeddings = np.array(embeddings).astype("float32") # Step 1: Change data type
47
+
48
+ index = faiss.IndexFlatIP(embeddings.shape[1]) # Step 2: Instantiate the index
49
+ index = faiss.IndexIDMap(index) # Step 3: Pass the index to IndexIDMap
50
+
51
+ index.add_with_ids(embeddings, context_ids) # Step 4: Add vectors and their IDs
52
+
53
+ res = faiss.StandardGpuResources() # Step 5: Instantiate the resources
54
+ gpu_index = faiss.index_cpu_to_gpu(
55
+ res, 0, index
56
+ ) # Step 6: Move the index to the GPU
57
+ return gpu_index
58
+
59
+
60
+
61
+ def vector_search(query, model, index, num_results=20):
62
+ """Tranforms query to vector using a pretrained, sentence-level
63
+ model and finds similar vectors using FAISS.
64
+ """
65
+ vector = model.encode(list(query))
66
+ D, I = index.search(np.array(vector).astype("float32"), k=num_results)
67
+ return D, I
68
+
69
+
70
+ def id2details(df, I, column):
71
+ """Returns the paper titles based on the paper index."""
72
+ return [list(df[df.index.values == idx][column])[0] for idx in I[0]]
73
+
74
+
75
+ def combine(user_query, model, index, df, column, num_results=10):
76
+ """
77
+ It takes a user query, a model, an index, a dataframe, and a column name, and returns the top 5
78
+ results from the dataframe
79
+
80
+ :param user_query: the query you want to search for
81
+ :param model: the model we trained above
82
+ :param index: the index of the vectorized dataframe
83
+ :param df: the dataframe containing the data
84
+ :param column: the column in the dataframe that contains the text you want to search
85
+ :param num_results: the number of results to return, defaults to 5 (optional)
86
+ :return: the top 5 results from the vector search.
87
+ """
88
+ D, I = vector_search([user_query], model, index, num_results=num_results)
89
+ return id2details(df, I, column)
90
+
91
+
92
+ def get_context(model, query, contexts, contexts_emb, top_k=100):
93
+ """
94
+ Given a query, a list of contexts, and their embeddings, return the top k contexts with the highest
95
+ similarity score.
96
+
97
+ :param model: the model we trained in the previous section
98
+ :param query: the query string
99
+ :param contexts: list of contexts
100
+ :param contexts_emb: the embeddings of the contexts
101
+ :param top_k: the number of contexts to return, defaults to 3 (optional)
102
+ :return: The top_context is a list of the top 3 contexts that are most similar to the query.
103
+ """
104
+ # Encode query and contexts with the encode function
105
+ query_emb = model.encode(query)
106
+ query_emb = torch.from_numpy(query_emb.reshape(1, -1))
107
+ contexts_emb = torch.from_numpy(contexts_emb)
108
+ # Compute similiarity score between query and all contexts embeddings
109
+ scores = util.cos_sim(query_emb, contexts_emb)[0].cpu().tolist()
110
+ # Combine contexts & scores
111
+ # print(contexts)
112
+ contexts_score_pairs = list(zip(contexts.premise.tolist(), scores))
113
+
114
+ result = sorted(contexts_score_pairs, key=lambda x: x[1], reverse=True)[:top_k]
115
+ # print(result)
116
+ top_context = []
117
+ for c, s in result:
118
+ top_context.append(c)
119
+ return top_context
120
+
121
+
122
+
123
+ def get_answer(model, query, context):
124
+ """
125
+ > Given a model, a query, and a context, return the answer
126
+
127
+ :param model: the model we just loaded
128
+ :param query: The question you want to ask
129
+ :param context: The context of the question
130
+ :return: A string
131
+ """
132
+
133
+ formatted_query = f"{query}\n{context}"
134
+ res = model(formatted_query)
135
+ return res[0]["generated_text"]
136
+
137
+
138
+
139
+ def evaluate_semantic_model(model, question, contexts, contexts_emb, index=None):
140
+
141
+ """
142
+ For each question, we use the model to find the most similar context.
143
+
144
+ :param model: the model we're using to evaluate
145
+ :param questions: a list of questions
146
+ :param contexts: the list of contexts
147
+ :param contexts_emb: the embeddings of the contexts
148
+ :param index: the index of the context embeddings
149
+ :return: The predictions are being returned.
150
+ """
151
+ predictions = combine(question, model, index, contexts, "premise") if index else get_context(model, question, contexts, contexts_emb) #for cosine
152
+
153
+
154
+ return predictions
155
+
156
+ @st.experimental_singleton
157
+ def load_models():
158
+
159
+ semantic_search_model = load_semantic_search_model("distiluse-base-multilingual-cased-v1")
160
+
161
+ model_nli_stsb = CrossEncoder('ssilwal/nli-stsb-fr', max_length=512, device='cpu')
162
+
163
+ model_nli = CrossEncoder('ssilwal/CASS-civile-nli', max_length=512, device='cpu')
164
+
165
+ model_baseline = CrossEncoder('amberoad/bert-multilingual-passage-reranking-msmarco', max_length=512, device='cpu')
166
+
167
+ df = pd.read_csv('synthetic-dataset.csv')
168
+ contexts = df.premise.unique()
169
+ contexts = pd.DataFrame(contexts, columns = ['premise'])
170
+ context_emb = np.loadtxt('contexts-emb.txt', dtype=np.float32)
171
+
172
+ return semantic_search_model, model_nli, model_nli_stsb, model_baseline, contexts, context_emb
173
+
174
+
175
+ def callback(state, object):
176
+ return
177
+ # st.session_state[f'{state}']
178
+
179
+
180
+ if 'slider' not in st.session_state:
181
+ st.session_state['slider'] = 0
182
+
183
+ if 'radio' not in st.session_state:
184
+ st.session_state['radio'] = 'Model 1'
185
+
186
+ if 'show' not in st.session_state:
187
+ st.session_state['show'] = False
188
+
189
+ if 'results' not in st.session_state:
190
+ st.session_state['results'] = None
191
+
192
+ # if 'run' not in st.session_state:
193
+ # st.session_state['run'] = True
194
+
195
+ # if 'radio' not in st.session_state:
196
+ # st.session_state['radio'] = 'Model 1'
197
+
198
+
199
+ semantic_search_model, model_nli, model_nli_stsb, model_baseline, contexts, context_emb = load_models()
200
+
201
+ @st.cache(suppress_st_warning=True)
202
+ def run_inference(model_name, query):
203
+
204
+
205
+ pred = evaluate_semantic_model(
206
+ semantic_search_model,
207
+ query,
208
+ contexts,
209
+ context_emb,
210
+ # index,
211
+ # #if u want to use faiss
212
+ )
213
+
214
+
215
+ # So we create the respective sentence combinations
216
+ sentence_combinations = [[query, corpus_sentence] for corpus_sentence in pred]
217
+
218
+ # Compute the similarity scores for these combinations
219
+
220
+ if model_name=='Model 1':
221
+ similarity_scores = model_nli.predict(sentence_combinations)
222
+ scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
223
+ sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
224
+ results = [pred[idx] for _,idx in list(sim_scores_argsort)[:int(top_K)]]
225
+
226
+ if model_name=='Model 2':
227
+ similarity_scores = model_nli_stsb.predict(sentence_combinations)
228
+ sim_scores_argsort = reversed(np.argsort(similarity_scores))
229
+ results = [pred[idx] for idx in list(sim_scores_argsort)[:int(top_K)]]
230
+
231
+ if model_name=='Model 3':
232
+ similarity_scores = model_baseline.predict(sentence_combinations)
233
+ scores = [(score_max[0],idx) for idx,score_max in enumerate(similarity_scores)]
234
+ sim_scores_argsort = sorted(scores, key=lambda x: x[0], reverse=True)
235
+ results = [pred[idx] for _,idx in list(sim_scores_argsort)[:int(top_K)]]
236
+
237
+
238
+
239
+ return results
240
+
241
+
242
+
243
+
244
+
245
+
246
+ # only need for faiss index
247
+ # index = convert_embeddings_to_faiss_index(context_emb, contexts.index.values)
248
+
249
+
250
+ # query = ['Quelles protections la Loi sur la protection du consommateur accorde-t-elle aux individus?']
251
+ query = st.text_input('Civil Legal Query', 'Quelles protections la Loi sur la protection du consommateur accorde-t-elle aux individus?')
252
+ top_K = st.text_input('Choose Number of Result: ','10')
253
+
254
+
255
+ model_name = st.radio(
256
+ "Choose Model",
257
+ ("Model 1", "Model 2", "Model 3"),
258
+ key='radio', on_change=callback, args=('radio','Model 1')
259
+ )
260
+
261
+
262
+ if st.button('Run', key='run'):
263
+
264
+ results= run_inference(model_name, query)
265
+
266
+ st.session_state['show'] = True
267
+ st.session_state['results'] = results
268
+ st.session_state['query'] = query
269
+ model_dict = {'Model 1': 'NLI-Syn', 'Model 2': 'NLI-stsb', 'Model 3': 'NLI-baseline'}
270
+ st.session_state['model'] = model_dict[model_name]
271
+
272
+
273
+
274
+
275
+ if st.session_state['show'] and st.session_state['results']!=None:
276
+ st.write("-"*50)
277
+ for result in st.session_state['results']:
278
+
279
+ line = f'Context: {result}\n\n'
280
+
281
+ st.write(line)
282
+
283
+ rate = st.slider('Please rate this output', min_value= 0, max_value=5, key='slider', on_change=callback, args=('slider','0'))
284
+
285
+ if st.session_state['slider'] !=0:
286
+ rate = st.session_state['slider']
287
+ st.write(f'You rated {rate}')
288
+
289
+
290
+
291
+ if st.button('Submit', key='rate'):
292
+ if st.session_state['results']!=None:
293
+ item = {'query': st.session_state['query'], 'results': st.session_state['results'], 'model': st.session_state['model'],'rating': st.session_state['slider']}
294
+ try:
295
+ with open('human.json','r') as file:
296
+ import json
297
+ archive = json.load(file)
298
+ archive.append(item)
299
+ with open('human.json','w') as file:
300
+ json.dump(archive, file)
301
+ except FileNotFoundError:
302
+ import json
303
+ data = [item]
304
+ print(data)
305
+ with open('human.json','w') as file:
306
+ json.dump(data, file)
307
+
contexts-emb.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d46f89976c4ea5e8c573950b51c94db43b03d231c557096a8273d75cb506576
3
+ size 76331907
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ torch
4
+ sentence_transformers
5
+ streamlit
synthetic-dataset.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b05dffb7e3b522a85fe20263c22ab91430f6e9c535705515dd6bf869a20199d
3
+ size 38688491