pushpdeep commited on
Commit
c139b1e
1 Parent(s): 52e3664

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ pib2022_23_cleaned.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import pickle
3
+ import pandas as pd
4
+ import streamlit as st
5
+ from sentence_transformers import SentenceTransformer
6
+ from vector_engine.utils import vector_search
7
+
8
+ @st.cache_data
9
+ def read_data(pibdata="pib2022_23_cleaned.csv"):
10
+ """Read the pib data."""
11
+ return pd.read_csv(pibdata)
12
+
13
+
14
+ @st.cache_resource
15
+ def load_bert_model(name="pushpdeep/sbertmsmarco-en_to_indic_ur-murilv1"):
16
+ """Instantiate a sentence-level DistilBERT model."""
17
+ return SentenceTransformer(name)
18
+
19
+
20
+ @st.cache_data
21
+ def load_faiss_index(path_to_faiss="models/faiss_index_ip.pickle"):
22
+ """Load and deserialize the Faiss index."""
23
+ with open(path_to_faiss, "rb") as h:
24
+ data = pickle.load(h)
25
+ return faiss.deserialize_index(data)
26
+
27
+ def main():
28
+ # Load data and models
29
+ data = read_data()
30
+ model = load_bert_model()
31
+ faiss_index = load_faiss_index()
32
+
33
+ st.title("Vector-based search with Sentence Transformers and Faiss")
34
+
35
+ # User search
36
+ user_input = st.text_area("Search box", "Aatmanirbhar Bharat")
37
+
38
+ # Filters
39
+ st.sidebar.markdown("**Filters**")
40
+ # filter_year = st.sidebar.slider("Publication year", 2010, 2021, (2010, 2021), 1)
41
+ # filter_citations = st.sidebar.slider("Citations", 0, 250, 0)
42
+ num_results = st.sidebar.slider("Number of search results", 10, 50, 10)
43
+
44
+ # Fetch results
45
+ if user_input:
46
+ # Get paper IDs
47
+ D, I = vector_search([user_input], model, faiss_index, num_results)
48
+ # Slice data on year
49
+ frame = data
50
+ # Get individual results
51
+ for id_ in I.flatten().tolist():
52
+ if id_ in set(frame.rid):
53
+ f = frame[(frame.rid == id_)]
54
+ else:
55
+ continue
56
+
57
+ st.write(
58
+ f"""
59
+ **Language**: {f.iloc[0].language}
60
+ **Monthyear**: {f.iloc[0].posted-on}
61
+ **Abstract**
62
+ {f.iloc[0].body}
63
+ """
64
+ )
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
models/faiss_index_ip.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:418bbb2ecd560a57a6007a1c6dfbbd6e48babe5e5aee0219d6b78c3c6ee0862e
3
+ size 271674732
pib2022_23_cleaned.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08b9c8bc455941f30610fc05c588f46af2c769ca38d42b4919cebb895631351b
3
+ size 619820988
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentence-transformers
4
+ pandas
5
+ faiss-cpu
6
+ numpy
7
+ -e .
vector_engine/.DS_Store ADDED
Binary file (6.15 kB). View file
 
vector_engine/__init__.py ADDED
File without changes
vector_engine/utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def vector_search(query, model, index, num_results=10):
5
+ """Tranforms query to vector using a pretrained, sentence-level
6
+ DistilBERT model and finds similar vectors using FAISS.
7
+ Args:
8
+ query (str): User query that should be more than a sentence long.
9
+ model (sentence_transformers.SentenceTransformer.SentenceTransformer)
10
+ index (`numpy.ndarray`): FAISS index that needs to be deserialized.
11
+ num_results (int): Number of results to return.
12
+ Returns:
13
+ D (:obj:`numpy.array` of `float`): Distance between results and query.
14
+ I (:obj:`numpy.array` of `int`): Paper ID of the results.
15
+
16
+ """
17
+ vector = model.encode(list(query))
18
+ D, I = index.search(np.array(vector).astype("float32"), k=num_results)
19
+ return D, I
20
+
21
+
22
+ def id2details(df, I, column):
23
+ """Returns the paper titles based on the paper index."""
24
+ return [list(df[df.rid == idx][column]) for idx in I[0]]